From 29e74d6e71826da9a3fe3c5790fed1329fc4d1e8 Mon Sep 17 00:00:00 2001
From: Melan
Date: Thu, 20 Oct 2022 16:26:16 +0200
Subject: Add support for Tensorboard for training embeddings
---
modules/textual_inversion/textual_inversion.py | 31 +++++++++++++++++++++++++-
1 file changed, 30 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3be69562..c57d3ace 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,9 +7,11 @@ import tqdm
import html
import datetime
import csv
+import numpy as np
+import torchvision.transforms
from PIL import Image, PngImagePlugin
-
+from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -199,6 +201,19 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer.add_scalar(tag=tag,
+ scalar_value=value, global_step=step)
+
+def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
+ if shared.opts.training_enable_tensorboard:
+ # Convert a pil image to a torch tensor
+ img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands()))
+ img_tensor = img_tensor.permute((2, 0, 1))
+
+ tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
@@ -252,6 +267,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
+ if shared.opts.training_enable_tensorboard:
+ os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
+ tensorboard_writer = SummaryWriter(
+ log_dir=os.path.join(log_directory, "tensorboard"),
+ flush_secs=shared.opts.training_tensorboard_flush_every)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
@@ -270,6 +291,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
del x
losses[embedding.step % losses.shape[0]] = loss.item()
+
optimizer.zero_grad()
loss.backward()
@@ -285,6 +307,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step)
+ tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step)
+ tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step)
+ tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step)
+
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate
@@ -349,6 +377,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
image.save(last_saved_image)
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
last_saved_image += f", prompt: {preview_text}"
--
cgit v1.2.3
From a6d593a6b51dc6a8443f2aa5c24caa391a04cd56 Mon Sep 17 00:00:00 2001
From: Melan
Date: Thu, 20 Oct 2022 19:43:21 +0200
Subject: Fixed a typo in a variable
---
modules/textual_inversion/textual_inversion.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c57d3ace..ec8176bf 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -260,11 +260,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_image = ""
embedding_yet_to_be_embedded = False
- ititial_step = embedding.step or 0
- if ititial_step > steps:
+ initial_step = embedding.step or 0
+ if initial_step > steps:
return embedding, filename
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
if shared.opts.training_enable_tensorboard:
@@ -273,9 +273,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
log_dir=os.path.join(log_directory, "tensorboard"),
flush_secs=shared.opts.training_tensorboard_flush_every)
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step)
for i, entries in pbar:
- embedding.step = i + ititial_step
+ embedding.step = i + initial_step
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
--
cgit v1.2.3
From 8f5912984794c4c69e429c4636e984854d911b6a Mon Sep 17 00:00:00 2001
From: Melan
Date: Thu, 20 Oct 2022 22:37:16 +0200
Subject: Some changes to the tensorboard code and hypernetwork support
---
modules/hypernetworks/hypernetwork.py | 18 ++++++++++-
modules/textual_inversion/textual_inversion.py | 45 +++++++++++++++-----------
2 files changed, 44 insertions(+), 19 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74300122..5e919775 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -4,6 +4,7 @@ import html
import os
import sys
import traceback
+import tensorboard
import tqdm
import csv
@@ -18,7 +19,6 @@ import modules.textual_inversion.dataset
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
@@ -291,6 +291,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -315,6 +318,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.zero_grad()
loss.backward()
optimizer.step()
+
mean_loss = losses.mean()
if torch.isnan(mean_loss):
raise RuntimeError("Loss diverged.")
@@ -323,6 +327,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
+
+ if shared.opts.training_enable_tensorboard:
+ epoch_num = hypernetwork.step // len(ds)
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
+
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss,
+ global_step=hypernetwork.step, step=epoch_step,
+ learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{mean_loss:.7f}",
@@ -360,6 +372,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}",
+ image, hypernetwork.step)
+
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ec8176bf..b1dc2596 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -201,19 +201,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def tensorboard_setup(log_directory):
+ os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
+ return SummaryWriter(
+ log_dir=os.path.join(log_directory, "tensorboard"),
+ flush_secs=shared.opts.training_tensorboard_flush_every)
+
+def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
+ tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
+ tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
+
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
- if shared.opts.training_enable_tensorboard:
- tensorboard_writer.add_scalar(tag=tag,
- scalar_value=value, global_step=step)
+ tensorboard_writer.add_scalar(tag=tag,
+ scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
- if shared.opts.training_enable_tensorboard:
- # Convert a pil image to a torch tensor
- img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
- img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands()))
- img_tensor = img_tensor.permute((2, 0, 1))
+ # Convert a pil image to a torch tensor
+ img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
+ len(pil_image.getbands()))
+ img_tensor = img_tensor.permute((2, 0, 1))
- tensorboard_writer.add_image(tag, img_tensor, global_step=step)
+ tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
@@ -268,10 +279,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
if shared.opts.training_enable_tensorboard:
- os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
- tensorboard_writer = SummaryWriter(
- log_dir=os.path.join(log_directory, "tensorboard"),
- flush_secs=shared.opts.training_tensorboard_flush_every)
+ tensorboard_writer = tensorboard_setup(log_directory)
pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step)
for i, entries in pbar:
@@ -308,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = True
if shared.opts.training_enable_tensorboard:
- tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step)
- tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step)
- tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step)
- tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step)
+ tensorboard_add(tensorboard_writer, loss=losses.mean(), global_step=embedding.step,
+ step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
@@ -377,7 +383,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
image.save(last_saved_image)
- tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
+
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}",
+ image, embedding.step)
last_saved_image += f", prompt: {preview_text}"
--
cgit v1.2.3
From 18f86e41f6f289042c075bff1498e620ab997b8c Mon Sep 17 00:00:00 2001
From: Melan
Date: Mon, 24 Oct 2022 17:21:18 +0200
Subject: Removed two unused imports
---
modules/hypernetworks/hypernetwork.py | 1 -
modules/textual_inversion/textual_inversion.py | 1 -
2 files changed, 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 0cd94f49..2263e95e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -4,7 +4,6 @@ import html
import os
import sys
import traceback
-import tensorboard
import tqdm
import csv
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index b1dc2596..589314fe 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -9,7 +9,6 @@ import datetime
import csv
import numpy as np
-import torchvision.transforms
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models
--
cgit v1.2.3
From 1618df41bad092e068c61bf510b1e20856821ad5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Fri, 28 Oct 2022 10:31:27 +0700
Subject: Gradient clipping for textual embedding
---
modules/textual_inversion/textual_inversion.py | 11 ++++++++++-
modules/ui.py | 2 ++
2 files changed, 12 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ff002d3e..7bad73a6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -206,7 +206,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -256,6 +256,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if ititial_step > steps:
return embedding, filename
+ clip_grad_mode_value = clip_grad_mode == "value"
+ clip_grad_mode_norm = clip_grad_mode == "norm"
+
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -280,6 +283,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer.zero_grad()
loss.backward()
+
+ if clip_grad_mode_value:
+ torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value)
+ elif clip_grad_mode_norm:
+ torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value)
+
optimizer.step()
diff --git a/modules/ui.py b/modules/ui.py
index ba5e92a7..97de7da2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1409,6 +1409,8 @@ def create_ui(wrap_gradio_gpu_call):
training_width,
training_height,
steps,
+ clip_grad_mode,
+ clip_grad_value,
create_image_every,
save_embedding_every,
template_file,
--
cgit v1.2.3
From 16451ca573220e49f2eaaab97580b6b91287c8c4 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Fri, 28 Oct 2022 17:16:23 +0700
Subject: Learning rate sched syntax support for grad clipping
---
modules/hypernetworks/hypernetwork.py | 13 ++++++++++---
modules/textual_inversion/learn_schedule.py | 11 ++++++++---
modules/textual_inversion/textual_inversion.py | 12 +++++++++---
modules/ui.py | 7 +++----
4 files changed, 30 insertions(+), 13 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c5d60654..86532063 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -383,11 +383,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
-
+
clip_grad_mode_value = clip_grad_mode == "value"
clip_grad_mode_norm = clip_grad_mode == "norm"
+ clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
+ if clip_grad_enabled:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
@@ -407,6 +411,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if shared.state.interrupted:
break
+ if clip_grad_enabled:
+ clip_grad_sched.step(hypernetwork.step)
+
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
@@ -430,9 +437,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_value)
+ torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_value)
+ torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
optimizer.step()
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..ffec3e1b 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -51,14 +51,19 @@ class LearnRateScheduler:
self.finished = False
- def apply(self, optimizer, step_number):
+ def step(self, step_number):
if step_number <= self.end_step:
- return
+ return False
try:
(self.learn_rate, self.end_step) = next(self.schedules)
- except Exception:
+ except StopIteration:
self.finished = True
+ return False
+ return True
+
+ def apply(self, optimizer, step_number):
+ if not self.step(step_number):
return
if self.verbose:
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7bad73a6..6b00c6a1 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -255,9 +255,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
-
+
clip_grad_mode_value = clip_grad_mode == "value"
clip_grad_mode_norm = clip_grad_mode == "norm"
+ clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
+ if clip_grad_enabled:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -273,6 +276,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if shared.state.interrupted:
break
+ if clip_grad_enabled:
+ clip_grad_sched.step(embedding.step)
+
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
@@ -285,9 +291,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value)
+ torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value)
+ torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
optimizer.step()
diff --git a/modules/ui.py b/modules/ui.py
index 97de7da2..47d16429 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1305,7 +1305,9 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
-
+ with gr.Row():
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="1.0", show_label=False)
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1313,9 +1315,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
- with gr.Row():
- clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
- clip_grad_value = gr.Number(value=1.0, show_label=False)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
--
cgit v1.2.3
From cffc240a7327ae60671ff533469fc4ed4bf605de Mon Sep 17 00:00:00 2001
From: Nerogar
Date: Sun, 23 Oct 2022 14:05:25 +0200
Subject: fixed textual inversion training with inpainting models
---
modules/textual_inversion/textual_inversion.py | 27 +++++++++++++++++++++++++-
1 file changed, 26 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 0aeb0459..2630c7c9 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -224,6 +224,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+def create_dummy_mask(x, width=None, height=None):
+ if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ return image_conditioning
+
+
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@@ -286,6 +306,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = ""
embedding_yet_to_be_embedded = False
+ img_c = None
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
@@ -299,8 +320,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
+ if img_c is None:
+ img_c = create_dummy_mask(c, training_width, training_height)
+
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ loss = shared.sd_model(x, cond)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
--
cgit v1.2.3
From bb832d7725187f8a8ab44faa6ee1b38cb5f600aa Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 5 Nov 2022 11:48:38 +0700
Subject: Simplify grad clip
---
modules/hypernetworks/hypernetwork.py | 16 +++++++---------
modules/textual_inversion/textual_inversion.py | 16 +++++++---------
2 files changed, 14 insertions(+), 18 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f4c2668f..02b624e1 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -385,10 +385,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- clip_grad_mode_value = clip_grad_mode == "value"
- clip_grad_mode_norm = clip_grad_mode == "norm"
- clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
- if clip_grad_enabled:
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
@@ -433,7 +433,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if shared.state.interrupted:
break
- if clip_grad_enabled:
+ if clip_grad:
clip_grad_sched.step(hypernetwork.step)
with torch.autocast("cuda"):
@@ -458,10 +458,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
- if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
- elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
+ if clip_grad:
+ clip_grad(weights, clip_grad_sched.learn_rate)
optimizer.step()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c567ec3f..687d97bb 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -269,10 +269,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- clip_grad_mode_value = clip_grad_mode == "value"
- clip_grad_mode_norm = clip_grad_mode == "norm"
- clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
- if clip_grad_enabled:
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -302,7 +302,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if shared.state.interrupted:
break
- if clip_grad_enabled:
+ if clip_grad:
clip_grad_sched.step(embedding.step)
with torch.autocast("cuda"):
@@ -316,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer.zero_grad()
loss.backward()
- if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
- elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
optimizer.step()
--
cgit v1.2.3
From f55ac33d446185680604e872ceda2ae858821d5c Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Sat, 31 Dec 2022 11:27:02 -0500
Subject: validate textual inversion embeddings
---
modules/sd_models.py | 3 ++
modules/textual_inversion/textual_inversion.py | 43 +++++++++++++++++++++++---
modules/ui.py | 2 --
3 files changed, 41 insertions(+), 7 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ecdd91c5..ebd4dff7 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -325,6 +325,9 @@ def load_model(checkpoint_info=None):
script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.")
+
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model
+
return sd_model
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f6112578..103ace60 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -23,6 +23,8 @@ class Embedding:
self.vec = vec
self.name = name
self.step = step
+ self.shape = None
+ self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
@@ -57,8 +59,10 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
+ self.skipped_embeddings = []
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
+ self.expected_shape = -1
def register_embedding(self, embedding, model):
@@ -75,14 +79,35 @@ class EmbeddingDatabase:
return embedding
- def load_textual_inversion_embeddings(self):
+ def get_expected_shape(self):
+ expected_shape = -1 # initialize with unknown
+ idx = torch.tensor(0).to(shared.device)
+ if expected_shape == -1:
+ try: # matches sd15 signature
+ first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
+ expected_shape = first_embedding.shape[0]
+ except:
+ pass
+ if expected_shape == -1:
+ try: # matches sd20 signature
+ first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
+ expected_shape = first_embedding.shape[0]
+ except:
+ pass
+ if expected_shape == -1:
+ print('Could not determine expected embeddings shape from model')
+ return expected_shape
+
+ def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
+ if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
+ self.skipped_embeddings = []
+ self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
@@ -122,7 +147,14 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- self.register_embedding(embedding, shared.sd_model)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings.append(name)
+ # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os.listdir(self.embeddings_dir):
try:
@@ -137,8 +169,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr)
continue
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
- print("Embeddings:", ', '.join(self.word_embeddings.keys()))
+ print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
+ if (len(self.skipped_embeddings) > 0):
+ print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
diff --git a/modules/ui.py b/modules/ui.py
index 57ee0465..397dd804 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1157,8 +1157,6 @@ def create_ui():
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
-
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="See wiki for detailed explanation.
")
--
cgit v1.2.3
From bdbe09827b39be63c9c0b3636132ca58da38ebf6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 31 Dec 2022 22:49:09 +0300
Subject: changed embedding accepted shape detection to use existing code and
support the new alt-diffusion model, and reformatted messages a bit #6149
---
modules/textual_inversion/textual_inversion.py | 30 ++++++--------------------
1 file changed, 6 insertions(+), 24 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 103ace60..66f40367 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -80,23 +80,8 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
- expected_shape = -1 # initialize with unknown
- idx = torch.tensor(0).to(shared.device)
- if expected_shape == -1:
- try: # matches sd15 signature
- first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
- expected_shape = first_embedding.shape[0]
- except:
- pass
- if expected_shape == -1:
- try: # matches sd20 signature
- first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
- expected_shape = first_embedding.shape[0]
- except:
- pass
- if expected_shape == -1:
- print('Could not determine expected embeddings shape from model')
- return expected_shape
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
@@ -112,8 +97,6 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = []
-
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
@@ -150,11 +133,10 @@ class EmbeddingDatabase:
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
- if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings.append(name)
- # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os.listdir(self.embeddings_dir):
try:
@@ -169,9 +151,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr)
continue
- print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
- if (len(self.skipped_embeddings) > 0):
- print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
--
cgit v1.2.3
From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 00:38:09 +0300
Subject: fix the issue with training on SD2.0
---
modules/sd_models.py | 2 ++
modules/textual_inversion/textual_inversion.py | 3 +--
2 files changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ebd4dff7..bff8d6c9 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ model.logvar = model.logvar.to(devices.device) # fix for training
+
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 66f40367..1e5722e7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -282,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- # dataset loading may take a while, so input validations and early returns should be done before this
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -310,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0
_loss_step = 0 #internal
-
last_saved_file = ""
last_saved_image = ""
forced_filename = ""
--
cgit v1.2.3
From c65909ad16a1962129114c6251de092f49479b06 Mon Sep 17 00:00:00 2001
From: Philpax
Date: Mon, 2 Jan 2023 12:21:22 +1100
Subject: feat(api): return more data for embeddings
---
modules/api/api.py | 17 +++++++++++++++--
modules/api/models.py | 11 +++++++++--
modules/textual_inversion/textual_inversion.py | 8 ++++----
3 files changed, 28 insertions(+), 8 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 30bf3dac..9c670f00 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -330,9 +330,22 @@ class Api:
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
return {
- "loaded": sorted(db.word_embeddings.keys()),
- "skipped": sorted(db.skipped_embeddings),
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self):
diff --git a/modules/api/models.py b/modules/api/models.py
index a8472dc9..4a632c68 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,6 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
class EmbeddingsResponse(BaseModel):
- loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1e5722e7..fd253477 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -59,7 +59,7 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
- self.skipped_embeddings = []
+ self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1
@@ -91,7 +91,7 @@ class EmbeddingDatabase:
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
- self.skipped_embeddings = []
+ self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
@@ -136,7 +136,7 @@ class EmbeddingDatabase:
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
- self.skipped_embeddings.append(name)
+ self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
- print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
--
cgit v1.2.3
From bddebe09edeb6a18f2c06986d5658a7be3a563ea Mon Sep 17 00:00:00 2001
From: Shondoit
Date: Tue, 3 Jan 2023 10:26:37 +0100
Subject: Save Optimizer next to TI embedding
Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory)
---
modules/shared.py | 2 +-
modules/textual_inversion/textual_inversion.py | 40 ++++++++++++++++++++------
2 files changed, 33 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/shared.py b/modules/shared.py
index 23657a93..c541d18c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd253477..16176e90 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -28,6 +28,7 @@ class Embedding:
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -41,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -95,9 +103,10 @@ class EmbeddingDatabase:
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -105,8 +114,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
+
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- #if shared.opts.save_optimizer_state:
- #embedding.optimizer_state_dict = optimizer.state_dict()
- save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
@@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}
return embedding, filename
-def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint
--
cgit v1.2.3
From 192ddc04d6de0d780f73aa5fbaa8c66cd4642e1c Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Tue, 3 Jan 2023 10:34:51 -0500
Subject: add job info to modules
---
modules/extras.py | 17 +++++++++++++----
modules/hypernetworks/hypernetwork.py | 1 +
modules/textual_inversion/preprocess.py | 1 +
modules/textual_inversion/textual_inversion.py | 1 +
4 files changed, 16 insertions(+), 4 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/extras.py b/modules/extras.py
index 7e222313..d665440a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -58,6 +58,9 @@ cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
+ shared.state.begin()
+ shared.state.job = 'extras'
+
imageArr = []
# Also keep track of original file names
imageNameArr = []
@@ -94,6 +97,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
@@ -104,6 +108,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
@@ -114,6 +119,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
@@ -180,6 +186,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
+
+ shared.state.textinfo = f'Processing image {image_name}'
+
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -193,6 +202,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
basename = ''
+ if opts.enable_pnginfo: # append info before save
+ image.info = existing_pnginfo
+ image.info["extras"] = info
+
if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
@@ -203,10 +216,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
- if opts.enable_pnginfo:
- image.info = existing_pnginfo
- image.info["extras"] = info
-
if extras_mode != 2 or show_extras_results :
outputs.append(image)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 109e8078..450fecac 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -417,6 +417,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
+ shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 56b9b2eb..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd253477..2c1251d6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -245,6 +245,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
--
cgit v1.2.3
From 184e670126f5fc50ba56fa0fedcf0cf60e45ed7e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 17:45:01 +0300
Subject: fix the merge
---
modules/textual_inversion/textual_inversion.py | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5421a758..8731ea5d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -251,6 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+
def create_dummy_mask(x, width=None, height=None):
if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
@@ -380,17 +381,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
break
with devices.autocast():
- # c = stack_conds(batch.cond).to(devices.device)
- # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
- # print(mask)
- # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
-
-
- if img_c is None:
- img_c = create_dummy_mask(c, training_width, training_height)
-
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
+
+ if img_c is None:
+ img_c = create_dummy_mask(c, training_width, training_height)
+
cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x
--
cgit v1.2.3
From 525cea924562afd676f55470095268a0f6fca59e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 17:58:07 +0300
Subject: use shared function from processing for creating dummy mask when
training inpainting model
---
modules/processing.py | 39 +++++++++++++-------------
modules/textual_inversion/textual_inversion.py | 33 ++++++----------------
2 files changed, 29 insertions(+), 43 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/processing.py b/modules/processing.py
index c03e77e7..c7264aff 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
@@ -139,26 +157,9 @@ class StableDiffusionProcessing():
self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- self.is_using_inpainting_conditioning = True
-
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8731ea5d..2250e41b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -252,26 +252,6 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def create_dummy_mask(x, width=None, height=None):
- if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- else:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- return image_conditioning
-
-
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@@ -346,7 +326,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
else:
print("No saved optimizer exists in checkpoint")
-
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -362,7 +341,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
forced_filename = ""
embedding_yet_to_be_embedded = False
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
img_c = None
+
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@@ -384,10 +365,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
- if img_c is None:
- img_c = create_dummy_mask(c, training_width, training_height)
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
- cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x
--
cgit v1.2.3
From eea8fc40e16664ddc8a9aec77206da704a35dde0 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 5 Jan 2023 07:24:22 -0800
Subject: Add option to save ti settings to file.
---
modules/shared.py | 1 +
modules/textual_inversion/textual_inversion.py | 30 +++++++++++++++++++++++---
2 files changed, 28 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/shared.py b/modules/shared.py
index e0f44c6d..933cd738 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -362,6 +362,7 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
+ "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 71e07bcc..2bed2ecb 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,7 @@
import os
import sys
import traceback
+import inspect
import torch
import tqdm
@@ -229,6 +230,28 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ checkpoint = sd_models.select_checkpoint()
+ model_name = checkpoint.model_name
+ model_hash = '[{}]'.format(checkpoint.hash)
+
+ # Get a list of the argument names.
+ arg_names = inspect.getfullargspec(save_settings_to_file).args
+
+ # Create a list of the argument names to include in the settings string.
+ names = arg_names[:16] # Include all arguments up until the preview-related ones.
+ if preview_from_txt2img:
+ names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True.
+
+ # Build the settings string.
+ settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n"
+ for name in names:
+ value = locals()[name]
+ settings_str += f"{name}: {value}\n"
+
+ with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout:
+ fout.write(settings_str + "\n\n")
+
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
@@ -292,13 +315,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
None
if clip_grad:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -306,7 +329,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
-
+ if shared.opts.save_train_settings_to_txt:
+ save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
--
cgit v1.2.3
From b85c2b5cf4a6809bc871718cf4680d49c3e95e94 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 5 Jan 2023 08:14:38 -0800
Subject: Clean up ti, add same behavior to hypernetwork.
---
modules/hypernetworks/hypernetwork.py | 31 +++++++++++++++++++++++++-
modules/shared.py | 2 +-
modules/textual_inversion/textual_inversion.py | 14 +++++++-----
3 files changed, 40 insertions(+), 7 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6a9b1398..d5985263 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -401,7 +401,33 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
hypernet.save(fn)
shared.reload_hypernetworks()
+# Note: textual_inversion.py has a nearly identical function of the same name.
+def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ checkpoint = sd_models.select_checkpoint()
+ model_name = checkpoint.model_name
+ model_hash = '[{}]'.format(checkpoint.hash)
+ # Starting index of preview-related arguments.
+ border_index = 19
+
+ # Get a list of the argument names, excluding default argument.
+ sig = inspect.signature(save_settings_to_file)
+ arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty]
+
+ # Create a list of the argument names to include in the settings string.
+ names = arg_names[:border_index] # Include all arguments up until the preview-related ones.
+
+ # Include preview-related arguments if applicable.
+ if preview_from_txt2img:
+ names.extend(arg_names[border_index:])
+
+ # Build the settings string.
+ settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n"
+ for name in names:
+ value = locals()[name]
+ settings_str += f"{name}: {value}\n"
+ with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout:
+ fout.write(settings_str + "\n\n")
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
@@ -457,7 +483,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
-
+
+ if shared.opts.save_training_settings_to_txt:
+ save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
+
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
diff --git a/modules/shared.py b/modules/shared.py
index 933cd738..10231a75 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
- "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."),
+ "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 2bed2ecb..68648550 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -230,18 +230,20 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+# Note: hypernetwork.py has a nearly identical function of the same name.
def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
checkpoint = sd_models.select_checkpoint()
model_name = checkpoint.model_name
model_hash = '[{}]'.format(checkpoint.hash)
-
+ # Starting index of preview-related arguments.
+ border_index = 16
# Get a list of the argument names.
arg_names = inspect.getfullargspec(save_settings_to_file).args
# Create a list of the argument names to include in the settings string.
- names = arg_names[:16] # Include all arguments up until the preview-related ones.
+ names = arg_names[:border_index] # Include all arguments up until the preview-related ones.
if preview_from_txt2img:
- names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True.
+ names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True.
# Build the settings string.
settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n"
@@ -329,8 +331,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
- if shared.opts.save_train_settings_to_txt:
- save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
+
+ if shared.opts.save_training_settings_to_txt:
+ save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
+
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
--
cgit v1.2.3
From b6bab2f052b32c0ffebe6aecc1819ccf20cf8c5d Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 5 Jan 2023 09:14:56 -0800
Subject: Include model in log file. Exclude directory.
---
modules/hypernetworks/hypernetwork.py | 28 +++++++++-----------------
modules/textual_inversion/textual_inversion.py | 22 +++++++++-----------
2 files changed, 19 insertions(+), 31 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index d5985263..3237c37a 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -402,30 +402,22 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
# Note: textual_inversion.py has a nearly identical function of the same name.
-def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- checkpoint = sd_models.select_checkpoint()
- model_name = checkpoint.model_name
- model_hash = '[{}]'.format(checkpoint.hash)
+def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# Starting index of preview-related arguments.
- border_index = 19
-
- # Get a list of the argument names, excluding default argument.
- sig = inspect.signature(save_settings_to_file)
- arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty]
-
+ border_index = 21
+ # Get a list of the argument names.
+ arg_names = inspect.getfullargspec(save_settings_to_file).args
# Create a list of the argument names to include in the settings string.
names = arg_names[:border_index] # Include all arguments up until the preview-related ones.
-
- # Include preview-related arguments if applicable.
if preview_from_txt2img:
- names.extend(arg_names[border_index:])
-
+ names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable.
# Build the settings string.
settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n"
for name in names:
- value = locals()[name]
- settings_str += f"{name}: {value}\n"
-
+ if name != 'log_directory': # It's useless and redundant to save log_directory.
+ value = locals()[name]
+ settings_str += f"{name}: {value}\n"
+ # Create or append to the file.
with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout:
fout.write(settings_str + "\n\n")
@@ -485,7 +477,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
+ save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
latent_sampling_method = ds.latent_sampling_method
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 68648550..ce7e4f5d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -231,26 +231,22 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
# Note: hypernetwork.py has a nearly identical function of the same name.
-def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- checkpoint = sd_models.select_checkpoint()
- model_name = checkpoint.model_name
- model_hash = '[{}]'.format(checkpoint.hash)
+def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# Starting index of preview-related arguments.
- border_index = 16
+ border_index = 18
# Get a list of the argument names.
- arg_names = inspect.getfullargspec(save_settings_to_file).args
-
+ arg_names = inspect.getfullargspec(save_settings_to_file).args
# Create a list of the argument names to include in the settings string.
names = arg_names[:border_index] # Include all arguments up until the preview-related ones.
if preview_from_txt2img:
- names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True.
-
+ names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable.
# Build the settings string.
settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n"
for name in names:
- value = locals()[name]
- settings_str += f"{name}: {value}\n"
-
+ if name != 'log_directory': # It's useless and redundant to save log_directory.
+ value = locals()[name]
+ settings_str += f"{name}: {value}\n"
+ # Create or append to the file.
with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout:
fout.write(settings_str + "\n\n")
@@ -333,7 +329,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
+ save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
latent_sampling_method = ds.latent_sampling_method
--
cgit v1.2.3
From fda04e620d529031e2134520e74756d0efa30464 Mon Sep 17 00:00:00 2001
From: Kuma <36082288+KumiIT@users.noreply.github.com>
Date: Thu, 5 Jan 2023 18:44:19 +0100
Subject: typo in TI
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 71e07bcc..24b43045 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -298,7 +298,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
None
if clip_grad:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
--
cgit v1.2.3
From 81133d4168ae0bae9bf8bf1a1d4983319a589112 Mon Sep 17 00:00:00 2001
From: Faber
Date: Fri, 6 Jan 2023 03:38:37 +0700
Subject: allow loading embeddings from subdirectories
---
modules/textual_inversion/textual_inversion.py | 23 ++++++++++++-----------
1 file changed, 12 insertions(+), 11 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 24b43045..0a059044 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -149,19 +149,20 @@ class EmbeddingDatabase:
else:
self.skipped_embeddings[name] = embedding
- for fn in os.listdir(self.embeddings_dir):
- try:
- fullfn = os.path.join(self.embeddings_dir, fn)
-
- if os.stat(fullfn).st_size == 0:
+ for root, dirs, fns in os.walk(self.embeddings_dir):
+ for fn in fns:
+ try:
+ fullfn = os.path.join(root, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
continue
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
-
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
--
cgit v1.2.3
From 683287d87f6401083a8d63eedc00ca7410214ca1 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 6 Jan 2023 08:52:06 +0300
Subject: rework saving training params to file #6372
---
modules/hypernetworks/hypernetwork.py | 28 +++++++-------------------
modules/shared.py | 2 +-
modules/textual_inversion/logging.py | 24 ++++++++++++++++++++++
modules/textual_inversion/textual_inversion.py | 23 +++------------------
4 files changed, 35 insertions(+), 42 deletions(-)
create mode 100644 modules/textual_inversion/logging.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3237c37a..b0cfbe71 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -13,7 +13,7 @@ import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers
-from modules.textual_inversion import textual_inversion
+from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
@@ -401,25 +401,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
hypernet.save(fn)
shared.reload_hypernetworks()
-# Note: textual_inversion.py has a nearly identical function of the same name.
-def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # Starting index of preview-related arguments.
- border_index = 21
- # Get a list of the argument names.
- arg_names = inspect.getfullargspec(save_settings_to_file).args
- # Create a list of the argument names to include in the settings string.
- names = arg_names[:border_index] # Include all arguments up until the preview-related ones.
- if preview_from_txt2img:
- names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable.
- # Build the settings string.
- settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n"
- for name in names:
- if name != 'log_directory': # It's useless and redundant to save log_directory.
- value = locals()[name]
- settings_str += f"{name}: {value}\n"
- # Create or append to the file.
- with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout:
- fout.write(settings_str + "\n\n")
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
@@ -477,7 +459,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
+ saved_params = dict(
+ model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
+ )
+ logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
latent_sampling_method = ds.latent_sampling_method
diff --git a/modules/shared.py b/modules/shared.py
index f0e10b35..57e489d0 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
- "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."),
+ "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
new file mode 100644
index 00000000..8b1981d5
--- /dev/null
+++ b/modules/textual_inversion/logging.py
@@ -0,0 +1,24 @@
+import datetime
+import json
+import os
+
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
+saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
+saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+
+
+def save_settings_to_file(log_directory, all_params):
+ now = datetime.datetime.now()
+ params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
+
+ keys = saved_params_all
+ if all_params.get('preview_from_txt2img'):
+ keys = keys | saved_params_previews
+
+ params.update({k: v for k, v in all_params.items() if k in keys})
+
+ filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
+ with open(os.path.join(log_directory, filename), "w") as file:
+ json.dump(params, file, indent=4)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e9cf432f..f9f5e8cd 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
+from modules.textual_inversion.logging import save_settings_to_file
+
class Embedding:
def __init__(self, vec, name, step=None):
@@ -231,25 +233,6 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
-# Note: hypernetwork.py has a nearly identical function of the same name.
-def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # Starting index of preview-related arguments.
- border_index = 18
- # Get a list of the argument names.
- arg_names = inspect.getfullargspec(save_settings_to_file).args
- # Create a list of the argument names to include in the settings string.
- names = arg_names[:border_index] # Include all arguments up until the preview-related ones.
- if preview_from_txt2img:
- names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable.
- # Build the settings string.
- settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n"
- for name in names:
- if name != 'log_directory': # It's useless and redundant to save log_directory.
- value = locals()[name]
- settings_str += f"{name}: {value}\n"
- # Create or append to the file.
- with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout:
- fout.write(settings_str + "\n\n")
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
@@ -330,7 +313,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height)
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
latent_sampling_method = ds.latent_sampling_method
--
cgit v1.2.3
From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 7 Jan 2023 01:45:28 +0300
Subject: CLIP hijack rework
---
modules/sd_hijack.py | 6 +-
modules/sd_hijack_clip.py | 348 ++++++++++++-------------
modules/sd_hijack_clip_old.py | 81 ++++++
modules/textual_inversion/textual_inversion.py | 1 -
modules/ui.py | 2 +-
5 files changed, 256 insertions(+), 182 deletions(-)
create mode 100644 modules/sd_hijack_clip_old.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fa2cd4bb..71cc145a 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -150,10 +150,10 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
- def tokenize(self, text):
- _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
+ def get_prompt_lengths(self, text):
+ _, token_count = self.clip.process_texts([text])
- return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index ca92b142..ac3020d7 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -1,12 +1,28 @@
import math
+from collections import namedtuple
import torch
from modules import prompt_parser, devices
from modules.shared import opts
-def get_target_prompt_token_count(token_count):
- return math.ceil(max(token_count, 1) / 75) * 75
+
+class PromptChunk:
+ """
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+ so just 75 tokens from prompt.
+ """
+
+ def __init__(self):
+ self.tokens = []
+ self.multipliers = []
+ self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk"""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
@@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
+ self.chunk_length = 75
+
+ def empty_chunk(self):
+ """creates an empty PromptChunk and returns it"""
+
+ chunk = PromptChunk()
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
+ return chunk
+
+ def get_target_prompt_token_count(self, token_count):
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
+ """Converts a batch of texts into a batch of token ids"""
+
raise NotImplementedError
def encode_with_transformers(self, tokens):
+ """
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+ All python lists with tokens are assumed to have same length, usually 77.
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+ model - can be 768 and 1024
+ """
+
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+
raise NotImplementedError
- def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ def tokenize_line(self, line):
+ """
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+ represent the prompt.
+ Returns the list and the total number of tokens in the prompt.
+ """
+
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
@@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
tokenized = self.tokenize([text for text, _ in parsed])
- fixes = []
- remade_tokens = []
- multipliers = []
+ chunks = []
+ chunk = PromptChunk()
+ token_count = 0
last_comma = -1
- for tokens, (text, weight) in zip(tokenized, parsed):
- i = 0
- while i < len(tokens):
- token = tokens[i]
+ def next_chunk():
+ """puts current chunk into the list of results and produces the next one - empty"""
+ nonlocal token_count
+ nonlocal last_comma
+ nonlocal chunk
+
+ token_count += len(chunk.tokens)
+ to_add = self.chunk_length - len(chunk.tokens)
+ if to_add > 0:
+ chunk.tokens += [self.id_end] * to_add
+ chunk.multipliers += [1.0] * to_add
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+
+ last_comma = -1
+ chunks.append(chunk)
+ chunk = PromptChunk()
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ position = 0
+ while position < len(tokens):
+ token = tokens[position]
if token == self.comma_token:
- last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
- last_comma += 1
- reloc_tokens = remade_tokens[last_comma:]
- reloc_mults = multipliers[last_comma:]
+ last_comma = len(chunk.tokens)
+
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+ # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next.
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+ break_location = last_comma + 1
+
+ reloc_tokens = chunk.tokens[break_location:]
+ reloc_mults = chunk.multipliers[break_location:]
- remade_tokens = remade_tokens[:last_comma]
- length = len(remade_tokens)
+ chunk.tokens = chunk.tokens[:break_location]
+ chunk.multipliers = chunk.multipliers[:break_location]
- rem = int(math.ceil(length / 75)) * 75 - length
- remade_tokens += [self.id_end] * rem + reloc_tokens
- multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+ next_chunk()
+ chunk.tokens = reloc_tokens
+ chunk.multipliers = reloc_mults
+ if len(chunk.tokens) == self.chunk_length:
+ next_chunk()
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
if embedding is None:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- iteration = len(remade_tokens) // 75
- if (len(remade_tokens) + emb_len) // 75 != iteration:
- rem = (75 * (iteration + 1) - len(remade_tokens))
- remade_tokens += [self.id_end] * rem
- multipliers += [1.0] * rem
- iteration += 1
- fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- token_count = len(remade_tokens)
- prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens)
-
- remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
- multipliers = multipliers + [1.0] * tokens_to_add
-
- return remade_tokens, fixes, multipliers, token_count
-
- def process_text(self, texts):
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
+ chunk.tokens.append(token)
+ chunk.multipliers.append(weight)
+ position += 1
+ continue
+
+ emb_len = int(embedding.vec.shape[0])
+ if len(chunk.tokens) + emb_len > self.chunk_length:
+ next_chunk()
+
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+
+ chunk.tokens += [0] * emb_len
+ chunk.multipliers += [weight] * emb_len
+ position += embedding_length_in_tokens
+
+ if len(chunk.tokens) > 0:
+ next_chunk()
+
+ return chunks, token_count
+
+ def process_texts(self, texts):
+ """
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+ length, in tokens, of all texts.
+ """
+
token_count = 0
cache = {}
- batch_multipliers = []
+ batch_chunks = []
for line in texts:
if line in cache:
- remade_tokens, fixes, multipliers = cache[line]
+ chunks = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
- cache[line] = (remade_tokens, fixes, multipliers)
+ cache[line] = chunks
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
+ batch_chunks.append(chunks)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+ return batch_chunks, token_count
- def process_text_old(self, texts):
- id_start = self.id_start
- id_end = self.id_end
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
+ def forward(self, texts):
+ """
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ An example shape returned by this function can be: (2, 77, 768).
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+ """
- cache = {}
- batch_tokens = self.tokenize(texts)
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
+ if opts.use_old_emphasis_implementation:
+ import modules.sd_hijack_clip_old
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
- def forward(self, text):
- use_old = opts.use_old_emphasis_implementation
- if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
- else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
- if use_old:
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
- z = None
- i = 0
- while max(map(len, remade_batch_tokens)) != 0:
- rem_tokens = [x[75:] for x in remade_batch_tokens]
- rem_multipliers = [x[75:] for x in batch_multipliers]
-
- self.hijack.fixes = []
- for unfiltered in hijack_fixes:
- fixes = []
- for fix in unfiltered:
- if fix[0] == i:
- fixes.append(fix[1])
- self.hijack.fixes.append(fixes)
-
- tokens = []
- multipliers = []
- for j in range(len(remade_batch_tokens)):
- if len(remade_batch_tokens[j]) > 0:
- tokens.append(remade_batch_tokens[j][:75])
- multipliers.append(batch_multipliers[j][:75])
- else:
- tokens.append([self.id_end] * 75)
- multipliers.append([1.0] * 75)
-
- z1 = self.process_tokens(tokens, multipliers)
- z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- remade_batch_tokens = rem_tokens
- batch_multipliers = rem_multipliers
- i += 1
+ batch_chunks, token_count = self.process_texts(texts)
- return z
+ used_embeddings = {}
+ chunk_count = max([len(x) for x in batch_chunks])
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
- batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+ zs = []
+ for i in range(chunk_count):
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+
+ tokens = [x.tokens for x in batch_chunk]
+ multipliers = [x.multipliers for x in batch_chunk]
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
+ for fixes in self.hijack.fixes:
+ for position, embedding in fixes:
+ used_embeddings[embedding.name] = embedding
+
+ z = self.process_tokens(tokens, multipliers)
+ zs.append(z)
+
+ if len(used_embeddings) > 0:
+ embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+
+ return torch.hstack(zs)
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ """
+ sends one single prompt chunk to be encoded by transformers neural network.
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+ corresponds to one token.
+ """
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
@@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
- batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
new file mode 100644
index 00000000..6d9fbbe6
--- /dev/null
+++ b/modules/sd_hijack_clip_old.py
@@ -0,0 +1,81 @@
+from modules import sd_hijack_clip
+from modules import shared
+
+
+def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ id_start = self.id_start
+ id_end = self.id_end
+ maxlen = self.wrapped.max_length # you get to stay at 77
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_tokens = self.tokenize(texts)
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ i += 1
+ elif embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
+
+ self.hijack.comments += hijack_comments
+
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f9f5e8cd..45882ed6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -79,7 +79,6 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
- # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
diff --git a/modules/ui.py b/modules/ui.py
index b79d24ee..5d2f5bad 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -368,7 +368,7 @@ def update_token_counter(text, steps):
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
- tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"{token_count}/{max_length}"
--
cgit v1.2.3
From 448b9cedab66e05b5b2800513ca334a769b42aa7 Mon Sep 17 00:00:00 2001
From: dan
Date: Sat, 7 Jan 2023 21:07:27 +0800
Subject: Allow variable img size
---
modules/textual_inversion/dataset.py | 18 +++++++++++-------
modules/textual_inversion/textual_inversion.py | 4 ++--
2 files changed, 13 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 88d68c76..375178ed 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
- def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None):
self.filename = filename
self.filename_text = filename_text
self.latent_dist = latent_dist
@@ -25,6 +25,7 @@ class DatasetEntry:
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
+ self.img_shape = img_shape
class PersonalizedBase(Dataset):
@@ -33,8 +34,6 @@ class PersonalizedBase(Dataset):
self.placeholder_token = placeholder_token
- self.width = width
- self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
@@ -59,7 +58,11 @@ class PersonalizedBase(Dataset):
if shared.state.interrupted:
raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB')
+ if width < 2000:
+ image = image.resize((width, height), PIL.Image.BICUBIC)
+ else:
+ assert batch_size == 1, 'variable img size must have batch size 1'
except Exception:
continue
@@ -88,14 +91,14 @@ class PersonalizedBase(Dataset):
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once"
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
elif latent_sampling_method == "random":
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
@@ -151,6 +154,7 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ self.img_shape = [entry.img_shape for entry in data]
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 45882ed6..9f96d0fd 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -451,8 +451,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
else:
p.prompt = batch.cond_text[0]
p.steps = 20
- p.width = training_width
- p.height = training_height
+ p.width = batch.img_shape[0][0]
+ p.height = batch.img_shape[0][1]
preview_text = p.prompt
--
cgit v1.2.3
From 669fb18d5222f53ae48abe0f30393d846c50ad91 Mon Sep 17 00:00:00 2001
From: dan
Date: Sun, 8 Jan 2023 01:34:52 +0800
Subject: Add checkbox for variable training dims
---
modules/hypernetworks/hypernetwork.py | 2 +-
modules/textual_inversion/dataset.py | 4 ++--
modules/textual_inversion/textual_inversion.py | 4 ++--
modules/ui.py | 3 +++
4 files changed, 8 insertions(+), 5 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b0cfbe71..dba52841 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -403,7 +403,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 375178ed..7f8a314f 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -29,7 +29,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -59,7 +59,7 @@ class PersonalizedBase(Dataset):
raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB')
- if width < 2000:
+ if not varsize:
image = image.resize((width, height), PIL.Image.BICUBIC)
else:
assert batch_size == 1, 'variable img size must have batch size 1'
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 9f96d0fd..110efd19 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -255,7 +255,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
@@ -309,7 +309,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
diff --git a/modules/ui.py b/modules/ui.py
index 99483130..4e709a71 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1343,6 +1343,7 @@ def create_ui():
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
with FormRow():
@@ -1449,6 +1450,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,
@@ -1480,6 +1482,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,
--
cgit v1.2.3
From a0c87f1fdf2b76b2ae4ef6c4b01ddaede3afab06 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 8 Jan 2023 08:52:26 +0300
Subject: skip images in embeddings dir if they have a second .preview
extension
---
modules/textual_inversion/textual_inversion.py | 4 ++++
1 file changed, 4 insertions(+)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 45882ed6..e85dd549 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -109,6 +109,10 @@ class EmbeddingDatabase:
ext = ext.upper()
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ _, second_ext = os.path.splitext(name)
+ if second_ext.upper() == '.PREVIEW':
+ return
+
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
--
cgit v1.2.3
From 085427de0efc9e9e7a6e9a5aebc6b5a69f0365e7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 8 Jan 2023 09:37:33 +0300
Subject: make it possible for extensions/scripts to add their own embedding
directories
---
modules/sd_hijack.py | 7 +-
modules/textual_inversion/textual_inversion.py | 170 +++++++++++++++----------
2 files changed, 108 insertions(+), 69 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index cfdb09d6..6b0d95af 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
- def hijack(self, m):
+ def __init__(self):
+ self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
-
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e85dd549..217fe9eb 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -66,17 +66,41 @@ class Embedding:
return self.cached_checksum
+class DirWithTextualInversionEmbeddings:
+ def __init__(self, path):
+ self.path = path
+ self.mtime = None
+
+ def has_changed(self):
+ if not os.path.isdir(self.path):
+ return False
+
+ mt = os.path.getmtime(self.path)
+ if self.mtime is None or mt > self.mtime:
+ return True
+
+ def update(self):
+ if not os.path.isdir(self.path):
+ return
+
+ self.mtime = os.path.getmtime(self.path)
+
+
class EmbeddingDatabase:
- def __init__(self, embeddings_dir):
+ def __init__(self):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
- self.dir_mtime = None
- self.embeddings_dir = embeddings_dir
self.expected_shape = -1
+ self.embedding_dirs = {}
- def register_embedding(self, embedding, model):
+ def add_embedding_dir(self, path):
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
+
+ def clear_embedding_dirs(self):
+ self.embedding_dirs.clear()
+ def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenize([embedding.name])[0]
@@ -93,69 +117,62 @@ class EmbeddingDatabase:
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
- def load_textual_inversion_embeddings(self, force_reload = False):
- mt = os.path.getmtime(self.embeddings_dir)
- if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
- return
+ def load_from_file(self, path, filename):
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
- self.skipped_embeddings.clear()
- self.expected_shape = self.get_expected_shape()
-
- def process_file(path, filename):
- name, ext = os.path.splitext(filename)
- ext = ext.upper()
-
- if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- _, second_ext = os.path.splitext(name)
- if second_ext.upper() == '.PREVIEW':
- return
-
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- name = data.get('name', name)
- else:
- data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
- elif ext in ['.BIN', '.PT']:
- data = torch.load(path, map_location="cpu")
- else:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ _, second_ext = os.path.splitext(name)
+ if second_ext.upper() == '.PREVIEW':
return
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
-
- if self.expected_shape == -1 or self.expected_shape == embedding.shape:
- self.register_embedding(embedding, shared.sd_model)
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
else:
- self.skipped_embeddings[name] = embedding
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ elif ext in ['.BIN', '.PT']:
+ data = torch.load(path, map_location="cpu")
+ else:
+ return
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
- for root, dirs, fns in os.walk(self.embeddings_dir):
+ def load_from_dir(self, embdir):
+ if not os.path.isdir(embdir.path):
+ return
+
+ for root, dirs, fns in os.walk(embdir.path):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -163,12 +180,32 @@ class EmbeddingDatabase:
if os.stat(fullfn).st_size == 0:
continue
- process_file(fullfn, fn)
+ self.load_from_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
+ def load_textual_inversion_embeddings(self, force_reload=False):
+ if not force_reload:
+ need_reload = False
+ for path, embdir in self.embedding_dirs.items():
+ if embdir.has_changed():
+ need_reload = True
+ break
+
+ if not need_reload:
+ return
+
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
+
+ for path, embdir in self.embedding_dirs.items():
+ self.load_from_dir(embdir)
+ embdir.update()
+
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
@@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
- assert steps > 0 , "Max steps must be positive"
+ assert steps > 0, "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
- assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert save_model_every >= 0, "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
- assert create_image_every >= 0 , "Create image must be positive or 0"
+ assert create_image_every >= 0, "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
--
cgit v1.2.3
From 43bb5190fc9e7ae479a5dc6640be202c9a71e464 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 9 Jan 2023 22:52:23 +0300
Subject: remove/simplify some changes from #6481
---
modules/textual_inversion/dataset.py | 14 +++++---------
modules/textual_inversion/textual_inversion.py | 4 ++--
modules/ui.py | 2 +-
3 files changed, 8 insertions(+), 12 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index bcad6848..fa48708e 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
- def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None):
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename
self.filename_text = filename_text
self.latent_dist = latent_dist
@@ -25,7 +25,6 @@ class DatasetEntry:
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
- self.img_shape = img_shape
class PersonalizedBase(Dataset):
@@ -46,12 +45,10 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- if varsize:
- assert batch_size == 1, 'variable img size must have batch size 1'
+ assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
@@ -91,14 +88,14 @@ class PersonalizedBase(Dataset):
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once"
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
@@ -154,7 +151,6 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
- self.img_shape = [entry.img_shape for entry in data]
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ad76297e..14be2c96 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -492,8 +492,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
else:
p.prompt = batch.cond_text[0]
p.steps = 20
- p.width = batch.img_shape[0][0]
- p.height = batch.img_shape[0][1]
+ p.width = training_width
+ p.height = training_height
preview_text = p.prompt
diff --git a/modules/ui.py b/modules/ui.py
index 9d6b141e..ddfe1b1a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1348,7 +1348,7 @@ def create_ui():
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
- varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize")
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
with FormRow():
--
cgit v1.2.3
From 1fbb6f9ebe48326a3b12ecf611105dbc4a46891e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 9 Jan 2023 23:35:40 +0300
Subject: make a dropdown for prompt template selection
---
modules/hypernetworks/hypernetwork.py | 7 ++++--
modules/shared.py | 1 +
modules/textual_inversion/textual_inversion.py | 35 ++++++++++++++++++++------
modules/ui.py | 11 ++++++--
webui.py | 3 +++
5 files changed, 45 insertions(+), 12 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 32c67ccc..ea3f1db9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,6 +24,7 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
diff --git a/modules/shared.py b/modules/shared.py
index a1e10201..aa37c8ce 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 14be2c96..5420903f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
import inspect
+from collections import namedtuple
import torch
import tqdm
@@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
- insert_image_data_embed, extract_image_data_embed,
- caption_image_overlay)
+from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
from modules.textual_inversion.logging import save_settings_to_file
+TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
+textual_inversion_templates = {}
+
+
+def list_textual_inversion_templates():
+ textual_inversion_templates.clear()
+
+ for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for fn in fns:
+ path = os.path.join(root, fn)
+
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
+
+ return textual_inversion_templates
+
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
@@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
@@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert template_file, "Prompt template file is empty"
- assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert template_filename, "Prompt template file not selected"
+ assert template_file, f"Prompt template file {template_filename} not found"
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0, "Max steps must be positive"
@@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
- validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = textual_inversion_templates.get(template_filename, None)
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = template_file.path
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
diff --git a/modules/ui.py b/modules/ui.py
index ddfe1b1a..b6079aec 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -37,7 +37,7 @@ from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
-import modules.textual_inversion.ui
+from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
@@ -1322,6 +1322,9 @@ def create_ui():
outputs=[process_focal_crop_row],
)
+ def get_textual_inversion_template_names():
+ return sorted([x for x in textual_inversion.textual_inversion_templates])
+
with gr.Tab(label="Train"):
gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
with FormRow():
@@ -1345,7 +1348,11 @@ def create_ui():
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+
+ with FormRow():
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
+
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
diff --git a/webui.py b/webui.py
index 8737e593..47d372c7 100644
--- a/webui.py
+++ b/webui.py
@@ -33,6 +33,7 @@ import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
+import modules.textual_inversion.textual_inversion
import modules.ui
from modules import modelloader
@@ -67,6 +68,8 @@ def initialize():
modules.sd_vae.refresh_vae_list()
+ modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
+
try:
modules.sd_models.load_model()
except Exception as e:
--
cgit v1.2.3
From f9706acf431f77e0ce9e4270e5be7299922ee963 Mon Sep 17 00:00:00 2001
From: Lee Bousfield
Date: Tue, 10 Jan 2023 18:40:34 -0700
Subject: Support loading textual inversion embeddings from safetensors files
---
modules/textual_inversion/textual_inversion.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5420903f..3866c154 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -9,6 +9,7 @@ import tqdm
import html
import datetime
import csv
+import safetensors.torch
from PIL import Image, PngImagePlugin
@@ -150,6 +151,8 @@ class EmbeddingDatabase:
name = data.get('name', name)
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ data = safetensors.torch.load_file(path, device="cpu")
else:
return
--
cgit v1.2.3
From 3f43d8a966ba8462ba019a5ad573f94508cd45f8 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 11 Jan 2023 10:28:55 -0500
Subject: set descriptions
---
modules/hypernetworks/hypernetwork.py | 4 +++-
modules/textual_inversion/preprocess.py | 7 ++++++-
modules/textual_inversion/textual_inversion.py | 4 +++-
3 files changed, 12 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 300d3975..194679e8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -619,7 +619,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index feb876c6..3c1042ad 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
params.process_caption_deepbooru = process_caption_deepbooru
params.preprocess_txt_action = preprocess_txt_action
- for index, imagefile in enumerate(tqdm.tqdm(files)):
+ pbar = tqdm.tqdm(files)
+ for index, imagefile in enumerate(pbar):
params.subindex = 0
filename = os.path.join(src, imagefile)
try:
@@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
except Exception:
continue
+ description = f"Preprocessing [Image {index}/{len(files)}]"
+ pbar.set_description(description)
+ shared.state.textinfo = description
+
params.src = filename
existing_caption = None
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3866c154..b915b091 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -476,7 +476,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
--
cgit v1.2.3
From d52a80f7f7da160c73afd067c8f1bf491391f994 Mon Sep 17 00:00:00 2001
From: Shondoit
Date: Thu, 12 Jan 2023 09:22:29 +0100
Subject: Allow creation of zero vectors for TI
---
modules/textual_inversion/textual_inversion.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index b915b091..853246a6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
--
cgit v1.2.3
From a176d89487d92f5a5b152401e5c424b34ff43b96 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 13 Jan 2023 14:32:15 +0300
Subject: print bucket sizes for training without resizing images #6620 fix an
error when generating a picture with embedding in it
---
modules/textual_inversion/dataset.py | 16 ++++++++++++++++
modules/textual_inversion/image_embedding.py | 4 ++--
modules/textual_inversion/textual_inversion.py | 2 +-
3 files changed, 19 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index b47414f3..d31963d4 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -118,6 +118,12 @@ class PersonalizedBase(Dataset):
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
+ if len(groups) > 1:
+ print("Buckets:")
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
+ print(f" {w}x{h}: {len(ids)}")
+ print()
+
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
@@ -140,8 +146,11 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
+
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ super().__init__(data_source)
+
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
@@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler):
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
+
def __len__(self):
return self.len
+
def __iter__(self):
b = self.batch_size
+
for g in self.groups:
shuffle(g)
+
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
+
shuffle(batches)
+
yield from batches
+
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index ea653806..5593f88c 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
- data_np_low.resize(next_size)
+ data_np_low = np.resize(data_np_low, next_size)
data_np_low = data_np_low.reshape((h, -1, d))
- data_np_high.resize(next_size)
+ data_np_high = np.resize(data_np_high, next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 853246a6..e23906ca 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description)
shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
--
cgit v1.2.3
From 82725f0ac439f7e3b67858d55900e95330bbd326 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 13 Jan 2023 15:04:37 +0300
Subject: fix a bug caused by merge
---
modules/textual_inversion/textual_inversion.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 85210b0e..6939efcc 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,6 +11,7 @@ import datetime
import csv
import safetensors.torch
+import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
--
cgit v1.2.3
From a95f1353089bdeaccd7c266b40cdd79efedfe632 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 09:56:59 +0300
Subject: change hash to sha256
---
.gitignore | 1 +
modules/api/api.py | 2 +-
modules/api/models.py | 3 +-
modules/hashes.py | 72 +++++++++++++++
modules/hypernetworks/hypernetwork.py | 4 +-
modules/sd_models.py | 116 ++++++++++++++++---------
modules/shared.py | 2 +-
modules/textual_inversion/textual_inversion.py | 6 +-
webui.py | 2 +
9 files changed, 158 insertions(+), 50 deletions(-)
create mode 100644 modules/hashes.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/.gitignore b/.gitignore
index 21fa26a7..0b1d17ca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,3 +32,4 @@ notification.mp3
/extensions
/test/stdout.txt
/test/stderr.txt
+/cache.json
diff --git a/modules/api/api.py b/modules/api/api.py
index 5767ba90..9814bbc2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -371,7 +371,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/api/models.py b/modules/api/models.py
index c78095ca..1eb1fcf1 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -224,7 +224,8 @@ class UpscalerItem(BaseModel):
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
- hash: str = Field(title="Hash")
+ hash: Optional[str] = Field(title="Short hash")
+ sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
diff --git a/modules/hashes.py b/modules/hashes.py
new file mode 100644
index 00000000..ebfbd90c
--- /dev/null
+++ b/modules/hashes.py
@@ -0,0 +1,72 @@
+import hashlib
+import json
+import os.path
+
+import filelock
+
+
+cache_filename = "cache.json"
+cache_data = None
+
+
+def dump_cache():
+ with filelock.FileLock(cache_filename+".lock"):
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ global cache_data
+
+ if cache_data is None:
+ with filelock.FileLock(cache_filename+".lock"):
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def calculate_sha256(filename):
+ hash_sha256 = hashlib.sha256()
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
+ ondisk_mtime = os.path.getmtime(filename)
+
+ if title in hashes:
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime <= cached_mtime and cached_sha256 is not None:
+ return cached_sha256
+
+ print(f"Calculating sha256 for {filename}: ", end='')
+ sha256_value = calculate_sha256(filename)
+ print(f"{sha256_value}")
+
+ hashes[title] = {
+ "mtime": ondisk_mtime,
+ "sha256": sha256_value,
+ }
+
+ dump_cache()
+
+ return sha256_value
+
+
+
+
+
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 83cbb4f0..9b5f2e79 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -509,7 +509,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
)
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
@@ -737,7 +737,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
- hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.name = hypernetwork_name
hypernetwork.save(filename)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index c466f273..7babb9ae 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,17 +14,56 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
+checkpoint_alisases = {}
checkpoints_loaded = collections.OrderedDict()
+
+class CheckpointInfo:
+ def __init__(self, filename):
+ self.filename = filename
+ abspath = os.path.abspath(filename)
+
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
+ else:
+ name = os.path.basename(filename)
+
+ if name.startswith("\\") or name.startswith("/"):
+ name = name[1:]
+
+ self.title = name
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+ self.hash = model_hash(filename)
+ self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]']
+ self.shorthash = None
+ self.sha256 = None
+
+ def register(self):
+ checkpoints_list[self.title] = self
+ for id in self.ids:
+ checkpoint_alisases[id] = self
+
+ def calculate_shorthash(self):
+ self.sha256 = hashes.sha256(self.filename, self.title)
+ self.shorthash = self.sha256[0:10]
+
+ if self.shorthash not in self.ids:
+ self.ids += [self.shorthash, self.sha256]
+ self.register()
+
+ return self.shorthash
+
+
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -43,10 +82,14 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- convert = lambda name: int(name) if name.isdigit() else name.lower()
- alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
- return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def checkpoint_tiles():
+ def convert(name):
+ return int(name) if name.isdigit() else name.lower()
+
+ def alphanumeric_key(key):
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
@@ -62,48 +105,38 @@ def find_checkpoint_config(info):
def list_models():
checkpoints_list.clear()
+ checkpoint_alisases.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
- def modeltitle(path, shorthash):
- abspath = os.path.abspath(path)
-
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
- elif abspath.startswith(model_path):
- name = abspath.replace(model_path, '')
- else:
- name = os.path.basename(path)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
-
- return f'{name} [{shorthash}]', shortname
-
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
- h = model_hash(cmd_ckpt)
- title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
- shared.opts.data['sd_model_checkpoint'] = title
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
+ checkpoint_info.register()
+
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
+
for filename in model_list:
- h = model_hash(filename)
- title, short_model_name = modeltitle(filename, h)
+ checkpoint_info = CheckpointInfo(filename)
+ checkpoint_info.register()
+
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+def get_closet_checkpoint_match(search_string):
+ checkpoint_info = checkpoint_alisases.get(search_string, None)
+ if checkpoint_info is not None:
+ return
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
-def get_closet_checkpoint_match(searchString):
- applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
- if len(applicable) > 0:
- return applicable[0]
return None
def model_hash(filename):
+ """old hash that only looks at a small part of the file and is prone to collisions"""
+
try:
with open(filename, "rb") as file:
import hashlib
@@ -119,7 +152,7 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoints_list.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
-def load_model_weights(model, checkpoint_info, vae_file="auto"):
- checkpoint_file = checkpoint_info.filename
- sd_model_hash = checkpoint_info.hash
+def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
+ sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
- sd = read_state_dict(checkpoint_file)
+ sd = read_state_dict(checkpoint_info.filename)
model.load_state_dict(sd, strict=False)
del sd
@@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
- model.sd_model_checkpoint = checkpoint_file
+ model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+ vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
sd_vae.load_vae(model, vae_file)
diff --git a/modules/shared.py b/modules/shared.py
index b90ded52..d74c069d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -428,7 +428,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6939efcc..63935878 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
latent_sampling_method = ds.latent_sampling_method
@@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_mid = '[{}]'.format(checkpoint.shorthash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
@@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
try:
- embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint = checkpoint.shorthash
embedding.sd_checkpoint_name = checkpoint.model_name
if remove_cached_checksum:
embedding.cached_checksum = None
diff --git a/webui.py b/webui.py
index 47d372c7..1fff80da 100644
--- a/webui.py
+++ b/webui.py
@@ -78,6 +78,8 @@ def initialize():
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
+ shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
--
cgit v1.2.3
From d8b90ac121cbf0c18b1dc9d56a5e1d14ca51e74e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 18:50:56 +0300
Subject: big rework of progressbar/preview system to allow multiple users to
prompts at the same time and do not get previews of each other
---
javascript/progressbar.js | 249 ++++++++++++++++---------
javascript/textualInversion.js | 13 +-
javascript/ui.js | 33 +++-
modules/call_queue.py | 19 +-
modules/hypernetworks/hypernetwork.py | 6 +-
modules/img2img.py | 2 +-
modules/progress.py | 96 ++++++++++
modules/sd_samplers.py | 2 +-
modules/shared.py | 16 +-
modules/textual_inversion/preprocess.py | 2 +-
modules/textual_inversion/textual_inversion.py | 6 +-
modules/txt2img.py | 2 +-
modules/ui.py | 41 ++--
modules/ui_progress.py | 101 ----------
style.css | 74 +++++---
webui.py | 3 +
16 files changed, 390 insertions(+), 275 deletions(-)
create mode 100644 modules/progress.py
delete mode 100644 modules/ui_progress.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index d6323ed9..b7524ef7 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -1,82 +1,25 @@
// code related to showing and updating progressbar shown as the image is being made
-global_progressbars = {}
-galleries = {}
-galleryObservers = {}
-
-// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
-timeoutIds = {}
-function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
- // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
- // every time you use gr.HTML(elem_id='xxx'), so we handle this here
- var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
- var progressbarParent
- if(progressbar){
- progressbarParent = gradioApp().querySelector("#"+id_progressbar)
- } else{
- progressbar = gradioApp().getElementById(id_progressbar)
- progressbarParent = null
- }
- var skip = id_skip ? gradioApp().getElementById(id_skip) : null
- var interrupt = gradioApp().getElementById(id_interrupt)
-
- if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
- if(progressbar.innerText){
- let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
- if(document.title != newtitle){
- document.title = newtitle;
- }
- }else{
- let newtitle = 'Stable Diffusion'
- if(document.title != newtitle){
- document.title = newtitle;
- }
- }
- }
-
- if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
- global_progressbars[id_progressbar] = progressbar
-
- var mutationObserver = new MutationObserver(function(m){
- if(timeoutIds[id_part]) return;
-
- preview = gradioApp().getElementById(id_preview)
- gallery = gradioApp().getElementById(id_gallery)
+galleries = {}
+storedGallerySelections = {}
+galleryObservers = {}
- if(preview != null && gallery != null){
- preview.style.width = gallery.clientWidth + "px"
- preview.style.height = gallery.clientHeight + "px"
- if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
+function rememberGallerySelection(id_gallery){
+ storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
+}
- //only watch gallery if there is a generation process going on
- check_gallery(id_gallery);
+function getGallerySelectedIndex(id_gallery){
+ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
+ let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
- var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
- if(progressDiv){
- timeoutIds[id_part] = window.setTimeout(function() {
- timeoutIds[id_part] = null
- requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
- }, 500)
- } else{
- if (skip) {
- skip.style.display = "none"
- }
- interrupt.style.display = "none"
+ let currentlySelectedIndex = -1
+ galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
- //disconnect observer once generation finished, so user can close selected image if they want
- if (galleryObservers[id_gallery]) {
- galleryObservers[id_gallery].disconnect();
- galleries[id_gallery] = null;
- }
- }
- }
-
- });
- mutationObserver.observe( progressbar, { childList:true, subtree:true })
- }
+ return currentlySelectedIndex
}
+// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
@@ -85,10 +28,16 @@ function check_gallery(id_gallery){
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
- let prevSelectedIndex = selected_gallery_index();
+
+ storedGallerySelections[id_gallery] = -1
+
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
+ let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
+ prevSelectedIndex = storedGallerySelections[id_gallery]
+ storedGallerySelections[id_gallery] = -1
+
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
// automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement;
@@ -120,30 +69,150 @@ function check_gallery(id_gallery){
}
onUiUpdate(function(){
- check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
- check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
- check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
+ check_gallery('txt2img_gallery')
+ check_gallery('img2img_gallery')
})
-function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
- btn = gradioApp().getElementById(id_part+"_check_progress");
- if(btn==null) return;
-
- btn.click();
- var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
- var skip = id_skip ? gradioApp().getElementById(id_skip) : null
- var interrupt = gradioApp().getElementById(id_interrupt)
- if(progressDiv && interrupt){
- if (skip) {
- skip.style.display = "block"
+function request(url, data, handler, errorHandler){
+ var xhr = new XMLHttpRequest();
+ var url = url;
+ xhr.open("POST", url, true);
+ xhr.setRequestHeader("Content-Type", "application/json");
+ xhr.onreadystatechange = function () {
+ if (xhr.readyState === 4) {
+ if (xhr.status === 200) {
+ var js = JSON.parse(xhr.responseText);
+ handler(js)
+ } else{
+ errorHandler()
+ }
}
- interrupt.style.display = "block"
+ };
+ var js = JSON.stringify(data);
+ xhr.send(js);
+}
+
+function pad2(x){
+ return x<10 ? '0'+x : x
+}
+
+function formatTime(secs){
+ if(secs > 3600){
+ return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
+ } else if(secs > 60){
+ return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
+ } else{
+ return Math.floor(secs) + "s"
}
}
-function requestProgress(id_part){
- btn = gradioApp().getElementById(id_part+"_check_progress_initial");
- if(btn==null) return;
+function randomId(){
+ return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
+}
+
+// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
+// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
+// calls onProgress every time there is a progress update
+function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
+ var dateStart = new Date()
+ var wasEverActive = false
+ var parentProgressbar = progressbarContainer.parentNode
+ var parentGallery = gallery.parentNode
+
+ var divProgress = document.createElement('div')
+ divProgress.className='progressDiv'
+ var divInner = document.createElement('div')
+ divInner.className='progress'
+
+ divProgress.appendChild(divInner)
+ parentProgressbar.insertBefore(divProgress, progressbarContainer)
+
+ var livePreview = document.createElement('div')
+ livePreview.className='livePreview'
+ parentGallery.insertBefore(livePreview, gallery)
+
+ var removeProgressBar = function(){
+ parentProgressbar.removeChild(divProgress)
+ parentGallery.removeChild(livePreview)
+ atEnd()
+ }
+
+ var fun = function(id_task, id_live_preview){
+ request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
+ console.log(res)
+
+ if(res.completed){
+ removeProgressBar()
+ return
+ }
+
+ var rect = progressbarContainer.getBoundingClientRect()
+
+ if(rect.width){
+ divProgress.style.width = rect.width + "px";
+ }
+
+ progressText = ""
+
+ divInner.style.width = ((res.progress || 0) * 100.0) + '%'
+
+ if(res.progress > 0){
+ progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
+ }
+
+ if(res.eta){
+ progressText += " ETA: " + formatTime(res.eta)
+ } else if(res.textinfo){
+ progressText += " " + res.textinfo
+ }
+
+ divInner.textContent = progressText
+
+ var elapsedFromStart = (new Date() - dateStart) / 1000
+
+ if(res.active) wasEverActive = true;
+
+ if(! res.active && wasEverActive){
+ removeProgressBar()
+ return
+ }
+
+ if(elapsedFromStart > 5 && !res.queued && !res.active){
+ removeProgressBar()
+ return
+ }
+
+
+ if(res.live_preview){
+ var img = new Image();
+ img.onload = function() {
+ var rect = gallery.getBoundingClientRect()
+ if(rect.width){
+ livePreview.style.width = rect.width + "px"
+ livePreview.style.height = rect.height + "px"
+ }
+
+ livePreview.innerHTML = ''
+ livePreview.appendChild(img)
+ if(livePreview.childElementCount > 2){
+ livePreview.removeChild(livePreview.firstElementChild)
+ }
+ }
+ img.src = res.live_preview;
+ }
+
+
+ if(onProgress){
+ onProgress(res)
+ }
+
+ setTimeout(() => {
+ fun(id_task, res.id_live_preview);
+ }, 500)
+ }, function(){
+ removeProgressBar()
+ })
+ }
- btn.click();
+ fun(id_task, 0)
}
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
index 8061be08..0354b860 100644
--- a/javascript/textualInversion.js
+++ b/javascript/textualInversion.js
@@ -1,8 +1,17 @@
+
function start_training_textual_inversion(){
- requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
- return args_to_array(arguments)
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
+ gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
+ })
+
+ var res = args_to_array(arguments)
+
+ res[0] = id
+
+ return res
}
diff --git a/javascript/ui.js b/javascript/ui.js
index f8279124..ecf97cb3 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -126,18 +126,41 @@ function create_submit_args(args){
return res
}
+function showSubmitButtons(tabname, show){
+ gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
+ gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
+}
+
function submit(){
- requestProgress('txt2img')
+ rememberGallerySelection('txt2img_gallery')
+ showSubmitButtons('txt2img', false)
+
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
+ showSubmitButtons('txt2img', true)
+
+ })
- return create_submit_args(arguments)
+ var res = create_submit_args(arguments)
+
+ res[0] = id
+
+ return res
}
function submit_img2img(){
- requestProgress('img2img')
+ rememberGallerySelection('img2img_gallery')
+ showSubmitButtons('img2img', false)
+
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
+ showSubmitButtons('img2img', true)
+ })
- res = create_submit_args(arguments)
+ var res = create_submit_args(arguments)
- res[0] = get_tab_index('mode_img2img')
+ res[0] = id
+ res[1] = get_tab_index('mode_img2img')
return res
}
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 4cd49533..92097c15 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -4,7 +4,7 @@ import threading
import traceback
import time
-from modules import shared
+from modules import shared, progress
queue_lock = threading.Lock()
@@ -22,12 +22,23 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
- shared.state.begin()
+ # if the first argument is a string that says "task(...)", it is treated as a job id
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ id_task = args[0]
+ progress.add_task_to_queue(id_task)
+ else:
+ id_task = None
with queue_lock:
- res = func(*args, **kwargs)
+ shared.state.begin()
+ progress.start_task(id_task)
+
+ try:
+ res = func(*args, **kwargs)
+ finally:
+ progress.finish_task(id_task)
- shared.state.end()
+ shared.state.end()
return res
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3aebefa8..ae6af516 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
torch.cuda.set_rng_state_all(cuda_rng_state)
hypernetwork.train()
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
diff --git a/modules/img2img.py b/modules/img2img.py
index f62783c6..f4a03c57 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
diff --git a/modules/progress.py b/modules/progress.py
new file mode 100644
index 00000000..3327b883
--- /dev/null
+++ b/modules/progress.py
@@ -0,0 +1,96 @@
+import base64
+import io
+import time
+
+import gradio as gr
+from pydantic import BaseModel, Field
+
+from modules.shared import opts
+
+import modules.shared as shared
+
+
+current_task = None
+pending_tasks = {}
+finished_tasks = []
+
+
+def start_task(id_task):
+ global current_task
+
+ current_task = id_task
+ pending_tasks.pop(id_task, None)
+
+
+def finish_task(id_task):
+ global current_task
+
+ if current_task == id_task:
+ current_task = None
+
+ finished_tasks.append(id_task)
+ if len(finished_tasks) > 16:
+ finished_tasks.pop(0)
+
+
+def add_task_to_queue(id_job):
+ pending_tasks[id_job] = time.time()
+
+
+class ProgressRequest(BaseModel):
+ id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
+ id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+
+
+class ProgressResponse(BaseModel):
+ active: bool = Field(title="Whether the task is being worked on right now")
+ queued: bool = Field(title="Whether the task is in queue")
+ completed: bool = Field(title="Whether the task has already finished")
+ progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
+ eta: float = Field(default=None, title="ETA in secs")
+ live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
+ id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+
+
+def setup_progress_api(app):
+ return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+
+
+def progressapi(req: ProgressRequest):
+ active = req.id_task == current_task
+ queued = req.id_task in pending_tasks
+ completed = req.id_task in finished_tasks
+
+ if not active:
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
+
+ progress = 0
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ progress = min(progress, 1)
+
+ elapsed_since_start = time.time() - shared.state.time_start
+ predicted_duration = elapsed_since_start / progress if progress > 0 else None
+ eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
+
+ id_live_preview = req.id_live_preview
+ shared.state.set_current_image()
+ if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
+ image = shared.state.current_image
+ if image is not None:
+ buffered = io.BytesIO()
+ image.save(buffered, format="png")
+ live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
+ id_live_preview = shared.state.id_live_preview
+ else:
+ live_preview = None
+ else:
+ live_preview = None
+
+ return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 7616fded..76e0e0d5 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -140,7 +140,7 @@ def store_latent(decoded):
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
- shared.state.current_image = sample_to_image(decoded)
+ shared.state.assign_current_image(sample_to_image(decoded))
class InterruptedException(BaseException):
diff --git a/modules/shared.py b/modules/shared.py
index 51df056c..de99aca9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -152,6 +152,7 @@ def reload_hypernetworks():
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
class State:
skipped = False
interrupted = False
@@ -165,6 +166,7 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
+ id_live_preview = 0
textinfo = None
time_start = None
need_restart = False
@@ -207,6 +209,7 @@ class State:
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
+ self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
@@ -220,8 +223,8 @@ class State:
devices.torch_gc()
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not parallel_processing_allowed:
return
@@ -234,12 +237,16 @@ class State:
import modules.sd_samplers
if opts.show_progress_grid:
- self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
- self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
+ def assign_current_image(self, image):
+ self.current_image = image
+ self.id_live_preview += 1
+
state = State()
state.server_start = time.time()
@@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
}))
options_templates.update(options_section(('ui', "User interface"), {
- "show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
@@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), {
options_templates.update(options_section(('ui', "Live previews"), {
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 3c1042ad..64abff4d 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 63935878..7e4a6d24 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
@@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
@@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 38b5f591..ca5d4550 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
diff --git a/modules/ui.py b/modules/ui.py
index 2425c66f..ff33236b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -356,7 +356,7 @@ def create_toprow(is_img2img):
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
- with gr.Row():
+ with gr.Row(elem_id=f"{id_part}_generate_box"):
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
@@ -384,9 +384,7 @@ def create_toprow(is_img2img):
def setup_progressbar(*args, **kwargs):
- import modules.ui_progress
-
- modules.ui_progress.setup_progressbar(*args, **kwargs)
+ pass
def apply_setting(key, value):
@@ -479,8 +477,8 @@ Requested path was: {f}
else:
sp.Popen(["xdg-open", path])
- with gr.Column(variant='panel'):
- with gr.Group():
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
generation_info = None
@@ -595,15 +593,6 @@ def create_ui():
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
@@ -682,6 +671,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
+ dummy_component,
txt2img_prompt,
txt2img_negative_prompt,
txt2img_prompt_style,
@@ -782,16 +772,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"):
@@ -958,6 +939,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
inputs=[
+ dummy_component,
dummy_component,
img2img_prompt,
img2img_negative_prompt,
@@ -1335,15 +1317,11 @@ def create_ui():
script_callbacks.ui_train_tabs_callback(params)
- with gr.Column():
- progressbar = gr.HTML(elem_id="ti_progressbar")
+ with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
- ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
- setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
@@ -1384,6 +1362,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
process_src,
process_dst,
process_width,
@@ -1411,6 +1390,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_embedding_name,
embedding_learn_rate,
batch_size,
@@ -1443,6 +1423,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
diff --git a/modules/ui_progress.py b/modules/ui_progress.py
deleted file mode 100644
index 7cd312e4..00000000
--- a/modules/ui_progress.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import time
-
-import gradio as gr
-
-from modules.shared import opts
-
-import modules.shared as shared
-
-
-def calc_time_left(progress, threshold, label, force_display, show_eta):
- if progress == 0:
- return ""
- else:
- time_since_start = time.time() - shared.state.time_start
- eta = (time_since_start/progress)
- eta_relative = eta-time_since_start
- if (eta_relative > threshold and show_eta) or force_display:
- if eta_relative > 3600:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
- elif eta_relative > 60:
- return label + time.strftime('%M:%S', time.gmtime(eta_relative))
- else:
- return label + time.strftime('%Ss', time.gmtime(eta_relative))
- else:
- return ""
-
-
-def check_progress_call(id_part):
- if shared.state.job_count == 0:
- return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
-
- progress = 0
-
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
-
- # Show progress percentage and time left at the same moment, and base it also on steps done
- show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
-
- time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
- if time_left != "":
- shared.state.time_left_force_display = True
-
- progress = min(progress, 1)
-
- progressbar = ""
- if opts.show_progressbar:
- progressbar = f"""{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
"""
-
- image = gr.update(visible=False)
- preview_visibility = gr.update(visible=False)
-
- if opts.live_previews_enable:
- shared.state.set_current_image()
- image = shared.state.current_image
-
- if image is None:
- image = gr.update(value=None)
- else:
- preview_visibility = gr.update(visible=True)
-
- if shared.state.textinfo is not None:
- textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
- else:
- textinfo_result = gr.update(visible=False)
-
- return f"{time.time()}{progressbar}
", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
- shared.state.job_count = -1
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.textinfo = None
- shared.state.time_start = time.time()
- shared.state.time_left_force_display = False
-
- return check_progress_call(id_part)
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
- if textinfo is None:
- textinfo = gr.HTML(visible=False)
-
- check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
- check_progress.click(
- fn=lambda: check_progress_call(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
- check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
- check_progress_initial.click(
- fn=lambda: check_progress_call_initial(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
diff --git a/style.css b/style.css
index 2d484e06..786b71d1 100644
--- a/style.css
+++ b/style.css
@@ -305,26 +305,42 @@ input[type="range"]{
}
.progressDiv{
- width: 100%;
- height: 20px;
- background: #b4c0cc;
- border-radius: 8px;
+ position: absolute;
+ height: 20px;
+ top: -20px;
+ background: #b4c0cc;
+ border-radius: 8px !important;
}
.dark .progressDiv{
- background: #424c5b;
+ background: #424c5b;
}
.progressDiv .progress{
- width: 0%;
- height: 20px;
- background: #0060df;
- color: white;
- font-weight: bold;
- line-height: 20px;
- padding: 0 8px 0 0;
- text-align: right;
- border-radius: 8px;
+ width: 0%;
+ height: 20px;
+ background: #0060df;
+ color: white;
+ font-weight: bold;
+ line-height: 20px;
+ padding: 0 8px 0 0;
+ text-align: right;
+ border-radius: 8px;
+ overflow: visible;
+ white-space: nowrap;
+}
+
+.livePreview{
+ position: absolute;
+ z-index: 300;
+ background-color: white;
+ margin: -4px;
+}
+
+.livePreview img{
+ object-fit: contain;
+ width: 100%;
+ height: 100%;
}
#lightboxModal{
@@ -450,23 +466,25 @@ input[type="range"]{
display:none
}
-#txt2img_interrupt, #img2img_interrupt{
- position: absolute;
- width: 50%;
- height: 72px;
- background: #b4c0cc;
- border-radius: 0px;
- display: none;
+#txt2img_generate_box, #img2img_generate_box{
+ position: relative;
+}
+
+#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{
+ position: absolute;
+ width: 50%;
+ height: 100%;
+ background: #b4c0cc;
+ display: none;
}
+#txt2img_interrupt, #img2img_interrupt{
+ right: 0;
+ border-radius: 0 0.5rem 0.5rem 0;
+}
#txt2img_skip, #img2img_skip{
- position: absolute;
- width: 50%;
- right: 0px;
- height: 72px;
- background: #b4c0cc;
- border-radius: 0px;
- display: none;
+ left: 0;
+ border-radius: 0.5rem 0 0 0.5rem;
}
.red {
diff --git a/webui.py b/webui.py
index 1fff80da..4624fe18 100644
--- a/webui.py
+++ b/webui.py
@@ -34,6 +34,7 @@ import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
+import modules.progress
import modules.ui
from modules import modelloader
@@ -181,6 +182,8 @@ def webui():
app.add_middleware(GZipMiddleware, minimum_size=1000)
+ modules.progress.setup_progress_api(app)
+
if launch_api:
create_api(app)
--
cgit v1.2.3
From 924e222004ab54273806c5f2ca7a0e7cfa76ad83 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 18 Jan 2023 23:04:24 +0300
Subject: add option to show/hide warnings removed hiding warnings from LDSR
fixed/reworked few places that produced warnings
---
extensions-builtin/LDSR/ldsr_model_arch.py | 3 --
javascript/localization.js | 2 +-
modules/hypernetworks/hypernetwork.py | 7 ++++-
modules/sd_hijack.py | 8 ------
modules/sd_hijack_checkpoint.py | 38 +++++++++++++++++++++++++-
modules/shared.py | 1 +
modules/textual_inversion/textual_inversion.py | 6 +++-
modules/ui.py | 31 ++++++++++++---------
scripts/prompts_from_file.py | 2 +-
style.css | 5 ++--
10 files changed, 71 insertions(+), 32 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 0ad49f4e..bc11cc6e 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -1,7 +1,6 @@
import os
import gc
import time
-import warnings
import numpy as np
import torch
@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack
-warnings.filterwarnings("ignore", category=UserWarning)
-
cached_ldsr_model: torch.nn.Module = None
diff --git a/javascript/localization.js b/javascript/localization.js
index bf9e1506..1a5a1dbb 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -11,7 +11,7 @@ ignore_ids_for_localization={
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION',
- img2img_styles 'OPTION',
+ img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c963fc40..74e78582 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close()
hypernetwork.eval()
#report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6b0d95af..870eba88 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -69,12 +69,6 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-def fix_checkpoint():
- ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
-
-
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -106,8 +100,6 @@ class StableDiffusionModelHijack:
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
-
- fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
index 5712972f..2604d969 100644
--- a/modules/sd_hijack_checkpoint.py
+++ b/modules/sd_hijack_checkpoint.py
@@ -1,10 +1,46 @@
from torch.utils.checkpoint import checkpoint
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.openaimodel
+
+
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
+
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
+
def ResBlock_forward(self, x, emb):
- return checkpoint(self._forward, x, emb)
\ No newline at end of file
+ return checkpoint(self._forward, x, emb)
+
+
+stored = []
+
+
+def add():
+ if len(stored) != 0:
+ return
+
+ stored.extend([
+ ldm.modules.attention.BasicTransformerBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
+ ])
+
+ ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
+
+
+def remove():
+ if len(stored) == 0:
+ return
+
+ ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
+
+ stored.clear()
+
diff --git a/modules/shared.py b/modules/shared.py
index a708f23c..ddb97f99 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
}))
options_templates.update(options_section(('system', "System"), {
+ "show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7e4a6d24..5a7be422 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -15,7 +15,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
return embedding, filename
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
diff --git a/modules/ui.py b/modules/ui.py
index 6d70a795..25818fb0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -11,6 +11,7 @@ import tempfile
import time
import traceback
from functools import partial, reduce
+import warnings
import gradio as gr
import gradio.routes
@@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
+warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
@@ -417,17 +420,16 @@ def apply_setting(key, value):
return value
-def update_generation_info(args):
- generation_info, html_info, img_index = args
+def update_generation_info(generation_info, html_info, img_index):
try:
generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
except Exception:
pass
# if the json parse or anything else fails, just return the old html_info
- return html_info
+ return html_info, gr.update()
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
@@ -508,10 +510,9 @@ Requested path was: {f}
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
+ _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
)
save.click(
@@ -526,7 +527,8 @@ Requested path was: {f}
outputs=[
download_files,
html_log,
- ]
+ ],
+ show_progress=False,
)
save_zip.click(
@@ -588,7 +590,7 @@ def create_ui():
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
@@ -768,7 +770,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
@@ -1768,7 +1770,10 @@ def create_ui():
if saved_value is None:
ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value):
- print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ pass
+
+ # this warning is generally not useful;
+ # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value)
if init_field is not None:
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index f3e711d7..76dc5778 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -116,7 +116,7 @@ class Script(scripts.Script):
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
- file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file"))
+ file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
diff --git a/style.css b/style.css
index 61279a19..0845519a 100644
--- a/style.css
+++ b/style.css
@@ -299,9 +299,8 @@ input[type="range"]{
}
/* more gradio's garbage cleanup */
-.min-h-\[4rem\] {
- min-height: unset !important;
-}
+.min-h-\[4rem\] { min-height: unset !important; }
+.min-h-\[6rem\] { min-height: unset !important; }
.progressDiv{
position: absolute;
--
cgit v1.2.3
From 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 21 Jan 2023 08:36:07 +0300
Subject: extra networks UI rework of hypernets: rather than via settings,
hypernets are added directly to prompt as
---
html/card-no-preview.png | Bin 0 -> 84440 bytes
html/extra-networks-card.html | 11 ++
html/extra-networks-no-cards.html | 8 ++
javascript/extraNetworks.js | 60 ++++++++
javascript/hints.js | 2 +
javascript/ui.js | 9 +-
modules/api/api.py | 7 +-
modules/extra_networks.py | 147 +++++++++++++++++++
modules/extra_networks_hypernet.py | 21 +++
modules/generation_parameters_copypaste.py | 12 +-
modules/hypernetworks/hypernetwork.py | 107 +++++++++-----
modules/hypernetworks/ui.py | 5 +-
modules/processing.py | 24 ++--
modules/sd_hijack_optimizations.py | 10 +-
modules/shared.py | 21 ++-
modules/textual_inversion/textual_inversion.py | 2 +
modules/ui.py | 50 ++++---
modules/ui_components.py | 10 ++
modules/ui_extra_networks.py | 149 +++++++++++++++++++
modules/ui_extra_networks_hypernets.py | 34 +++++
modules/ui_extra_networks_textual_inversion.py | 32 +++++
script.js | 13 +-
scripts/xy_grid.py | 29 ----
style.css | 190 +++++++++++++------------
webui.py | 26 +++-
25 files changed, 765 insertions(+), 214 deletions(-)
create mode 100644 html/card-no-preview.png
create mode 100644 html/extra-networks-card.html
create mode 100644 html/extra-networks-no-cards.html
create mode 100644 javascript/extraNetworks.js
create mode 100644 modules/extra_networks.py
create mode 100644 modules/extra_networks_hypernet.py
create mode 100644 modules/ui_extra_networks.py
create mode 100644 modules/ui_extra_networks_hypernets.py
create mode 100644 modules/ui_extra_networks_textual_inversion.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/html/card-no-preview.png b/html/card-no-preview.png
new file mode 100644
index 00000000..e2beb269
Binary files /dev/null and b/html/card-no-preview.png differ
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
new file mode 100644
index 00000000..7314b063
--- /dev/null
+++ b/html/extra-networks-card.html
@@ -0,0 +1,11 @@
+
+
diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html
new file mode 100644
index 00000000..389358d6
--- /dev/null
+++ b/html/extra-networks-no-cards.html
@@ -0,0 +1,8 @@
+
+
Nothing here. Add some content to the following directories:
+
+
+
+
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
new file mode 100644
index 00000000..71e522d1
--- /dev/null
+++ b/javascript/extraNetworks.js
@@ -0,0 +1,60 @@
+
+function setupExtraNetworksForTab(tabname){
+ gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
+
+ gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
+ gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
+}
+
+var activePromptTextarea = null;
+var activePositivePromptTextarea = null;
+
+function setupExtraNetworks(){
+ setupExtraNetworksForTab('txt2img')
+ setupExtraNetworksForTab('img2img')
+
+ function registerPrompt(id, isNegative){
+ var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
+
+ if (activePromptTextarea == null){
+ activePromptTextarea = textarea
+ }
+ if (activePositivePromptTextarea == null && ! isNegative){
+ activePositivePromptTextarea = textarea
+ }
+
+ textarea.addEventListener("focus", function(){
+ activePromptTextarea = textarea;
+ if(! isNegative) activePositivePromptTextarea = textarea;
+ });
+ }
+
+ registerPrompt('txt2img_prompt')
+ registerPrompt('txt2img_neg_prompt', true)
+ registerPrompt('img2img_prompt')
+ registerPrompt('img2img_neg_prompt', true)
+}
+
+onUiLoaded(setupExtraNetworks)
+
+function cardClicked(textToAdd, allowNegativePrompt){
+ textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea
+
+ textarea.value = textarea.value + " " + textToAdd
+ updateInput(textarea)
+
+ return false
+}
+
+function saveCardPreview(event, tabname, filename){
+ textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
+ button = gradioApp().getElementById(tabname + '_save_preview')
+
+ textarea.value = filename
+ updateInput(textarea)
+
+ button.click()
+
+ event.stopPropagation()
+ event.preventDefault()
+}
diff --git a/javascript/hints.js b/javascript/hints.js
index e746e20d..f4079f96 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -21,6 +21,8 @@ titles = {
"\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field",
+ "\u{1f3b4}": "Show extra networks",
+
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
diff --git a/javascript/ui.js b/javascript/ui.js
index 3ba90ca8..a7e75439 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt]
}
-
-
opts = {}
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
@@ -239,11 +237,14 @@ onUiUpdate(function(){
return
}
+
prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative"
- textarea.addEventListener("input", () => update_token_counter(id_button));
+ textarea.addEventListener("input", function(){
+ update_token_counter(id_button);
+ });
}
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
@@ -261,10 +262,8 @@ onUiUpdate(function(){
})
}
}
-
})
-
onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
diff --git a/modules/api/api.py b/modules/api/api.py
index 9814bbc2..2c371e6e 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -480,7 +480,7 @@ class Api:
def train_hypernetwork(self, args: dict):
try:
shared.state.begin()
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -491,16 +491,15 @@ class Api:
except Exception as e:
error = e
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
- return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
except AssertionError as msg:
shared.state.end()
- return TrainResponse(info = "train embedding error: {error}".format(error = error))
+ return TrainResponse(info="train embedding error: {error}".format(error=error))
def get_memory(self):
try:
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
new file mode 100644
index 00000000..1978673d
--- /dev/null
+++ b/modules/extra_networks.py
@@ -0,0 +1,147 @@
+import re
+from collections import defaultdict
+
+from modules import errors
+
+extra_network_registry = {}
+
+
+def initialize():
+ extra_network_registry.clear()
+
+
+def register_extra_network(extra_network):
+ extra_network_registry[extra_network.name] = extra_network
+
+
+class ExtraNetworkParams:
+ def __init__(self, items=None):
+ self.items = items or []
+
+
+class ExtraNetwork:
+ def __init__(self, name):
+ self.name = name
+
+ def activate(self, p, params_list):
+ """
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
+ Passes arguments related to this extra network in params_list.
+ User passes arguments by specifying this in his prompt:
+
+
+
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
+ separated by colon.
+
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
+ in this case, all effects of this extra networks should be disabled.
+
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
+
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
+
+ > "1girl, "
+
+ params_list will be:
+
+ [
+ ExtraNetworkParams(items=["agm", "1.1"]),
+ ExtraNetworkParams(items=["ray"])
+ ]
+
+ """
+ raise NotImplementedError
+
+ def deactivate(self, p):
+ """
+ Called at the end of processing for housekeeping. No need to do anything here.
+ """
+
+ raise NotImplementedError
+
+
+def activate(p, extra_network_data):
+ """call activate for extra networks in extra_network_data in specified order, then call
+ activate for all remaining registered networks with an empty argument list"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ print(f"Skipping unknown extra network: {extra_network_name}")
+ continue
+
+ try:
+ extra_network.activate(p, extra_network_args)
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.activate(p, [])
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name}")
+
+
+def deactivate(p, extra_network_data):
+ """call deactivate for extra networks in extra_network_data in specified order, then call
+ deactivate for all remaining registered networks"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating extra network {extra_network_name}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
+
+
+re_extra_net = re.compile(r"<(\w+):([^>]+)>")
+
+
+def parse_prompt(prompt):
+ res = defaultdict(list)
+
+ def found(m):
+ name = m.group(1)
+ args = m.group(2)
+
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
+
+ return ""
+
+ prompt = re.sub(re_extra_net, found, prompt)
+
+ return prompt, res
+
+
+def parse_prompts(prompts):
+ res = []
+ extra_data = None
+
+ for prompt in prompts:
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
+
+ if extra_data is None:
+ extra_data = parsed_extra_data
+
+ res.append(updated_prompt)
+
+ return res, extra_data
+
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
new file mode 100644
index 00000000..6a0c4ba8
--- /dev/null
+++ b/modules/extra_networks_hypernet.py
@@ -0,0 +1,21 @@
+from modules import extra_networks
+from modules.hypernetworks import hypernetwork
+
+
+class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('hypernet')
+
+ def activate(self, p, params_list):
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ hypernetwork.load_hypernetworks(names, multipliers)
+
+ def deactivate(p, self):
+ pass
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index a381ff59..46e12dc6 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
from modules import ui
settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
@@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Clip skip" not in res:
res["Clip skip"] = "1"
- if "Hypernet strength" not in res:
- res["Hypernet strength"] = "1"
-
- if "Hypernet" in res:
- hypernet_name = res["Hypernet"]
- hypernet_hash = res.get("Hypernet hash", None)
- res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+ hypernet = res.get("Hypernet", None)
+ if hypernet is not None:
+ res["Prompt"] += f""""""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74e78582..80a47c79 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -25,7 +25,6 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
- multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__()
+ self.multiplier = 1.0
+
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
state_dict[to] = x
def forward(self, x):
- return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
def trainables(self):
layer_structure = []
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
return layer_structure
-def apply_strength(value=None):
- HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None:
@@ -192,6 +190,20 @@ class Hypernetwork:
for param in layer.parameters():
param.requires_grad = mode
+ def to(self, device):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.to(device)
+
+ return self
+
+ def set_multiplier(self, multiplier):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.multiplier = multiplier
+
+ return self
+
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
@@ -269,11 +281,13 @@ class Hypernetwork:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- print("Loaded existing optimizer from checkpoint")
- print(f"Optimizer name is {self.optimizer_name}")
+ if shared.opts.print_hypernet_extra:
+ print("Loaded existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
else:
self.optimizer_name = "AdamW"
- print("No saved optimizer exists in checkpoint")
+ if shared.opts.print_hypernet_extra:
+ print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
@@ -306,23 +320,43 @@ def list_hypernetworks(path):
return res
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- # Prevent any file named "None.pt" from being loaded.
- if path is not None and filename != "None":
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+def load_hypernetwork(name):
+ path = shared.hypernetworks.get(name, None)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print("Unloading hypernetwork")
+ if path is None:
+ return None
+
+ hypernetwork = Hypernetwork()
+
+ try:
+ hypernetwork.load(path)
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+ already_loaded = {}
+
+ for hypernetwork in shared.loaded_hypernetworks:
+ if hypernetwork.name in names:
+ already_loaded[hypernetwork.name] = hypernetwork
- shared.loaded_hypernetwork = None
+ shared.loaded_hypernetworks.clear()
+
+ for i, name in enumerate(names):
+ hypernetwork = already_loaded.get(name, None)
+ if hypernetwork is None:
+ hypernetwork = load_hypernetwork(name)
+
+ if hypernetwork is None:
+ continue
+
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+ shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
return applicable[0]
-def apply_hypernetwork(hypernetwork, context, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
- return context, context
+ return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
+ context_k = hypernetwork_layers[0](context_k)
+ context_v = hypernetwork_layers[1](context_v)
+ return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+ context_k = context
+ context_v = context
+ for hypernetwork in hypernetworks:
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+
return context_k, context_v
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+ hypernetwork = Hypernetwork()
+ hypernetwork.load(path)
+ shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else:
images_dir = None
- hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 81e3f519..76599f5a 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
@@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
def train_hypernetwork(*args):
-
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception:
raise
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
diff --git a/modules/processing.py b/modules/processing.py
index a3e9f709..b5deeacf 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
- "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork':
- shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights() # make onchange call for changing SD model
+ sd_models.reload_model_weights()
if k == 'sd_vae':
- sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ sd_vae.reload_vae_weights()
res = process_images_inner(p)
@@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights()
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights()
return res
@@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
cache[0] = (required_prompts, steps)
return cache[1]
+ p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
+
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ extra_networks.activate(p, extra_network_data)
+
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
@@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ extra_networks.deactivate(p, extra_network_data)
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index cdc63ed7..4fa54329 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
del context, context_k, context_v, x
@@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
@@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) * self.scale
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
diff --git a/modules/shared.py b/modules/shared.py
index 2f366454..c0e11f18 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -23,6 +23,7 @@ demo = None
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
+
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {}
-loaded_hypernetwork = None
+loaded_hypernetworks = []
def reload_hypernetworks():
@@ -153,8 +154,6 @@ def reload_hypernetworks():
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
- hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
-
class State:
@@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
- "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
- "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -661,3 +658,17 @@ mem_mon.start()
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
+
+
+def html_path(filename):
+ return os.path.join(script_path, "html", filename)
+
+
+def html(filename):
+ path = html_path(filename)
+
+ if os.path.exists(path):
+ with open(path, encoding="utf8") as file:
+ return file.read()
+
+ return ""
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5a7be422..4e90f690 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -50,6 +50,7 @@ class Embedding:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
+ self.filename = None
def save(self, filename):
embedding_data = {
@@ -182,6 +183,7 @@ class EmbeddingDatabase:
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
+ embedding.filename = path
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
diff --git a/modules/ui.py b/modules/ui.py
index 06c11848..d23b2b8e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️
+extra_networks_symbol = '\U0001F3B4' # 🎴
def plaintext_to_html(text):
@@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
def update_token_counter(text, steps):
try:
+ text, _ = extra_networks.parse_prompt(text)
+
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
@@ -354,10 +357,10 @@ def create_toprow(is_img2img):
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ paste = ToolButton(value=paste_symbol, elem_id="paste")
+ clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+
token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter")
@@ -395,11 +398,14 @@ def create_toprow(is_img2img):
outputs=[],
)
- with gr.Row():
+ with gr.Row(elem_id=f"{id_part}_styles_row"):
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button
+ prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply")
+ save_style = ToolButton(value=save_style_symbol, elem_id="style_create")
+
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
def setup_progressbar(*args, **kwargs):
@@ -616,11 +622,15 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
+ with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
@@ -794,14 +804,20 @@ def create_ui():
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
+
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+ with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
@@ -1064,6 +1080,8 @@ def create_ui():
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
@@ -1666,10 +1684,8 @@ def create_ui():
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
- if os.path.exists("html/licenses.html"):
- with open("html/licenses.html", encoding="utf8") as file:
- with gr.TabItem("Licenses"):
- gr.HTML(file.read(), elem_id="licenses")
+ with gr.TabItem("Licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@@ -1756,11 +1772,9 @@ def create_ui():
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
- if os.path.exists("html/footer.html"):
- with open("html/footer.html", encoding="utf8") as file:
- footer = file.read()
- footer = footer.format(versions=versions_html())
- gr.HTML(footer, elem_id="footer")
+ footer = shared.html("footer.html")
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 97acff06..46324425 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent):
return "button"
+class ToolButtonTop(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool-top", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
new file mode 100644
index 00000000..253e90f7
--- /dev/null
+++ b/modules/ui_extra_networks.py
@@ -0,0 +1,149 @@
+import os.path
+
+from modules import shared
+import gradio as gr
+import json
+
+from modules.generation_parameters_copypaste import image_from_url_text
+
+extra_pages = []
+
+
+def register_page(page):
+ """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions"""
+
+ extra_pages.append(page)
+
+
+class ExtraNetworksPage:
+ def __init__(self, title):
+ self.title = title
+ self.card_page = shared.html("extra-networks-card.html")
+ self.allow_negative_prompt = False
+
+ def refresh(self):
+ pass
+
+ def create_html(self, tabname):
+ items_html = ''
+
+ for item in self.list_items():
+ items_html += self.create_html_for_item(item, tabname)
+
+ if items_html == '':
+ dirs = "".join([f"{x}" for x in self.allowed_directories_for_previews()])
+ items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+
+ res = ""
+
+ return res
+
+ def list_items(self):
+ raise NotImplementedError()
+
+ def allowed_directories_for_previews(self):
+ return []
+
+ def create_html_for_item(self, item, tabname):
+ preview = item.get("preview", None)
+
+ args = {
+ "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
+ "prompt": json.dumps(item["prompt"]),
+ "tabname": json.dumps(tabname),
+ "local_preview": json.dumps(item["local_preview"]),
+ "name": item["name"],
+ "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+ }
+
+ return self.card_page.format(**args)
+
+
+def intialize():
+ extra_pages.clear()
+
+
+class ExtraNetworksUi:
+ def __init__(self):
+ self.pages = None
+ self.stored_extra_pages = None
+
+ self.button_save_preview = None
+ self.preview_target_filename = None
+
+ self.tabname = None
+
+
+def create_ui(container, button, tabname):
+ ui = ExtraNetworksUi()
+ ui.pages = []
+ ui.stored_extra_pages = extra_pages.copy()
+ ui.tabname = tabname
+
+ with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+ button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
+
+ for page in ui.stored_extra_pages:
+ with gr.Tab(page.title):
+ page_elem = gr.HTML(page.create_html(ui.tabname))
+ ui.pages.append(page_elem)
+
+ ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+
+ button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
+ button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
+
+ def refresh():
+ res = []
+
+ for pg in ui.stored_extra_pages:
+ pg.refresh()
+ res.append(pg.create_html(ui.tabname))
+
+ return res
+
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
+
+ return ui
+
+
+def path_is_parent(parent_path, child_path):
+ parent_path = os.path.abspath(parent_path)
+ child_path = os.path.abspath(child_path)
+
+ return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+
+
+def setup_ui(ui, gallery):
+ def save_preview(index, images, filename):
+ if len(images) == 0:
+ print("There is no image in gallery to save as a preview.")
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(images) - 1 if index >= len(images) else index
+
+ img_info = images[index if index >= 0 else 0]
+ image = image_from_url_text(img_info)
+
+ is_allowed = False
+ for extra_page in ui.stored_extra_pages:
+ if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ is_allowed = True
+ break
+
+ assert is_allowed, f'writing to {filename} is not allowed'
+
+ image.save(filename)
+
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ ui.button_save_preview.click(
+ fn=save_preview,
+ _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+ inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
+ outputs=[*ui.pages]
+ )
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
new file mode 100644
index 00000000..312dbaf0
--- /dev/null
+++ b/modules/ui_extra_networks_hypernets.py
@@ -0,0 +1,34 @@
+import os
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Hypernetworks')
+
+ def refresh(self):
+ shared.reload_hypernetworks()
+
+ def list_items(self):
+ for name, path in shared.hypernetworks.items():
+ path, ext = os.path.splitext(path)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ break
+
+ yield {
+ "name": name,
+ "filename": path,
+ "preview": preview,
+ "prompt": f"",
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.hypernetwork_dir]
+
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
new file mode 100644
index 00000000..e4a6e3bf
--- /dev/null
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -0,0 +1,32 @@
+import os
+
+from modules import ui_extra_networks, sd_hijack
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Textual Inversion')
+ self.allow_negative_prompt = True
+
+ def refresh(self):
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ def list_items(self):
+ for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ path, ext = os.path.splitext(embedding.filename)
+ preview_file = path + ".preview.png"
+
+ preview = None
+ if os.path.isfile(preview_file):
+ preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+
+ yield {
+ "name": embedding.name,
+ "filename": embedding.filename,
+ "preview": preview,
+ "prompt": embedding.name,
+ "local_preview": path + ".preview.png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/script.js b/script.js
index 3345e32b..97e0bfcf 100644
--- a/script.js
+++ b/script.js
@@ -13,6 +13,7 @@ function get_uiCurrentTabContent() {
}
uiUpdateCallbacks = []
+uiLoadedCallbacks = []
uiTabChangeCallbacks = []
optionsChangedCallbacks = []
let uiCurrentTab = null
@@ -20,6 +21,9 @@ let uiCurrentTab = null
function onUiUpdate(callback){
uiUpdateCallbacks.push(callback)
}
+function onUiLoaded(callback){
+ uiLoadedCallbacks.push(callback)
+}
function onUiTabChange(callback){
uiTabChangeCallbacks.push(callback)
}
@@ -38,8 +42,15 @@ function executeCallbacks(queue, m) {
queue.forEach(function(x){runCallback(x, m)})
}
+var executedOnLoaded = false;
+
document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){
+ if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
+ executedOnLoaded = true;
+ executeCallbacks(uiLoadedCallbacks);
+ }
+
executeCallbacks(uiUpdateCallbacks, m);
const newTab = get_uiCurrentTab();
if ( newTab && ( newTab !== uiCurrentTab ) ) {
@@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() {
/**
* Add a ctrl+enter as a shortcut to start a generation
*/
- document.addEventListener('keydown', function(e) {
+document.addEventListener('keydown', function(e) {
var handled = false;
if (e.key !== undefined) {
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 6629f5d5..b1badec9 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -11,7 +11,6 @@ import modules.scripts as scripts
import gradio as gr
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
-from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs):
raise RuntimeError(f"Unknown checkpoint: {x}")
-def apply_hypernetwork(p, x, xs):
- if x.lower() in ["", "none"]:
- name = None
- else:
- name = hypernetwork.find_closest_hypernetwork_name(x)
- if not name:
- raise RuntimeError(f"Unknown hypernetwork: {x}")
- hypernetwork.load_hypernetwork(name)
-
-
-def apply_hypernetwork_strength(p, x, xs):
- hypernetwork.apply_strength(x)
-
-
-def confirm_hypernetworks(p, xs):
- for x in xs:
- if x.lower() in ["", "none"]:
- continue
- if not hypernetwork.find_closest_hypernetwork_name(x):
- raise RuntimeError(f"Unknown hypernetwork: {x}")
-
-
def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
@@ -208,8 +185,6 @@ axis_options = [
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
- AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
- AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
@@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
class SharedSettingsStackHelper(object):
def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
- self.hypernetwork = opts.sd_hypernetwork
self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb):
@@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object):
modules.sd_models.reload_model_weights()
modules.sd_vae.reload_vae_weights()
- hypernetwork.load_hypernetwork(self.hypernetwork)
- hypernetwork.apply_strength()
-
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
diff --git a/style.css b/style.css
index 3a515ebd..5e8bc2ca 100644
--- a/style.css
+++ b/style.css
@@ -132,13 +132,6 @@
}
#roll_col > button {
- min-width: 2em;
- min-height: 2em;
- max-width: 2em;
- max-height: 2em;
- flex-grow: 0;
- padding-left: 0.25em;
- padding-right: 0.25em;
margin: 0.1em 0;
}
@@ -146,9 +139,10 @@
min-width: 0 !important;
max-width: 8em !important;
margin-right: 1em;
+ gap: 0;
}
#interrogate, #deepbooru{
- margin: 0em 0.25em 0.9em 0.25em;
+ margin: 0em 0.25em 0.5em 0.25em;
min-width: 8em;
max-width: 8em;
}
@@ -157,8 +151,17 @@
min-width: 8em !important;
}
+#txt2img_styles_row, #img2img_styles_row{
+ gap: 0.25em;
+ margin-top: 0.5em;
+}
+
+#txt2img_styles_row > button, #img2img_styles_row > button{
+ margin: 0;
+}
+
#txt2img_styles, #img2img_styles{
- margin-top: 1em;
+ padding: 0;
}
#txt2img_styles ul, #img2img_styles ul{
@@ -635,17 +638,21 @@ canvas[key="mask"] {
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
}
-.gr-button-tool{
+.gr-button-tool, .gr-button-tool-top{
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
- margin: 1.6em 0.7em 0.55em 0;
}
-#tab_modelmerger .gr-button-tool{
+.gr-button-tool{
margin: 0.6em 0em 0.55em 0;
}
+.gr-button-tool-top, #settings .gr-button-tool{
+ margin: 1.6em 0.7em 0.55em 0;
+}
+
+
#modelmerger_results_container{
margin-top: 1em;
overflow: visible;
@@ -763,81 +770,88 @@ footer {
line-height: 2.4em;
}
-/* The following handles localization for right-to-left (RTL) languages like Arabic.
-The rtl media type will only be activated by the logic in javascript/localization.js.
-If you change anything above, you need to make sure it is RTL compliant by just running
-your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
-Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
-@media rtl {
- /* this part was added manually */
- :host {
- direction: rtl;
- }
- select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
- direction: ltr;
- }
- #script_list > label > select,
- #x_type > label > select,
- #y_type > label > select {
- direction: rtl;
- }
- .gr-radio, .gr-checkbox{
- margin-left: 0.25em;
- }
+#txt2img_extra_networks, #img2img_extra_networks{
+ margin-top: -1em;
+}
- /* automatically generated with few manual modifications */
- .performance .time {
- margin-right: unset;
- margin-left: 0;
- }
- .justify-center.overflow-x-scroll {
- justify-content: right;
- }
- .justify-center.overflow-x-scroll button:first-of-type {
- margin-left: unset;
- margin-right: auto;
- }
- .justify-center.overflow-x-scroll button:last-of-type {
- margin-right: unset;
- margin-left: auto;
- }
- #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
- margin-right: unset;
- margin-left: 8em;
- }
- #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
- right: unset;
- left: 0;
- }
- .progressDiv .progress{
- padding: 0 0 0 8px;
- text-align: left;
- }
- #lightboxModal{
- left: unset;
- right: 0;
- }
- .modalPrev, .modalNext{
- border-radius: 3px 0 0 3px;
- }
- .modalNext {
- right: unset;
- left: 0;
- border-radius: 0 3px 3px 0;
- }
- #imageARPreview{
- left:unset;
- right:0px;
- }
- #txt2img_skip, #img2img_skip{
- right: unset;
- left: 0px;
- }
- #context-menu{
- box-shadow:-1px 1px 2px #CE6400;
- }
- .gr-box > div > div > input.gr-text-input{
- right: unset;
- left: 0.5em;
- }
+.extra-networks > div > [id *= '_extra_']{
+ margin: 0.3em;
}
+
+.extra-network-cards .nocards{
+ margin: 1.25em 0.5em 0.5em 0.5em;
+}
+
+.extra-network-cards .nocards h1{
+ font-size: 1.5em;
+ margin-bottom: 1em;
+}
+
+.extra-network-cards .nocards li{
+ margin-left: 0.5em;
+}
+
+.extra-network-cards .card{
+ display: inline-block;
+ margin: 0.5em;
+ width: 16em;
+ height: 24em;
+ box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
+ border-radius: 0.2em;
+ position: relative;
+
+ background-size: auto 100%;
+ background-position: center;
+ overflow: hidden;
+ cursor: pointer;
+
+ background-image: url('./file=html/card-no-preview.png')
+}
+
+.extra-network-cards .card:hover{
+ box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);
+}
+
+.extra-network-cards .card .actions .additional{
+ display: none;
+}
+
+.extra-network-cards .card .actions{
+ position: absolute;
+ bottom: 0;
+ left: 0;
+ right: 0;
+ padding: 0.5em;
+ color: white;
+ background: rgba(0,0,0,0.5);
+ box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);
+ text-shadow: 0 0 0.2em black;
+}
+
+.extra-network-cards .card .actions:hover{
+ box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
+}
+
+.extra-network-cards .card .actions .name{
+ font-size: 1.7em;
+ font-weight: bold;
+ line-break: anywhere;
+}
+
+.extra-network-cards .card .actions:hover .additional{
+ display: block;
+}
+
+.extra-network-cards .card ul{
+ margin: 0.25em 0 0.75em 0.25em;
+ cursor: unset;
+}
+
+.extra-network-cards .card ul a{
+ cursor: pointer;
+}
+
+.extra-network-cards .card ul a:hover{
+ color: red;
+}
+
diff --git a/webui.py b/webui.py
index 865a7300..e8dd822a 100644
--- a/webui.py
+++ b/webui.py
@@ -9,16 +9,18 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from modules import import_hook, errors
+from modules import import_hook, errors, extra_networks
+from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
import torch
+
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
+from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer
import modules.extras
import modules.face_restoration
@@ -84,10 +86,17 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
- shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.reload_hypernetworks()
+
+ ui_extra_networks.intialize()
+ ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+ ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+
+ extra_networks.initialize()
+ extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try:
@@ -209,6 +218,15 @@ def webui():
modules.sd_models.list_models()
+ shared.reload_hypernetworks()
+
+ ui_extra_networks.intialize()
+ ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+ ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+
+ extra_networks.initialize()
+ extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+
if __name__ == "__main__":
if cmd_opts.nowebui:
--
cgit v1.2.3
From e179b6098ac1b1ce9645fef5bd9fd0bc9b918f30 Mon Sep 17 00:00:00 2001
From: "Alex \"mcmonkey\" Goodwin"
Date: Wed, 25 Jan 2023 08:48:40 -0800
Subject: allow symlinks in the textual inversion embeddings folder
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4e90f690..6cf00e65 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -194,7 +194,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path):
return
- for root, dirs, fns in os.walk(embdir.path):
+ for root, dirs, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
--
cgit v1.2.3