From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 15:03:39 +0300
Subject: initial support for training textual inversion
---
modules/textual_inversion/textual_inversion.py | 258 +++++++++++++++++++++++++
1 file changed, 258 insertions(+)
create mode 100644 modules/textual_inversion/textual_inversion.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
new file mode 100644
index 00000000..c0baaace
--- /dev/null
+++ b/modules/textual_inversion/textual_inversion.py
@@ -0,0 +1,258 @@
+import os
+import sys
+import traceback
+
+import torch
+import tqdm
+import html
+import datetime
+
+from modules import shared, devices, sd_hijack, processing
+import modules.textual_inversion.dataset
+
+
+class Embedding:
+ def __init__(self, vec, name, step=None):
+ self.vec = vec
+ self.name = name
+ self.step = step
+ self.cached_checksum = None
+
+ def save(self, filename):
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ }
+
+ torch.save(embedding_data, filename)
+
+ def checksum(self):
+ if self.cached_checksum is not None:
+ return self.cached_checksum
+
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
+ return self.cached_checksum
+
+class EmbeddingDatabase:
+ def __init__(self, embeddings_dir):
+ self.ids_lookup = {}
+ self.word_embeddings = {}
+ self.dir_mtime = None
+ self.embeddings_dir = embeddings_dir
+
+ def register_embedding(self, embedding, model):
+
+ self.word_embeddings[embedding.name] = embedding
+
+ ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+ self.ids_lookup[first_id].append((ids, embedding))
+
+ return embedding
+
+ def load_textual_inversion_embeddings(self):
+ mt = os.path.getmtime(self.embeddings_dir)
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
+ return
+
+ self.dir_mtime = mt
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+
+ def process_file(path, filename):
+ name = os.path.splitext(filename)[0]
+
+ data = torch.load(path, map_location="cpu")
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ self.register_embedding(embedding, shared.sd_model)
+
+ for fn in os.listdir(self.embeddings_dir):
+ try:
+ fullfn = os.path.join(self.embeddings_dir, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+
+ def find_embedding_at_position(self, tokens, offset):
+ token = tokens[offset]
+ possible_matches = self.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ return None
+
+ for ids, embedding in possible_matches:
+ if tokens[offset:offset + len(ids)] == ids:
+ return embedding
+
+ return None
+
+
+
+def create_embedding(name, num_vectors_per_token):
+ init_text = '*'
+
+ cond_model = shared.sd_model.cond_stage_model
+ embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+
+ ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
+
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ embedding = Embedding(vec, name)
+ embedding.step = 0
+ embedding.save(fn)
+
+ return fn
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+ assert embedding_name, 'embedding not selected'
+
+ shared.state.textinfo = "Initializing textual inversion training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+
+ if save_embedding_every > 0:
+ embedding_dir = os.path.join(log_directory, "embeddings")
+ os.makedirs(embedding_dir, exist_ok=True)
+ else:
+ embedding_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ embedding.vec.requires_grad = True
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = ""
+ last_saved_image = ""
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ return embedding, filename
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, (x, text) in pbar:
+ embedding.step = i + ititial_step
+
+ if embedding.step > steps:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([text])
+ loss = shared.sd_model(x.unsqueeze(0), c)[0]
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ embedding.save(last_saved_file)
+
+ if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
+
+Loss: {losses.mean():.7f}
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
+"""
+
+ embedding.cached_checksum = None
+ embedding.save(filename)
+
+ return embedding, filename
+
--
cgit v1.2.3
From 88ec0cf5571883d84abd09196652b3679e359f2e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 19:40:51 +0300
Subject: fix for incorrect embedding token length calculation (will break
seeds that use embeddings, you're welcome!) add option to input
initialization text for embeddings
---
modules/sd_hijack.py | 8 ++++----
modules/textual_inversion/textual_inversion.py | 13 +++++--------
modules/textual_inversion/ui.py | 4 ++--
modules/ui.py | 2 ++
4 files changed, 13 insertions(+), 14 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fd57e5c5..3fa06242 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if embedding is None:
remade_tokens.append(token)
@@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
- i += emb_len
+ i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
@@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
- i += emb_len
+ i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c0baaace..0c50161d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -117,24 +117,21 @@ class EmbeddingDatabase:
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
- return None
+ return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
- return embedding
+ return embedding, len(ids)
- return None
+ return None, None
-
-def create_embedding(name, num_vectors_per_token):
- init_text = '*'
-
+def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index ce3677a9..66c43ffb 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared
-def create_embedding(name, nvpt):
- filename = ti.create_embedding(name, nvpt)
+def create_embedding(name, initialization_text, nvpt):
+ filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
diff --git a/modules/ui.py b/modules/ui.py
index 3b81a4f7..eca50df0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="Create a new embedding
")
new_embedding_name = gr.Textbox(label="Name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
@@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
+ initialization_text,
nvpt,
],
outputs=[
--
cgit v1.2.3
From 71fe7fa49f5eb1a2c89932a9d217ed153c12fc8b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 19:56:37 +0300
Subject: fix using aaaa-100 embedding when the prompt has aaaa-10000 and you
have both aaaa-100 and aaaa-10000 in the directory with embeddings.
---
modules/textual_inversion/textual_inversion.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 0c50161d..9d2241ce 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -57,7 +57,8 @@ class EmbeddingDatabase:
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
- self.ids_lookup[first_id].append((ids, embedding))
+
+ self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
--
cgit v1.2.3
From 4ec4af6e0b7addeee5221a03f32d117ccdc875d9 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 20:15:25 +0300
Subject: add checkpoint info to saved embeddings
---
modules/textual_inversion/textual_inversion.py | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 9d2241ce..1183aab7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,7 +7,7 @@ import tqdm
import html
import datetime
-from modules import shared, devices, sd_hijack, processing
+from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -17,6 +17,8 @@ class Embedding:
self.name = name
self.step = step
self.cached_checksum = None
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
def save(self, filename):
embedding_data = {
@@ -24,6 +26,8 @@ class Embedding:
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
+ "sd_checkpoint": self.sd_checkpoint,
+ "sd_checkpoint_name": self.sd_checkpoint_name,
}
torch.save(embedding_data, filename)
@@ -41,6 +45,7 @@ class Embedding:
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
+
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
@@ -96,6 +101,8 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
@@ -249,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}
"""
+ checkpoint = sd_models.select_checkpoint()
+
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
embedding.save(filename)
--
cgit v1.2.3
From c7543d4940da672d970124ae8f2fec9de7bdc1da Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 22:41:21 +0300
Subject: preprocessing for textual inversion added
---
modules/interrogate.py | 1 +
modules/textual_inversion/preprocess.py | 75 ++++++++++++++++++++++++++
modules/textual_inversion/textual_inversion.py | 1 +
modules/textual_inversion/ui.py | 14 +++--
modules/ui.py | 36 +++++++++++++
5 files changed, 124 insertions(+), 3 deletions(-)
create mode 100644 modules/textual_inversion/preprocess.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/interrogate.py b/modules/interrogate.py
index f62a4745..eed87144 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+
class InterrogateModels:
blip_model = None
clip_model = None
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
new file mode 100644
index 00000000..209e928f
--- /dev/null
+++ b/modules/textual_inversion/preprocess.py
@@ -0,0 +1,75 @@
+import os
+from PIL import Image, ImageOps
+import tqdm
+
+from modules import shared, images
+
+
+def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
+ size = 512
+ src = os.path.abspath(process_src)
+ dst = os.path.abspath(process_dst)
+
+ assert src != dst, 'same directory specified as source and desitnation'
+
+ os.makedirs(dst, exist_ok=True)
+
+ files = os.listdir(src)
+
+ shared.state.textinfo = "Preprocessing..."
+ shared.state.job_count = len(files)
+
+ if process_caption:
+ shared.interrogator.load()
+
+ def save_pic_with_caption(image, index):
+ if process_caption:
+ caption = "-" + shared.interrogator.generate_caption(image)
+ else:
+ caption = ""
+
+ image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
+ subindex[0] += 1
+
+ def save_pic(image, index):
+ save_pic_with_caption(image, index)
+
+ if process_flip:
+ save_pic_with_caption(ImageOps.mirror(image), index)
+
+ for index, imagefile in enumerate(tqdm.tqdm(files)):
+ subindex = [0]
+ filename = os.path.join(src, imagefile)
+ img = Image.open(filename).convert("RGB")
+
+ if shared.state.interrupted:
+ break
+
+ ratio = img.height / img.width
+ is_tall = ratio > 1.35
+ is_wide = ratio < 1 / 1.35
+
+ if process_split and is_tall:
+ img = img.resize((size, size * img.height // img.width))
+
+ top = img.crop((0, 0, size, size))
+ save_pic(top, index)
+
+ bot = img.crop((0, img.height - size, size, img.height))
+ save_pic(bot, index)
+ elif process_split and is_wide:
+ img = img.resize((size * img.width // img.height, size))
+
+ left = img.crop((0, 0, size, size))
+ save_pic(left, index)
+
+ right = img.crop((img.width - size, 0, img.width, size))
+ save_pic(right, index)
+ else:
+ img = images.resize_image(1, img, size, size)
+ save_pic(img, index)
+
+ shared.state.nextjob()
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1183aab7..d4e250d8 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,6 +7,7 @@ import tqdm
import html
import datetime
+
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 633037d8..f19ac5e0 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -2,24 +2,31 @@ import html
import gradio as gr
-import modules.textual_inversion.textual_inversion as ti
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
- filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+def preprocess(*args):
+ modules.textual_inversion.preprocess.preprocess(*args)
+
+ return "Preprocessing finished.", ""
+
+
def train_embedding(*args):
try:
sd_hijack.undo_optimizations()
- embedding, filename = ti.train_embedding(*args)
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
@@ -30,3 +37,4 @@ Embedding saved to {html.escape(filename)}
raise
finally:
sd_hijack.apply_optimizations()
+
diff --git a/modules/ui.py b/modules/ui.py
index 8912deff..e7bde53b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -961,6 +961,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
+ gr.HTML(value="See wiki for detailed explanation.
")
+
gr.HTML(value="Create a new embedding
")
new_embedding_name = gr.Textbox(label="Name")
@@ -974,6 +976,24 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
+ with gr.Group():
+ gr.HTML(value="Preprocess images
")
+
+ process_src = gr.Textbox(label='Source directory')
+ process_dst = gr.Textbox(label='Destination directory')
+
+ with gr.Row():
+ process_flip = gr.Checkbox(label='Flip')
+ process_split = gr.Checkbox(label='Split into two')
+ process_caption = gr.Checkbox(label='Add caption')
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ run_preprocess = gr.Button(value="Preprocess", variant='primary')
+
with gr.Group():
gr.HTML(value="Train an embedding; must specify a directory with a set of 512x512 images
")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
@@ -1018,6 +1038,22 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ run_preprocess.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ process_src,
+ process_dst,
+ process_flip,
+ process_split,
+ process_caption,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ],
+ )
+
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
--
cgit v1.2.3
From 6785331e22d6a488fbf5905fab56d7fec867e038 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 22:59:01 +0300
Subject: keep textual inversion dataset latents in CPU memory to save a bit of
VRAM
---
modules/textual_inversion/dataset.py | 2 ++
modules/textual_inversion/textual_inversion.py | 3 +++
modules/ui.py | 4 ++--
3 files changed, 7 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7e134a08..e8394ff6 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,6 +8,7 @@ from torchvision import transforms
import random
import tqdm
+from modules import devices
class PersonalizedBase(Dataset):
@@ -47,6 +48,7 @@ class PersonalizedBase(Dataset):
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
+ init_latent = init_latent.to(devices.cpu)
self.dataset.append((init_latent, filename_tokens))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d4e250d8..8686f534 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -212,7 +212,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
with torch.autocast("cuda"):
c = cond_model([text])
+
+ x = x.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ del x
losses[embedding.step % losses.shape[0]] = loss.item()
diff --git a/modules/ui.py b/modules/ui.py
index e7bde53b..d9d02ece 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1002,8 +1002,8 @@ def create_ui(wrap_gradio_gpu_call):
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
steps = gr.Number(label='Max steps', value=100000, precision=0)
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
with gr.Row():
with gr.Column(scale=2):
--
cgit v1.2.3
From 2865ef4b9ab16d56326cc805541bebcf01d099bc Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 3 Oct 2022 13:10:03 +0300
Subject: fix broken date in TI
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8686f534..cd9f3498 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
--
cgit v1.2.3
From 5841990b0df04906da7321beef6f7f7902b7d57b Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 05:38:38 +0100
Subject: Update textual_inversion.py
---
modules/textual_inversion/textual_inversion.py | 25 ++++++++++++++++++++++---
1 file changed, 22 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..f6316020 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,6 +7,9 @@ import tqdm
import html
import datetime
+from PIL import Image, PngImagePlugin
+import base64
+from io import BytesIO
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -80,7 +83,15 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if filename.upper().endswith('.PNG'):
+ embed_image = Image.open(path)
+ if 'sd-embedding' in embed_image.text:
+ embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
+ data = torch.load(BytesIO(embeddingData), map_location="cpu")
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -244,7 +255,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
image = processed.images[0]
shared.state.current_image = image
- image.save(last_saved_image)
+
+ if save_image_with_stored_embedding:
+ info = PngImagePlugin.PngInfo()
+ info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
+ image.save(last_saved_image, "PNG", pnginfo=info)
+ else:
+ image.save(last_saved_image)
+
+
last_saved_image += f", prompt: {text}"
--
cgit v1.2.3
From 03694e1f9915e34cf7d9a31073f1a1a9def2909f Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 21:58:14 +0100
Subject: add embedding load and save from b64 json
---
modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++--------
1 file changed, 21 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f6316020..1b7f8906 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,9 +7,11 @@ import tqdm
import html
import datetime
-from PIL import Image, PngImagePlugin
+from PIL import Image,PngImagePlugin
+from ..images import captionImge
+import numpy as np
import base64
-from io import BytesIO
+import json
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -87,9 +89,9 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
- if 'sd-embedding' in embed_image.text:
- embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
- data = torch.load(BytesIO(embeddingData), map_location="cpu")
+ if 'sd-ti-embedding' in embed_image.text:
+ data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if save_image_with_stored_embedding:
info = PngImagePlugin.PngInfo()
- info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
- image.save(last_saved_image, "PNG", pnginfo=info)
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embeddingToB64(data))
+
+ pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
+
+ caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
+ caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
+ caption_stepcount = data.get('step',0)
+ caption_stepcount = caption_stepcount if caption_stepcount else 0
+
+ post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
+ caption_stepcount))]
+ captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
+ captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else:
image.save(last_saved_image)
-
-
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step
--
cgit v1.2.3
From 969bd8256e5b4f1007d3cc653723d4ad50a92528 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 22:02:28 +0100
Subject: add alternate checkpoint hash source
---
modules/textual_inversion/textual_inversion.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1b7f8906..d7813084 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -265,8 +265,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
- caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
- caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
+ caption_checkpoint_hash = data.get('sd_checkpoint')
+ if caption_checkpoint_hash is None:
+ caption_checkpoint_hash = data.get('hash')
+ caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN'
+
caption_stepcount = data.get('step',0)
caption_stepcount = caption_stepcount if caption_stepcount else 0
--
cgit v1.2.3
From 5d12ec82d3e13f5ff4c55db2930e4e10aed7015a Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 22:05:09 +0100
Subject: add encoder and decoder classes
---
modules/textual_inversion/textual_inversion.py | 21 +++++++++++++++++++++
1 file changed, 21 insertions(+)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d7813084..44d4e08b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -16,6 +16,27 @@ import json
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, o)
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+ def object_hook(self, d):
+ if 'EMBEDDINGTENSOR' in d:
+ return torch.from_numpy(np.array(d['EMBEDDINGTENSOR']))
+ return d
+
+def embeddingToB64(data):
+ d = json.dumps(data,cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+def EmbeddingFromB64(data):
+ d = base64.b64decode(data)
+ return json.loads(d,cls=EmbeddingDecoder)
class Embedding:
def __init__(self, vec, name, step=None):
--
cgit v1.2.3
From d0184b8f76ce492da699f1926f34b57cd095242e Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 22:06:12 +0100
Subject: change json tensor key name
---
modules/textual_inversion/textual_inversion.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 44d4e08b..ae8d207d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -19,15 +19,15 @@ import modules.textual_inversion.dataset
class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
- return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()}
+ return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, o)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
- if 'EMBEDDINGTENSOR' in d:
- return torch.from_numpy(np.array(d['EMBEDDINGTENSOR']))
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d
def embeddingToB64(data):
--
cgit v1.2.3
From 66846105103cfc282434d0dc2102910160b7a633 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 22:06:42 +0100
Subject: correct case on embeddingFromB64
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ae8d207d..d2b95fa3 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -34,7 +34,7 @@ def embeddingToB64(data):
d = json.dumps(data,cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
-def EmbeddingFromB64(data):
+def embeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
--
cgit v1.2.3
From 96f1e6be59316ec640cab2435fa95b3688194906 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 22:14:50 +0100
Subject: source checkpoint hash from current checkpoint
---
modules/textual_inversion/textual_inversion.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d2b95fa3..b16fa84e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -286,10 +286,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
- caption_checkpoint_hash = data.get('sd_checkpoint')
- if caption_checkpoint_hash is None:
- caption_checkpoint_hash = data.get('hash')
- caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN'
+ checkpoint = sd_models.select_checkpoint()
+ caption_checkpoint_hash = checkpoint.hash
caption_stepcount = data.get('step',0)
caption_stepcount = caption_stepcount if caption_stepcount else 0
--
cgit v1.2.3
From 01fd9cf0d28d8b71a113ab1aa62accfe7f0d9c51 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 9 Oct 2022 22:17:02 +0100
Subject: change source of step count
---
modules/textual_inversion/textual_inversion.py | 10 ++--------
1 file changed, 2 insertions(+), 8 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index b16fa84e..e4f339b8 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -285,15 +285,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
info.add_text("sd-ti-embedding", embeddingToB64(data))
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
-
checkpoint = sd_models.select_checkpoint()
- caption_checkpoint_hash = checkpoint.hash
-
- caption_stepcount = data.get('step',0)
- caption_stepcount = caption_stepcount if caption_stepcount else 0
-
- post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
- caption_stepcount))]
+ post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash,
+ embedding.step))]
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else:
--
cgit v1.2.3
From d6a599ef9ba18a66ae79b50f2945af5788fdda8f Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 10 Oct 2022 00:07:52 +0100
Subject: change caption method
---
modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++--------
1 file changed, 21 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e4f339b8..21596e78 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -8,7 +8,7 @@ import html
import datetime
from PIL import Image,PngImagePlugin
-from ..images import captionImge
+from ..images import captionImageOverlay
import numpy as np
import base64
import json
@@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
else:
images_dir = None
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.current_image = image
- if save_image_with_stored_embedding:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data))
- pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
+ title = "<{}>".format(data.get('name','???'))
checkpoint = sd_models.select_checkpoint()
- post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash,
- embedding.step))]
- captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
- captioned_image.save(last_saved_image, "PNG", pnginfo=info)
- else:
- image.save(last_saved_image)
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '[{}]'.format(embedding.step)
+
+ captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+
+ image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
--
cgit v1.2.3
From e2c2925eb4d634b186de2c76798162ec56e2f869 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 10 Oct 2022 00:12:53 +0100
Subject: remove braces from steps
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 21596e78..9a18ee5c 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -297,7 +297,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '[{}]'.format(embedding.step)
+ footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
--
cgit v1.2.3
From 3110f895b2718a3a25aae419fdf5c87c177ec9f4 Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Mon, 10 Oct 2022 17:07:46 +0900
Subject: Textual Inversion: Added custom training image size and number of
repeats per input image in a single epoch
---
modules/textual_inversion/dataset.py | 6 +++---
modules/textual_inversion/preprocess.py | 4 ++--
modules/textual_inversion/textual_inversion.py | 15 ++++++++++++---
modules/ui.py | 8 +++++++-
4 files changed, 24 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..acc4ce59 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -15,13 +15,13 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
- self.width = width
- self.height = height
+ self.width = size
+ self.height = size
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..b3de6fd7 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,8 +7,8 @@ import tqdm
from modules import shared, images
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
+def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption):
+ size = process_size
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..e34dc2e8 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,6 +6,7 @@ import torch
import tqdm
import html
import datetime
+import math
from modules import shared, devices, sd_hijack, processing, sd_models
@@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -182,7 +183,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -200,6 +201,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if ititial_step > steps:
return embedding, filename
+ tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
+ epoch_len = (tr_img_len * num_repeats) + tr_img_len
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -223,7 +227,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ epoch_num = math.floor(embedding.step / epoch_len)
+ epoch_step = embedding.step - (epoch_num * epoch_len)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -236,6 +243,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
sd_model=shared.sd_model,
prompt=text,
steps=20,
+ height=training_size,
+ width=training_size,
do_not_save_grid=True,
do_not_save_samples=True,
)
diff --git a/modules/ui.py b/modules/ui.py
index 2231a8ed..f821fd8d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1029,6 +1029,7 @@ def create_ui(wrap_gradio_gpu_call):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
+ process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1043,13 +1044,15 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group():
- gr.HTML(value="Train an embedding; must specify a directory with a set of 512x512 images
")
+ gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
+ num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
process_src,
process_dst,
+ process_size,
process_flip,
process_split,
process_caption,
@@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate,
dataset_directory,
log_directory,
+ training_size,
steps,
+ num_repeats,
create_image_every,
save_embedding_every,
template_file,
--
cgit v1.2.3
From 4ee7519fc2e459ce8eff1f61f1655afba393357c Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Mon, 10 Oct 2022 17:31:33 +0900
Subject: Fixed progress bar output for epoch
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e34dc2e8..769682ea 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -228,7 +228,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer.step()
epoch_num = math.floor(embedding.step / epoch_len)
- epoch_step = embedding.step - (epoch_num * epoch_len)
+ epoch_step = embedding.step - (epoch_num * epoch_len) + 1
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
--
cgit v1.2.3
From 04c745ea4f81518999927fee5f78500560c25e29 Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Mon, 10 Oct 2022 22:35:35 +0900
Subject: Custom Width and Height
---
modules/textual_inversion/dataset.py | 7 +++----
modules/textual_inversion/preprocess.py | 19 ++++++++++---------
modules/textual_inversion/textual_inversion.py | 11 +++++------
modules/ui.py | 12 ++++++++----
4 files changed, 26 insertions(+), 23 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index acc4ce59..bcf772d2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -15,13 +15,12 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
- self.size = size
- self.width = size
- self.height = size
+ self.width = width
+ self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index b3de6fd7..d7efdef2 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,8 +7,9 @@ import tqdm
from modules import shared, images
-def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption):
- size = process_size
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
@@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_size, process_flip, process_spl
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ img = img.resize((width, height * img.height // img.width))
- top = img.crop((0, 0, size, size))
+ top = img.crop((0, 0, width, height))
save_pic(top, index)
- bot = img.crop((0, img.height - size, size, img.height))
+ bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ img = img.resize((width * img.width // img.height, height))
- left = img.crop((0, 0, size, size))
+ left = img.crop((0, 0, width, height))
save_pic(left, index)
- right = img.crop((img.width - size, 0, img.width, size))
+ right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
else:
- img = images.resize_image(1, img, size, size)
+ img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 769682ea..5965c5a0 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,7 +6,6 @@ import torch
import tqdm
import html
import datetime
-import math
from modules import shared, devices, sd_hijack, processing, sd_models
@@ -157,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -183,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -227,7 +226,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
loss.backward()
optimizer.step()
- epoch_num = math.floor(embedding.step / epoch_len)
+ epoch_num = embedding.step // epoch_len
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
@@ -243,8 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
sd_model=shared.sd_model,
prompt=text,
steps=20,
- height=training_size,
- width=training_size,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
diff --git a/modules/ui.py b/modules/ui.py
index f821fd8d..8c06ad7c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1029,7 +1029,8 @@ def create_ui(wrap_gradio_gpu_call):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
- process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1050,7 +1051,8 @@ def create_ui(wrap_gradio_gpu_call):
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
- training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1095,7 +1097,8 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
process_src,
process_dst,
- process_size,
+ process_width,
+ process_height,
process_flip,
process_split,
process_caption,
@@ -1114,7 +1117,8 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate,
dataset_directory,
log_directory,
- training_size,
+ training_width,
+ training_height,
steps,
num_repeats,
create_image_every,
--
cgit v1.2.3
From ea00c1624bbb0dcb5be07f59c9509061baddf5b1 Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Mon, 10 Oct 2022 17:07:46 +0900
Subject: Textual Inversion: Added custom training image size and number of
repeats per input image in a single epoch
---
modules/textual_inversion/dataset.py | 6 +++---
modules/textual_inversion/preprocess.py | 4 ++--
modules/textual_inversion/textual_inversion.py | 15 ++++++++++++---
modules/ui.py | 8 +++++++-
4 files changed, 24 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..acc4ce59 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -15,13 +15,13 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
- self.width = width
- self.height = height
+ self.width = size
+ self.height = size
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..b3de6fd7 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,8 +7,8 @@ import tqdm
from modules import shared, images
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
+def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption):
+ size = process_size
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..e34dc2e8 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,6 +6,7 @@ import torch
import tqdm
import html
import datetime
+import math
from modules import shared, devices, sd_hijack, processing, sd_models
@@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -182,7 +183,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -200,6 +201,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if ititial_step > steps:
return embedding, filename
+ tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
+ epoch_len = (tr_img_len * num_repeats) + tr_img_len
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -223,7 +227,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ epoch_num = math.floor(embedding.step / epoch_len)
+ epoch_step = embedding.step - (epoch_num * epoch_len)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -236,6 +243,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
sd_model=shared.sd_model,
prompt=text,
steps=20,
+ height=training_size,
+ width=training_size,
do_not_save_grid=True,
do_not_save_samples=True,
)
diff --git a/modules/ui.py b/modules/ui.py
index 2231a8ed..f821fd8d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1029,6 +1029,7 @@ def create_ui(wrap_gradio_gpu_call):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
+ process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1043,13 +1044,15 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group():
- gr.HTML(value="Train an embedding; must specify a directory with a set of 512x512 images
")
+ gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
+ num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
process_src,
process_dst,
+ process_size,
process_flip,
process_split,
process_caption,
@@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate,
dataset_directory,
log_directory,
+ training_size,
steps,
+ num_repeats,
create_image_every,
save_embedding_every,
template_file,
--
cgit v1.2.3
From 6ad3a53e368d36535de1a4fca73b3bb78fd40654 Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Mon, 10 Oct 2022 17:31:33 +0900
Subject: Fixed progress bar output for epoch
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e34dc2e8..769682ea 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -228,7 +228,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer.step()
epoch_num = math.floor(embedding.step / epoch_len)
- epoch_step = embedding.step - (epoch_num * epoch_len)
+ epoch_step = embedding.step - (epoch_num * epoch_len) + 1
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
--
cgit v1.2.3
From 7a20f914eddfdf09c0ccced157ec108205bc3d0f Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Mon, 10 Oct 2022 22:35:35 +0900
Subject: Custom Width and Height
---
modules/textual_inversion/dataset.py | 7 +++----
modules/textual_inversion/preprocess.py | 19 ++++++++++---------
modules/textual_inversion/textual_inversion.py | 11 +++++------
modules/ui.py | 12 ++++++++----
4 files changed, 26 insertions(+), 23 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index acc4ce59..bcf772d2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -15,13 +15,12 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
- self.size = size
- self.width = size
- self.height = size
+ self.width = width
+ self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index b3de6fd7..d7efdef2 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,8 +7,9 @@ import tqdm
from modules import shared, images
-def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption):
- size = process_size
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
@@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_size, process_flip, process_spl
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ img = img.resize((width, height * img.height // img.width))
- top = img.crop((0, 0, size, size))
+ top = img.crop((0, 0, width, height))
save_pic(top, index)
- bot = img.crop((0, img.height - size, size, img.height))
+ bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ img = img.resize((width * img.width // img.height, height))
- left = img.crop((0, 0, size, size))
+ left = img.crop((0, 0, width, height))
save_pic(left, index)
- right = img.crop((img.width - size, 0, img.width, size))
+ right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
else:
- img = images.resize_image(1, img, size, size)
+ img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 769682ea..5965c5a0 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,7 +6,6 @@ import torch
import tqdm
import html
import datetime
-import math
from modules import shared, devices, sd_hijack, processing, sd_models
@@ -157,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -183,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -227,7 +226,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
loss.backward()
optimizer.step()
- epoch_num = math.floor(embedding.step / epoch_len)
+ epoch_num = embedding.step // epoch_len
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
@@ -243,8 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
sd_model=shared.sd_model,
prompt=text,
steps=20,
- height=training_size,
- width=training_size,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
diff --git a/modules/ui.py b/modules/ui.py
index f821fd8d..8c06ad7c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1029,7 +1029,8 @@ def create_ui(wrap_gradio_gpu_call):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
- process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1050,7 +1051,8 @@ def create_ui(wrap_gradio_gpu_call):
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
- training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1095,7 +1097,8 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
process_src,
process_dst,
- process_size,
+ process_width,
+ process_height,
process_flip,
process_split,
process_caption,
@@ -1114,7 +1117,8 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate,
dataset_directory,
log_directory,
- training_size,
+ training_width,
+ training_height,
steps,
num_repeats,
create_image_every,
--
cgit v1.2.3
From 707a431100362645e914042bb344d08439f48ac8 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 10 Oct 2022 15:34:49 +0100
Subject: add pixel data footer
---
modules/textual_inversion/textual_inversion.py | 48 ++++++++++++++++++++++++--
1 file changed, 46 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7a24192e..6fb64691 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@ from ..images import captionImageOverlay
import numpy as np
import base64
import json
+import zlib
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -20,7 +21,7 @@ class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
- return json.JSONEncoder.default(self, o)
+ return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
@@ -38,6 +39,45 @@ def embeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
+def appendImageDataFooter(image,data):
+ d = 3
+ data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
+ dnp = np.frombuffer(data_compressed,np.uint8).copy()
+ w = image.size[0]
+ next_size = dnp.shape[0] + (w-(dnp.shape[0]%w))
+ next_size = next_size + ((w*d)-(next_size%(w*d)))
+ dnp.resize(next_size)
+ dnp = dnp.reshape((-1,w,d))
+ print(dnp.shape)
+ im = Image.fromarray(dnp,mode='RGB')
+ background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0))
+ background.paste(image,(0,0))
+ background.paste(im,(0,image.size[1]+1))
+ return background
+
+def crop_black(img,tol=0):
+ mask = (img>tol).all(2)
+ mask0,mask1 = mask.any(0),mask.any(1)
+ col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
+ row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end,col_start:col_end]
+
+def extractImageDataFooter(image):
+ d=3
+ outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
+ lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
+ if lastRow[0].shape[0] == 0:
+ print('Image data block not found.')
+ return None
+ lastRow = lastRow[0]
+
+ lastRow = lastRow.max()
+
+ dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes()
+ print(lastRow)
+ data = zlib.decompress(dataBlock)
+ return json.loads(data,cls=EmbeddingDecoder)
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
@@ -113,6 +153,9 @@ class EmbeddingDatabase:
if 'sd-ti-embedding' in embed_image.text:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
+ else:
+ data = extractImageDataFooter(embed_image)
+ name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@@ -190,7 +233,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -308,6 +351,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
+ captioned_image = appendImageDataFooter(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
--
cgit v1.2.3
From df6d0d9286279c41c4c67460c3158fa268697524 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 10 Oct 2022 15:43:09 +0100
Subject: convert back to rgb as some hosts add alpha
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6fb64691..667a7cf2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -64,7 +64,7 @@ def crop_black(img,tol=0):
def extractImageDataFooter(image):
d=3
- outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
if lastRow[0].shape[0] == 0:
print('Image data block not found.')
--
cgit v1.2.3
From bc3e183b739913e7be91213a256f038b10eb71e9 Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Tue, 11 Oct 2022 04:30:13 +0900
Subject: Textual Inversion: Preprocess and Training will only pick-up image
files
---
modules/textual_inversion/dataset.py | 3 ++-
modules/textual_inversion/preprocess.py | 3 ++-
modules/textual_inversion/textual_inversion.py | 3 ++-
3 files changed, 6 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index bcf772d2..d4baf066 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -22,6 +22,7 @@ class PersonalizedBase(Dataset):
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+ self.extns = [".jpg",".jpeg",".png"]
self.dataset = []
@@ -32,7 +33,7 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
- self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in self.extns]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
image = Image.open(path)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index d7efdef2..b6c78cf8 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,12 +12,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
+ extns = [".jpg",".jpeg",".png"]
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True)
- files = os.listdir(src)
+ files = [i for i in os.listdir(src) if os.path.splitext(i.casefold())[1] in extns]
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5965c5a0..45397be9 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -161,6 +161,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
+ extns = [".jpg",".jpeg",".png"]
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
@@ -200,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
+ tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns])
epoch_len = (tr_img_len * num_repeats) + tr_img_len
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
--
cgit v1.2.3
From 2536ecbb1790da2af0d61b6a26f38732cba665cd Mon Sep 17 00:00:00 2001
From: Fampai <>
Date: Mon, 10 Oct 2022 17:10:29 -0400
Subject: Refactored learning rate code
---
modules/textual_inversion/textual_inversion.py | 51 ++++++++++++++++++++++++--
modules/ui.py | 2 +-
2 files changed, 48 insertions(+), 5 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5965c5a0..c64a4598 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -189,8 +189,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = ""
@@ -203,12 +201,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
epoch_len = (tr_img_len * num_repeats) + tr_img_len
+ scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(scheduleIter)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
- if embedding.step > steps:
- break
+ if embedding.step > end_step:
+ try:
+ (learn_rate, end_step) = next(scheduleIter)
+ except:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
@@ -277,3 +287,36 @@ Last saved image: {html.escape(last_saved_image)}
return embedding, filename
+class LearnSchedule:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
diff --git a/modules/ui.py b/modules/ui.py
index 8c06ad7c..c9e8355b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1047,7 +1047,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
- learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
+ learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
--
cgit v1.2.3
From 907a88b2d0be320575c2129d8d6a1d4f3a68f9eb Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Tue, 11 Oct 2022 06:33:08 +0900
Subject: Added .webp .bmp
---
modules/textual_inversion/dataset.py | 2 +-
modules/textual_inversion/preprocess.py | 2 +-
modules/textual_inversion/textual_inversion.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index d4baf066..0dc54fb7 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -22,7 +22,7 @@ class PersonalizedBase(Dataset):
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
- self.extns = [".jpg",".jpeg",".png"]
+ self.extns = [".jpg",".jpeg",".png",".webp",".bmp"]
self.dataset = []
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index b6c78cf8..8290abe8 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
- extns = [".jpg",".jpeg",".png"]
+ extns = [".jpg",".jpeg",".png",".webp",".bmp"]
assert src != dst, 'same directory specified as source and destination'
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index a03b299c..33c923d1 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -161,7 +161,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
- extns = [".jpg",".jpeg",".png"]
+ extns = [".jpg",".jpeg",".png",".webp",".bmp"]
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
--
cgit v1.2.3
From 315d5a8ed975c88f670bc484f40a23fbf3a77b63 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 10 Oct 2022 23:14:44 +0100
Subject: update data dis[play style
---
modules/textual_inversion/textual_inversion.py | 88 +++++++++++++++++++-------
1 file changed, 65 insertions(+), 23 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 667a7cf2..95eebea7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -39,20 +39,59 @@ def embeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
-def appendImageDataFooter(image,data):
+def xorBlock(block):
+ return np.bitwise_xor(block.astype(np.uint8),
+ ((np.random.RandomState(0xDEADBEEF).random(block.shape)*255).astype(np.uint8)) & 0x0F )
+
+def styleBlock(block,sequence):
+ im = Image.new('RGB',(block.shape[1],block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i=0
+ for x in range(-6,im.size[0],8):
+ for yi,y in enumerate(range(-6,im.size[1],8)):
+ offset=0
+ if yi%2==0:
+ offset=4
+ shade = sequence[i%len(sequence)]
+ i+=1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) )
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+ return block ^ fg
+
+def insertImageDataEmbed(image,data):
d = 3
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
dnp = np.frombuffer(data_compressed,np.uint8).copy()
- w = image.size[0]
- next_size = dnp.shape[0] + (w-(dnp.shape[0]%w))
- next_size = next_size + ((w*d)-(next_size%(w*d)))
- dnp.resize(next_size)
- dnp = dnp.reshape((-1,w,d))
- print(dnp.shape)
- im = Image.fromarray(dnp,mode='RGB')
- background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0))
- background.paste(image,(0,0))
- background.paste(im,(0,image.size[1]+1))
+ dnphigh = dnp >> 4
+ dnplow = dnp & 0x0F
+
+ h = image.size[1]
+ next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h))
+ next_size = next_size + ((h*d)-(next_size%(h*d)))
+
+ dnplow.resize(next_size)
+ dnplow = dnplow.reshape((h,-1,d))
+
+ dnphigh.resize(next_size)
+ dnphigh = dnphigh.reshape((h,-1,d))
+
+ edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8)
+
+ dnplow = styleBlock(dnplow,sequence=edgeStyleWeights)
+ dnplow = xorBlock(dnplow)
+ dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1])
+ dnphigh = xorBlock(dnphigh)
+
+ imlow = Image.fromarray(dnplow,mode='RGB')
+ imhigh = Image.fromarray(dnphigh,mode='RGB')
+
+ background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0))
+ background.paste(imlow,(0,0))
+ background.paste(image,(imlow.size[0]+1,0))
+ background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0))
+
return background
def crop_black(img,tol=0):
@@ -62,19 +101,22 @@ def crop_black(img,tol=0):
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end,col_start:col_end]
-def extractImageDataFooter(image):
+def extractImageDataEmbed(image):
d=3
- outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
- lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
- if lastRow[0].shape[0] == 0:
- print('Image data block not found.')
+ outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
+ blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
+ if blackCols[0].shape[0] < 2:
+ print('No Image data blocks found.')
return None
- lastRow = lastRow[0]
-
- lastRow = lastRow.max()
- dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes()
- print(lastRow)
+ dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8)
+ dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8)
+
+ dataBlocklower = xorBlock(dataBlocklower)
+ dataBlockupper = xorBlock(dataBlockupper)
+
+ dataBlock = (dataBlockupper << 4) | (dataBlocklower)
+ dataBlock = dataBlock.flatten().tobytes()
data = zlib.decompress(dataBlock)
return json.loads(data,cls=EmbeddingDecoder)
@@ -154,7 +196,7 @@ class EmbeddingDatabase:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
else:
- data = extractImageDataFooter(embed_image)
+ data = extractImageDataEmbed(embed_image)
name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@@ -351,7 +393,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
- captioned_image = appendImageDataFooter(captioned_image,data)
+ captioned_image = insertImageDataEmbed(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
--
cgit v1.2.3
From 767202a4c324f9b49f63ab4dabbb5736fe9df6e5 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 10 Oct 2022 23:20:52 +0100
Subject: add dependency
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 95eebea7..f3cacaa0 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,7 +7,7 @@ import tqdm
import html
import datetime
-from PIL import Image,PngImagePlugin
+from PIL import Image,PngImagePlugin,ImageDraw
from ..images import captionImageOverlay
import numpy as np
import base64
--
cgit v1.2.3
From e0fbe6d27e7b4505766c8cb5a4264e1114cf3721 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 10 Oct 2022 23:26:24 +0100
Subject: colour depth conversion fix
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f3cacaa0..ae807268 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -103,7 +103,7 @@ def crop_black(img,tol=0):
def extractImageDataEmbed(image):
d=3
- outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
if blackCols[0].shape[0] < 2:
print('No Image data blocks found.')
--
cgit v1.2.3
From 7aa8fcac1e45c3ad9c6a40df0e44a346afcd5032 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Tue, 11 Oct 2022 04:17:36 +0100
Subject: use simple lcg in xor
---
modules/textual_inversion/textual_inversion.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ae807268..13416a08 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -39,9 +39,15 @@ def embeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
+def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
+ while True:
+ seed = (a * seed + c) % m
+ yield seed
+
def xorBlock(block):
- return np.bitwise_xor(block.astype(np.uint8),
- ((np.random.RandomState(0xDEADBEEF).random(block.shape)*255).astype(np.uint8)) & 0x0F )
+ g = lcg()
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
+ return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F)
def styleBlock(block,sequence):
im = Image.new('RGB',(block.shape[1],block.shape[0]))
--
cgit v1.2.3
From b2368a3bce663f19a7209d9cb38617e635ca6e3c Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Tue, 11 Oct 2022 17:32:46 +0900
Subject: Switched to exception handling
---
modules/textual_inversion/dataset.py | 10 +++++-----
modules/textual_inversion/preprocess.py | 8 +++++---
modules/textual_inversion/textual_inversion.py | 18 ++++++++----------
3 files changed, 18 insertions(+), 18 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 0dc54fb7..4d006366 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -22,7 +22,6 @@ class PersonalizedBase(Dataset):
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
- self.extns = [".jpg",".jpeg",".png",".webp",".bmp"]
self.dataset = []
@@ -33,12 +32,13 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
- self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in self.extns]
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 8290abe8..1a672725 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,13 +12,12 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
- extns = [".jpg",".jpeg",".png",".webp",".bmp"]
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True)
- files = [i for i in os.listdir(src) if os.path.splitext(i.casefold())[1] in extns]
+ files = os.listdir(src)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
@@ -47,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
if shared.state.interrupted:
break
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 33c923d1..91cde04b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -161,7 +161,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
- extns = [".jpg",".jpeg",".png",".webp",".bmp"]
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
@@ -201,10 +200,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns])
-
- epoch_len = (tr_img_len * num_repeats) + tr_img_len
-
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -228,10 +223,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
loss.backward()
optimizer.step()
- epoch_num = embedding.step // epoch_len
- epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -243,9 +238,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
- steps=20,
- height=training_height,
+ steps=28,
+ height=768,
width=training_width,
+ negative_prompt="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, artist name",
+ cfg_scale=7.0,
+ sampler_index=0,
do_not_save_grid=True,
do_not_save_samples=True,
)
--
cgit v1.2.3
From 8bacbca0a1ab9aabcb0ad0cbf070e0006991e98a Mon Sep 17 00:00:00 2001
From: alg-wiki
Date: Tue, 11 Oct 2022 17:35:09 +0900
Subject: Removed my local edits to checkpoint image generation
---
modules/textual_inversion/textual_inversion.py | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 91cde04b..e9ff80c2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -238,12 +238,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
- steps=28,
- height=768,
+ steps=20,
+ height=training_height,
width=training_width,
- negative_prompt="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, artist name",
- cfg_scale=7.0,
- sampler_index=0,
do_not_save_grid=True,
do_not_save_samples=True,
)
--
cgit v1.2.3
From 530103b586109c11fd068eb70ef09503ec6a4caf Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 11 Oct 2022 14:53:02 +0300
Subject: fixes related to merge
---
modules/hypernetwork.py | 103 -------------------------
modules/hypernetwork/hypernetwork.py | 74 +++++++++++-------
modules/hypernetwork/ui.py | 10 +--
modules/sd_hijack_optimizations.py | 3 +-
modules/shared.py | 13 +++-
modules/textual_inversion/textual_inversion.py | 12 +--
modules/ui.py | 5 +-
scripts/xy_grid.py | 3 +-
webui.py | 15 +---
9 files changed, 78 insertions(+), 160 deletions(-)
delete mode 100644 modules/hypernetwork.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
deleted file mode 100644
index 7bbc443e..00000000
--- a/modules/hypernetwork.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import glob
-import os
-import sys
-import traceback
-
-import torch
-
-from ldm.util import default
-from modules import devices, shared
-import torch
-from torch import einsum
-from einops import rearrange, repeat
-
-
-class HypernetworkModule(torch.nn.Module):
- def __init__(self, dim, state_dict):
- super().__init__()
-
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
-
- self.load_state_dict(state_dict, strict=True)
- self.to(devices.device)
-
- def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, filename):
- self.filename = filename
- self.name = os.path.splitext(os.path.basename(filename))[0]
- self.layers = {}
-
- state_dict = torch.load(filename, map_location='cpu')
- for size, sd in state_dict.items():
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
-
-
-def list_hypernetworks(path):
- res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
- name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
- return res
-
-
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- if path is not None:
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork(path)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
-
- shared.loaded_hypernetwork = None
-
-
-def apply_hypernetwork(hypernetwork, context):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is None:
- return context, context
-
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
- return context_k, context_v
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py
index a3d6a47e..aa701bda 100644
--- a/modules/hypernetwork/hypernetwork.py
+++ b/modules/hypernetwork/hypernetwork.py
@@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module):
if state_dict is not None:
self.load_state_dict(state_dict, strict=True)
else:
- self.linear1.weight.data.fill_(0.0001)
- self.linear1.bias.data.fill_(0.0001)
- self.linear2.weight.data.fill_(0.0001)
- self.linear2.bias.data.fill_(0.0001)
+
+ self.linear1.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear1.bias.data.zero_()
+ self.linear2.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear2.bias.data.zero_()
self.to(devices.device)
@@ -92,41 +93,54 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
-def load_hypernetworks(path):
+def list_hypernetworks(path):
res = {}
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ res[name] = filename
+ return res
- for filename in glob.iglob(path + '**/*.pt', recursive=True):
+
+def load_hypernetwork(filename):
+ path = shared.hypernetworks.get(filename, None)
+ if path is not None:
+ print(f"Loading hypernetwork {filename}")
try:
- hn = Hypernetwork()
- hn.load(filename)
- res[hn.name] = hn
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
except Exception:
- print(f"Error loading hypernetwork {filename}", file=sys.stderr)
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if shared.loaded_hypernetwork is not None:
+ print(f"Unloading hypernetwork")
- return res
+ shared.loaded_hypernetwork = None
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
- q = self.to_q(x)
- context = default(context, x)
+ if hypernetwork_layers is None:
+ return context, context
- hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
- if hypernetwork_layers is not None:
- hypernetwork_k, hypernetwork_v = hypernetwork_layers
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
- self.hypernetwork_k = hypernetwork_k
- self.hypernetwork_v = hypernetwork_v
- context_k = hypernetwork_k(context)
- context_v = hypernetwork_v(context)
- else:
- context_k = context
- context_v = context
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'embedding not selected'
- shared.hypernetwork = shared.hypernetworks[hypernetwork_name]
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
@@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
- hypernetwork = shared.hypernetworks[hypernetwork_name]
+ hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
@@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text) in pbar:
hypernetwork.step = i + ititial_step
diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py
index 525f978c..f6d1d0a3 100644
--- a/modules/hypernetwork/ui.py
+++ b/modules/hypernetwork/ui.py
@@ -6,24 +6,24 @@ import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
+from modules.hypernetwork import hypernetwork
def create_hypernetwork(name):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
- hypernetwork.save(fn)
+ hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
+ hypernet.save(fn)
shared.reload_hypernetworks()
- shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
def train_hypernetwork(*args):
- initial_hypernetwork = shared.hypernetwork
+ initial_hypernetwork = shared.loaded_hypernetwork
try:
sd_hijack.undo_optimizations()
@@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception:
raise
finally:
- shared.hypernetwork = initial_hypernetwork
+ shared.loaded_hypernetwork = initial_hypernetwork
sd_hijack.apply_optimizations()
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 25cb67a4..27e571fc 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -8,7 +8,8 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, hypernetwork
+from modules import shared
+from modules.hypernetwork import hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
diff --git a/modules/shared.py b/modules/shared.py
index 14b40d70..8753015e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,8 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, hypernetwork
+from modules import sd_samplers
+from modules.hypernetwork import hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
@@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
-hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
+hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+def reload_hypernetworks():
+ global hypernetworks
+
+ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+ hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
+
class State:
skipped = False
interrupted = False
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5965c5a0..d6977950 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=text,
+ prompt=preview_text,
steps=20,
- height=training_height,
- width=training_width,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
@@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image
image.save(last_saved_image)
- last_saved_image += f", prompt: {text}"
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
diff --git a/modules/ui.py b/modules/ui.py
index 10b1ee3a..df653059 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Group():
gr.HTML(value="Create a new hypernetwork
")
@@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")
with gr.Column():
- create_hypernetwork = gr.Button(value="Create", variant='primary')
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Group():
gr.HTML(value="Preprocess images
")
@@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every,
save_embedding_every,
template_file,
+ preview_image_prompt,
],
outputs=[
ti_output,
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 42e1489c..0af5993c 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -10,7 +10,8 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, hypernetwork
+from modules import images
+from modules.hypernetwork import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
diff --git a/webui.py b/webui.py
index 7c200551..ba2156c8 100644
--- a/webui.py
+++ b/webui.py
@@ -29,6 +29,7 @@ from modules import devices
from modules import modelloader
from modules.paths import script_path
from modules.shared import cmd_opts
+import modules.hypernetwork.hypernetwork
modelloader.cleanup_models()
modules.sd_models.setup_model()
@@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
-def set_hypernetwork():
- shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
-
-
-shared.reload_hypernetworks()
-shared.opts.onchange("sd_hypernetwork", set_hypernetwork)
-set_hypernetwork()
-
-
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
-loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
-shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
+shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui():
@@ -117,7 +108,7 @@ def webui():
prevent_thread_lock=True
)
- app.add_middleware(GZipMiddleware,minimum_size=1000)
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
while 1:
time.sleep(0.5)
--
cgit v1.2.3
From d4ea5f4d8631f778d11efcde397e4a5b8801d43b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 11 Oct 2022 19:03:08 +0300
Subject: add an option to unload models during hypernetwork training to save
VRAM
---
modules/hypernetworks/hypernetwork.py | 25 +++++++++++++++-------
modules/hypernetworks/ui.py | 4 +++-
modules/shared.py | 4 ++++
modules/textual_inversion/dataset.py | 29 ++++++++++++++++++--------
modules/textual_inversion/textual_inversion.py | 2 +-
5 files changed, 46 insertions(+), 18 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b081f14e..4700e1ec 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
@@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
else:
images_dir = None
- cond_model = shared.sd_model.cond_stage_model
-
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
@@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
return hypernetwork, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text) in pbar:
+ for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step
if hypernetwork.step > steps:
@@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break
with torch.autocast("cuda"):
- c = cond_model([text])
-
+ cond = cond.to(devices.device)
x = x.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
+ del cond
losses[hypernetwork.step % losses.shape[0]] = loss.item()
@@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=preview_text,
@@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
processed = processing.process_images(p)
image = processed.images[0]
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
shared.state.current_image = image
image.save(last_saved_image)
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 3541a388..c67facbb 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -5,7 +5,7 @@ import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared
+from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
@@ -41,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)}
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
diff --git a/modules/shared.py b/modules/shared.py
index 20b45f23..c1092ff7 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -228,6 +228,10 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))
+options_templates.update(options_section(('training', "Training"), {
+ "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"),
+}))
+
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 4d006366..f61f40d3 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,14 +8,14 @@ from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
self.placeholder_token = placeholder_token
@@ -32,6 +32,8 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
+ cond_model = shared.sd_model.cond_stage_model
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
@@ -53,7 +55,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens))
+ if include_cond:
+ text = self.create_text(filename_tokens)
+ cond = cond_model([text]).to(devices.cpu)
+ else:
+ cond = None
+
+ self.dataset.append((init_latent, filename_tokens, cond))
self.length = len(self.dataset) * repeats
@@ -64,6 +72,12 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ def create_text(self, filename_tokens):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+ return text
+
def __len__(self):
return self.length
@@ -72,10 +86,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
-
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ x, filename_tokens, cond = self.dataset[index]
- return x, text
+ text = self.create_text(filename_tokens)
+ return x, text, cond
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index bb05cdc6..35f4bd9e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -201,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
return embedding, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text) in pbar:
+ for i, (x, text, _) in pbar:
embedding.step = i + ititial_step
if embedding.step > steps:
--
cgit v1.2.3
From 61788c0538415fa9ca1dd1b306519c116b18bd2c Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Tue, 11 Oct 2022 19:50:50 +0100
Subject: shift embedding logic out of textual_inversion
---
modules/textual_inversion/textual_inversion.py | 125 ++-----------------------
1 file changed, 6 insertions(+), 119 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8c66aeb5..22b4ae7f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,124 +7,11 @@ import tqdm
import html
import datetime
-from PIL import Image,PngImagePlugin,ImageDraw
-from ..images import captionImageOverlay
-import numpy as np
-import base64
-import json
-import zlib
+from PIL import Image,PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
-class EmbeddingEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, torch.Tensor):
- return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
- return json.JSONEncoder.default(self, obj)
-
-class EmbeddingDecoder(json.JSONDecoder):
- def __init__(self, *args, **kwargs):
- json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
- def object_hook(self, d):
- if 'TORCHTENSOR' in d:
- return torch.from_numpy(np.array(d['TORCHTENSOR']))
- return d
-
-def embeddingToB64(data):
- d = json.dumps(data,cls=EmbeddingEncoder)
- return base64.b64encode(d.encode())
-
-def embeddingFromB64(data):
- d = base64.b64decode(data)
- return json.loads(d,cls=EmbeddingDecoder)
-
-def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
- while True:
- seed = (a * seed + c) % m
- yield seed
-
-def xorBlock(block):
- g = lcg()
- randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
- return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F)
-
-def styleBlock(block,sequence):
- im = Image.new('RGB',(block.shape[1],block.shape[0]))
- draw = ImageDraw.Draw(im)
- i=0
- for x in range(-6,im.size[0],8):
- for yi,y in enumerate(range(-6,im.size[1],8)):
- offset=0
- if yi%2==0:
- offset=4
- shade = sequence[i%len(sequence)]
- i+=1
- draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) )
-
- fg = np.array(im).astype(np.uint8) & 0xF0
- return block ^ fg
-
-def insertImageDataEmbed(image,data):
- d = 3
- data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
- dnp = np.frombuffer(data_compressed,np.uint8).copy()
- dnphigh = dnp >> 4
- dnplow = dnp & 0x0F
-
- h = image.size[1]
- next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h))
- next_size = next_size + ((h*d)-(next_size%(h*d)))
-
- dnplow.resize(next_size)
- dnplow = dnplow.reshape((h,-1,d))
-
- dnphigh.resize(next_size)
- dnphigh = dnphigh.reshape((h,-1,d))
-
- edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
- edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8)
-
- dnplow = styleBlock(dnplow,sequence=edgeStyleWeights)
- dnplow = xorBlock(dnplow)
- dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1])
- dnphigh = xorBlock(dnphigh)
-
- imlow = Image.fromarray(dnplow,mode='RGB')
- imhigh = Image.fromarray(dnphigh,mode='RGB')
-
- background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0))
- background.paste(imlow,(0,0))
- background.paste(image,(imlow.size[0]+1,0))
- background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0))
-
- return background
-
-def crop_black(img,tol=0):
- mask = (img>tol).all(2)
- mask0,mask1 = mask.any(0),mask.any(1)
- col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
- row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
- return img[row_start:row_end,col_start:col_end]
-
-def extractImageDataEmbed(image):
- d=3
- outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
- blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
- if blackCols[0].shape[0] < 2:
- print('No Image data blocks found.')
- return None
-
- dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8)
- dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8)
-
- dataBlocklower = xorBlock(dataBlocklower)
- dataBlockupper = xorBlock(dataBlockupper)
-
- dataBlock = (dataBlockupper << 4) | (dataBlocklower)
- dataBlock = dataBlock.flatten().tobytes()
- data = zlib.decompress(dataBlock)
- return json.loads(data,cls=EmbeddingDecoder)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -199,10 +86,10 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text:
- data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
else:
- data = extractImageDataEmbed(embed_image)
+ data = extract_image_data_embed(embed_image)
name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@@ -393,7 +280,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embeddingToB64(data))
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name','???'))
checkpoint = sd_models.select_checkpoint()
@@ -401,8 +288,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step)
- captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
- captioned_image = insertImageDataEmbed(captioned_image,data)
+ captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right)
+ captioned_image = insert_image_data_embed(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
--
cgit v1.2.3
From d6fcc6b87bc00fcdecea276fe5b7c7945f7a8b14 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 11 Oct 2022 22:03:05 +0300
Subject: apply lr schedule to hypernets
---
modules/hypernetworks/hypernetwork.py | 19 ++++++++---
modules/textual_inversion/learn_schedule.py | 34 ++++++++++++++++++++
modules/textual_inversion/textual_inversion.py | 44 +++-----------------------
modules/ui.py | 2 +-
4 files changed, 54 insertions(+), 45 deletions(-)
create mode 100644 modules/textual_inversion/learn_schedule.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 5608e799..470659df 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -14,6 +14,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class HypernetworkModule(torch.nn.Module):
@@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
for weight in weights:
weight.requires_grad = True
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = ""
@@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > steps:
- break
+ if hypernetwork.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except Exception:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..db720271
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,34 @@
+
+class LearnSchedule:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 47a27faf..7717837d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,6 +10,7 @@ import datetime
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class Embedding:
@@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
- epoch_len = (tr_img_len * num_repeats) + tr_img_len
-
- scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(scheduleIter)
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
@@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > end_step:
try:
- (learn_rate, end_step) = next(scheduleIter)
+ (learn_rate, end_step) = next(schedules)
except:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
@@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}
embedding.save(filename)
return embedding, filename
-
-class LearnSchedule:
- def __init__(self, learn_rate, max_steps, cur_step=0):
- pairs = learn_rate.split(',')
- self.rates = []
- self.it = 0
- self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
- return
- elif step == -1:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self.it < self.maxit:
- self.it += 1
- return self.rates[self.it - 1]
- else:
- raise StopIteration
diff --git a/modules/ui.py b/modules/ui.py
index 2b688e32..1204eef7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03")
+ learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
--
cgit v1.2.3
From aa75d5cfe8c84768b0f5d16f977ddba298677379 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Tue, 11 Oct 2022 20:06:13 +0100
Subject: correct conflict resolution typo
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 22b4ae7f..789383ce 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -169,7 +169,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt)
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
--
cgit v1.2.3
From 91d7ee0d097a7ea203d261b570cd2b834837d9e2 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Tue, 11 Oct 2022 20:09:10 +0100
Subject: update imports
---
modules/textual_inversion/textual_inversion.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 789383ce..ff0a62b3 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -12,6 +12,9 @@ from PIL import Image,PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64,
+ insert_image_data_embed,extract_image_data_embed,
+ caption_image_overlay )
class Embedding:
def __init__(self, vec, name, step=None):
--
cgit v1.2.3
From 5f3317376bb7952bc5145f05f16c1bbd466efc85 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Tue, 11 Oct 2022 20:09:49 +0100
Subject: spacing
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ff0a62b3..485ef46c 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -12,7 +12,7 @@ from PIL import Image,PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
-from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64,
+from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64,
insert_image_data_embed,extract_image_data_embed,
caption_image_overlay )
--
cgit v1.2.3
From 10a2de644f8ea4cfade88e85d768da3480f4c9f0 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 12 Oct 2022 13:15:35 +0100
Subject: formatting
---
modules/textual_inversion/textual_inversion.py | 22 +++++++++++-----------
1 file changed, 11 insertions(+), 11 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 485ef46c..b072d745 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,14 +7,14 @@ import tqdm
import html
import datetime
-from PIL import Image,PngImagePlugin
+from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
-from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64,
- insert_image_data_embed,extract_image_data_embed,
- caption_image_overlay )
+from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
+ insert_image_data_embed, extract_image_data_embed,
+ caption_image_overlay)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -90,10 +90,10 @@ class EmbeddingDatabase:
embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- name = data.get('name',name)
+ name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
- name = data.get('name',name)
+ name = data.get('name', name)
else:
data = torch.load(path, map_location="cpu")
@@ -278,24 +278,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
-
+
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name','???'))
+ title = "<{}>".format(data.get('name', '???'))
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step)
- captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right)
- captioned_image = insert_image_data_embed(captioned_image,data)
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
-
+
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
--
cgit v1.2.3
From c3c8eef9fd5a0c8b26319e32ca4a19b56204e6df Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 12 Oct 2022 20:49:47 +0300
Subject: train: change filename processing to be more simple and configurable
train: make it possible to make text files with prompts train: rework
scheduler so that there's less repeating code in textual inversion and
hypernets train: move epochs setting to options
---
javascript/hints.js | 3 ++
modules/hypernetworks/hypernetwork.py | 40 +++++++++-------------
modules/shared.py | 3 ++
modules/textual_inversion/dataset.py | 47 +++++++++++++++++++-------
modules/textual_inversion/learn_schedule.py | 37 +++++++++++++++++++-
modules/textual_inversion/textual_inversion.py | 35 +++++++------------
modules/ui.py | 2 --
7 files changed, 105 insertions(+), 62 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/javascript/hints.js b/javascript/hints.js
index b81c181b..d51ee14c 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -81,6 +81,9 @@ titles = {
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
+
+ "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
+ "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
}
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8314450a..b6c06d49 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -14,7 +14,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
@@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text, cond) in pbar:
+ for i, entry in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except Exception:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- cond = cond.to(devices.device)
- x = x.to(devices.device)
+ cond = entry.cond.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
@@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
)
processed = processing.process_images(p)
- image = processed.images[0]
+ image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.current_image = image
- image.save(last_saved_image)
-
- last_saved_image += f", prompt: {preview_text}"
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/shared.py b/modules/shared.py
index 42e99741..e64e69fc 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -231,6 +231,9 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+ "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index f61f40d3..67e90afe 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared
import re
-re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, latent=None, filename_text=None):
+ self.filename = filename
+ self.latent = latent
+ self.filename_text = filename_text
+ self.cond = None
+ self.cond_text = None
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception:
continue
+ text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0]
- filename_tokens = re_tag.findall(filename_tokens)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
+
if include_cond:
- text = self.create_text(filename_tokens)
- cond = cond_model([text]).to(devices.cpu)
- else:
- cond = None
+ entry.cond_text = self.create_text(filename_text)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens, cond))
+ self.dataset.append(entry)
self.length = len(self.dataset) * repeats
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
- def create_text(self, filename_tokens):
+ def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens, cond = self.dataset[index]
+ entry = self.dataset[index]
+
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
- text = self.create_text(filename_tokens)
- return x, text, cond
+ return entry
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index db720271..2062726a 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -1,6 +1,12 @@
+import tqdm
-class LearnSchedule:
+
+class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
@@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1]
else:
raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c5153e4a..fa0e33a2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
@@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text, _) in pbar:
+ for i, entry in pbar:
embedding.step = i + ititial_step
- if embedding.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- c = cond_model([text])
+ c = cond_model([entry.cond_text])
- x = x.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
@@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/ui.py b/modules/ui.py
index 2b332267..c42535c8 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1098,7 +1098,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
- num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
@@ -1176,7 +1175,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width,
training_height,
steps,
- num_repeats,
create_image_every,
save_embedding_every,
template_file,
--
cgit v1.2.3
From 1cfc2a18981ee56bdb69a2de7b463a11ad05e329 Mon Sep 17 00:00:00 2001
From: Melan
Date: Wed, 12 Oct 2022 23:36:29 +0200
Subject: Save a csv containing the loss while training
---
modules/hypernetworks/hypernetwork.py | 17 ++++++++++++++++-
modules/textual_inversion/textual_inversion.py | 17 ++++++++++++++++-
modules/ui.py | 3 +++
3 files changed, 35 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b6c06d49..6522078f 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -5,6 +5,7 @@ import os
import sys
import traceback
import tqdm
+import csv
import torch
@@ -174,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out)
-def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, write_csv_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -256,6 +257,20 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
+ print(f"{write_csv_every} > {hypernetwork.step % write_csv_every == 0}, {write_csv_every}")
+ if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
+ write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
+
+ with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
+
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss"])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ csv_writer.writerow({"step": hypernetwork.step,
+ "loss": f"{losses.mean():.7f}"})
+
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..25038a89 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,6 +6,7 @@ import torch
import tqdm
import html
import datetime
+import csv
from PIL import Image, PngImagePlugin
@@ -172,7 +173,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, write_csv_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -256,6 +257,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
+ if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
+
+ with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
+
+ csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ csv_writer.writerow({"epoch": epoch_num + 1,
+ "epoch_step": epoch_step - 1,
+ "loss": f"{losses.mean():.7f}"})
+
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
diff --git a/modules/ui.py b/modules/ui.py
index e07ee0e1..1195c2f1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1096,6 +1096,7 @@ def create_ui(wrap_gradio_gpu_call):
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
+ write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
@@ -1174,6 +1175,7 @@ def create_ui(wrap_gradio_gpu_call):
steps,
create_image_every,
save_embedding_every,
+ write_csv_every,
template_file,
save_image_with_stored_embedding,
preview_image_prompt,
@@ -1195,6 +1197,7 @@ def create_ui(wrap_gradio_gpu_call):
steps,
create_image_every,
save_embedding_every,
+ write_csv_every,
template_file,
preview_image_prompt,
],
--
cgit v1.2.3
From 8636b50aea83f9c743f005722d9f3f8ee9303e00 Mon Sep 17 00:00:00 2001
From: Melan
Date: Thu, 13 Oct 2022 12:37:58 +0200
Subject: Add learn_rate to csv and removed a left-over debug statement
---
modules/hypernetworks/hypernetwork.py | 6 +++---
modules/textual_inversion/textual_inversion.py | 5 +++--
2 files changed, 6 insertions(+), 5 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6522078f..2751a8c8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -257,19 +257,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
- print(f"{write_csv_every} > {hypernetwork.step % write_csv_every == 0}, {write_csv_every}")
if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
- csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss"])
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
if write_csv_header:
csv_writer.writeheader()
csv_writer.writerow({"step": hypernetwork.step,
- "loss": f"{losses.mean():.7f}"})
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 25038a89..b83df079 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -262,14 +262,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
- csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"])
+ csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"])
if write_csv_header:
csv_writer.writeheader()
csv_writer.writerow({"epoch": epoch_num + 1,
"epoch_step": epoch_step - 1,
- "loss": f"{losses.mean():.7f}"})
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
--
cgit v1.2.3
From c344ba3b325459abbf9b0df2c1b18f7bf99805b2 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 14 Oct 2022 20:31:49 +0300
Subject: add option to read generation params for learning previews from
txt2img
---
modules/hypernetworks/hypernetwork.py | 21 ++++++++++++++++-----
modules/textual_inversion/textual_inversion.py | 25 ++++++++++++++++++-------
modules/ui.py | 20 +++++++++++++++++---
3 files changed, 51 insertions(+), 15 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f1248bb7..e5cb1817 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -180,7 +180,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out)
-def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -265,20 +265,31 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
-
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=preview_text,
- steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entry.cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..3d835358 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -172,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -259,18 +259,29 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
-
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=preview_text,
- steps=20,
- height=training_height,
- width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entry.cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0]
diff --git a/modules/ui.py b/modules/ui.py
index 828bfeea..4a04c2cc 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -711,6 +711,18 @@ def create_ui(wrap_gradio_gpu_call):
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
]
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
+ ]
+
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
@@ -1162,7 +1174,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
- preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
@@ -1240,7 +1252,8 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every,
template_file,
save_image_with_stored_embedding,
- preview_image_prompt,
+ preview_from_txt2img,
+ *txt2img_preview_params,
],
outputs=[
ti_output,
@@ -1260,7 +1273,8 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every,
save_embedding_every,
template_file,
- preview_image_prompt,
+ preview_from_txt2img,
+ *txt2img_preview_params,
],
outputs=[
ti_output,
--
cgit v1.2.3
From 03d62538aebeff51713619fe808c953bdb70193d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 14 Oct 2022 22:43:55 +0300
Subject: remove duplicate code for log loss, add step, make it read from
options rather than gradio input
---
modules/hypernetworks/hypernetwork.py | 20 ++++--------
modules/shared.py | 3 +-
modules/textual_inversion/textual_inversion.py | 44 ++++++++++++++++++--------
modules/ui.py | 3 --
4 files changed, 38 insertions(+), 32 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index edb8cba1..59c7ac6e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -15,6 +15,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
+from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -210,7 +211,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -263,19 +264,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
- if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
- write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
-
- with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
-
- csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
-
- if write_csv_header:
- csv_writer.writeheader()
-
- csv_writer.writerow({"step": hypernetwork.step,
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate})
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
diff --git a/modules/shared.py b/modules/shared.py
index 695d29b6..d41a7ab3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -236,7 +236,8 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
- "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+ "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+ "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1f5ace6f..da0d77a0 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -173,6 +173,32 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ epoch = step // epoch_len
+ epoch_step = step - epoch * epoch_len
+
+ csv_writer.writerow({
+ "step": step + 1,
+ "epoch": epoch + 1,
+ "epoch_step": epoch_step + 1,
+ **values,
+ })
+
+
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
@@ -257,20 +283,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
- if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
- write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
-
- with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
-
- csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"])
-
- if write_csv_header:
- csv_writer.writeheader()
-
- csv_writer.writerow({"epoch": epoch_num + 1,
- "epoch_step": epoch_step - 1,
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate})
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
diff --git a/modules/ui.py b/modules/ui.py
index be4a43a7..a08ffc9b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1172,7 +1172,6 @@ def create_ui(wrap_gradio_gpu_call):
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
- write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
@@ -1251,7 +1250,6 @@ def create_ui(wrap_gradio_gpu_call):
steps,
create_image_every,
save_embedding_every,
- write_csv_every,
template_file,
save_image_with_stored_embedding,
preview_from_txt2img,
@@ -1274,7 +1272,6 @@ def create_ui(wrap_gradio_gpu_call):
steps,
create_image_every,
save_embedding_every,
- write_csv_every,
template_file,
preview_from_txt2img,
*txt2img_preview_params,
--
cgit v1.2.3
From c7a86f7fe9c0b8967a87e8d709f507d2f44400d8 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 15 Oct 2022 09:24:59 +0300
Subject: add option to use batch size for training
---
modules/hypernetworks/hypernetwork.py | 33 +++++++++++++++++++-------
modules/textual_inversion/dataset.py | 31 ++++++++++++++----------
modules/textual_inversion/textual_inversion.py | 17 +++++++------
modules/ui.py | 3 +++
4 files changed, 54 insertions(+), 30 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 59c7ac6e..a2b3bc0a 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out)
-def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def stack_conds(conds):
+ if len(conds) == 1:
+ return torch.stack(conds)
+
+ # same as in reconstruct_multicond_batch
+ token_count = max([x.shape[0] for x in conds])
+ for i in range(len(conds)):
+ if conds[i].shape[0] != token_count:
+ last_vector = conds[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
+
+ return torch.stack(conds)
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, entry in pbar:
+ for i, entries in pbar:
hypernetwork.step = i + ititial_step
scheduler.apply(optimizer, hypernetwork.step)
@@ -246,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break
with torch.autocast("cuda"):
- cond = entry.cond.to(devices.device)
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), cond)[0]
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
- del cond
+ del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
@@ -292,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
p.width = preview_width
p.height = preview_height
else:
- p.prompt = entry.cond_text
+ p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
@@ -315,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 67e90afe..bd99c0cb 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -24,11 +24,12 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
- re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
+ self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -78,13 +79,13 @@ class PersonalizedBase(Dataset):
if include_cond:
entry.cond_text = self.create_text(filename_text)
- entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
- self.length = len(self.dataset) * repeats
+ self.length = len(self.dataset) * repeats // batch_size
- self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None
self.shuffle()
@@ -101,13 +102,19 @@ class PersonalizedBase(Dataset):
return self.length
def __getitem__(self, i):
- if i % len(self.dataset) == 0:
- self.shuffle()
+ res = []
- index = self.indexes[i % len(self.indexes)]
- entry = self.dataset[index]
+ for j in range(self.batch_size):
+ position = i * self.batch_size + j
+ if position % len(self.indexes) == 0:
+ self.shuffle()
- if entry.cond is None:
- entry.cond_text = self.create_text(entry.filename_text)
+ index = self.indexes[position % len(self.indexes)]
+ entry = self.dataset[index]
- return entry
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
+
+ res.append(entry)
+
+ return res
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index da0d77a0..e754747e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
@@ -251,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, entry in pbar:
+ for i, entries in pbar:
embedding.step = i + ititial_step
scheduler.apply(optimizer, embedding.step)
@@ -262,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break
with torch.autocast("cuda"):
- c = cond_model([entry.cond_text])
-
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
@@ -307,7 +306,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
p.width = preview_width
p.height = preview_height
else:
- p.prompt = entry.cond_text
+ p.prompt = entries[0].cond_text
p.steps = 20
p.width = training_width
p.height = training_height
@@ -348,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/ui.py b/modules/ui.py
index 1bc919c7..45550ea8 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1166,6 +1166,7 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@@ -1244,6 +1245,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_embedding_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
training_width,
@@ -1268,6 +1270,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_hypernetwork_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
steps,
--
cgit v1.2.3