From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 15:03:39 +0300
Subject: initial support for training textual inversion
---
modules/textual_inversion/textual_inversion.py | 258 +++++++++++++++++++++++++
1 file changed, 258 insertions(+)
create mode 100644 modules/textual_inversion/textual_inversion.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
new file mode 100644
index 00000000..c0baaace
--- /dev/null
+++ b/modules/textual_inversion/textual_inversion.py
@@ -0,0 +1,258 @@
+import os
+import sys
+import traceback
+
+import torch
+import tqdm
+import html
+import datetime
+
+from modules import shared, devices, sd_hijack, processing
+import modules.textual_inversion.dataset
+
+
+class Embedding:
+ def __init__(self, vec, name, step=None):
+ self.vec = vec
+ self.name = name
+ self.step = step
+ self.cached_checksum = None
+
+ def save(self, filename):
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ }
+
+ torch.save(embedding_data, filename)
+
+ def checksum(self):
+ if self.cached_checksum is not None:
+ return self.cached_checksum
+
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
+ return self.cached_checksum
+
+class EmbeddingDatabase:
+ def __init__(self, embeddings_dir):
+ self.ids_lookup = {}
+ self.word_embeddings = {}
+ self.dir_mtime = None
+ self.embeddings_dir = embeddings_dir
+
+ def register_embedding(self, embedding, model):
+
+ self.word_embeddings[embedding.name] = embedding
+
+ ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+ self.ids_lookup[first_id].append((ids, embedding))
+
+ return embedding
+
+ def load_textual_inversion_embeddings(self):
+ mt = os.path.getmtime(self.embeddings_dir)
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
+ return
+
+ self.dir_mtime = mt
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+
+ def process_file(path, filename):
+ name = os.path.splitext(filename)[0]
+
+ data = torch.load(path, map_location="cpu")
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ self.register_embedding(embedding, shared.sd_model)
+
+ for fn in os.listdir(self.embeddings_dir):
+ try:
+ fullfn = os.path.join(self.embeddings_dir, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+
+ def find_embedding_at_position(self, tokens, offset):
+ token = tokens[offset]
+ possible_matches = self.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ return None
+
+ for ids, embedding in possible_matches:
+ if tokens[offset:offset + len(ids)] == ids:
+ return embedding
+
+ return None
+
+
+
+def create_embedding(name, num_vectors_per_token):
+ init_text = '*'
+
+ cond_model = shared.sd_model.cond_stage_model
+ embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+
+ ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
+
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ embedding = Embedding(vec, name)
+ embedding.step = 0
+ embedding.save(fn)
+
+ return fn
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+ assert embedding_name, 'embedding not selected'
+
+ shared.state.textinfo = "Initializing textual inversion training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+
+ if save_embedding_every > 0:
+ embedding_dir = os.path.join(log_directory, "embeddings")
+ os.makedirs(embedding_dir, exist_ok=True)
+ else:
+ embedding_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ embedding.vec.requires_grad = True
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = ""
+ last_saved_image = ""
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ return embedding, filename
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, (x, text) in pbar:
+ embedding.step = i + ititial_step
+
+ if embedding.step > steps:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([text])
+ loss = shared.sd_model(x.unsqueeze(0), c)[0]
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ embedding.save(last_saved_file)
+
+ if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
+
+Loss: {losses.mean():.7f}
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
+"""
+
+ embedding.cached_checksum = None
+ embedding.save(filename)
+
+ return embedding, filename
+
--
cgit v1.2.3
From 88ec0cf5571883d84abd09196652b3679e359f2e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 19:40:51 +0300
Subject: fix for incorrect embedding token length calculation (will break
seeds that use embeddings, you're welcome!) add option to input
initialization text for embeddings
---
modules/sd_hijack.py | 8 ++++----
modules/textual_inversion/textual_inversion.py | 13 +++++--------
modules/textual_inversion/ui.py | 4 ++--
modules/ui.py | 2 ++
4 files changed, 13 insertions(+), 14 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fd57e5c5..3fa06242 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if embedding is None:
remade_tokens.append(token)
@@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
- i += emb_len
+ i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
@@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
- i += emb_len
+ i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c0baaace..0c50161d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -117,24 +117,21 @@ class EmbeddingDatabase:
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
- return None
+ return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
- return embedding
+ return embedding, len(ids)
- return None
+ return None, None
-
-def create_embedding(name, num_vectors_per_token):
- init_text = '*'
-
+def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index ce3677a9..66c43ffb 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared
-def create_embedding(name, nvpt):
- filename = ti.create_embedding(name, nvpt)
+def create_embedding(name, initialization_text, nvpt):
+ filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
diff --git a/modules/ui.py b/modules/ui.py
index 3b81a4f7..eca50df0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="Create a new embedding
")
new_embedding_name = gr.Textbox(label="Name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
@@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
+ initialization_text,
nvpt,
],
outputs=[
--
cgit v1.2.3
From 71fe7fa49f5eb1a2c89932a9d217ed153c12fc8b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 19:56:37 +0300
Subject: fix using aaaa-100 embedding when the prompt has aaaa-10000 and you
have both aaaa-100 and aaaa-10000 in the directory with embeddings.
---
modules/textual_inversion/textual_inversion.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 0c50161d..9d2241ce 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -57,7 +57,8 @@ class EmbeddingDatabase:
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
- self.ids_lookup[first_id].append((ids, embedding))
+
+ self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
--
cgit v1.2.3
From 4ec4af6e0b7addeee5221a03f32d117ccdc875d9 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 20:15:25 +0300
Subject: add checkpoint info to saved embeddings
---
modules/textual_inversion/textual_inversion.py | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 9d2241ce..1183aab7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,7 +7,7 @@ import tqdm
import html
import datetime
-from modules import shared, devices, sd_hijack, processing
+from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -17,6 +17,8 @@ class Embedding:
self.name = name
self.step = step
self.cached_checksum = None
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
def save(self, filename):
embedding_data = {
@@ -24,6 +26,8 @@ class Embedding:
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
+ "sd_checkpoint": self.sd_checkpoint,
+ "sd_checkpoint_name": self.sd_checkpoint_name,
}
torch.save(embedding_data, filename)
@@ -41,6 +45,7 @@ class Embedding:
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
+
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
@@ -96,6 +101,8 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
@@ -249,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}
"""
+ checkpoint = sd_models.select_checkpoint()
+
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
embedding.save(filename)
--
cgit v1.2.3
From c7543d4940da672d970124ae8f2fec9de7bdc1da Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 22:41:21 +0300
Subject: preprocessing for textual inversion added
---
modules/interrogate.py | 1 +
modules/textual_inversion/preprocess.py | 75 ++++++++++++++++++++++++++
modules/textual_inversion/textual_inversion.py | 1 +
modules/textual_inversion/ui.py | 14 +++--
modules/ui.py | 36 +++++++++++++
5 files changed, 124 insertions(+), 3 deletions(-)
create mode 100644 modules/textual_inversion/preprocess.py
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/interrogate.py b/modules/interrogate.py
index f62a4745..eed87144 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+
class InterrogateModels:
blip_model = None
clip_model = None
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
new file mode 100644
index 00000000..209e928f
--- /dev/null
+++ b/modules/textual_inversion/preprocess.py
@@ -0,0 +1,75 @@
+import os
+from PIL import Image, ImageOps
+import tqdm
+
+from modules import shared, images
+
+
+def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
+ size = 512
+ src = os.path.abspath(process_src)
+ dst = os.path.abspath(process_dst)
+
+ assert src != dst, 'same directory specified as source and desitnation'
+
+ os.makedirs(dst, exist_ok=True)
+
+ files = os.listdir(src)
+
+ shared.state.textinfo = "Preprocessing..."
+ shared.state.job_count = len(files)
+
+ if process_caption:
+ shared.interrogator.load()
+
+ def save_pic_with_caption(image, index):
+ if process_caption:
+ caption = "-" + shared.interrogator.generate_caption(image)
+ else:
+ caption = ""
+
+ image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
+ subindex[0] += 1
+
+ def save_pic(image, index):
+ save_pic_with_caption(image, index)
+
+ if process_flip:
+ save_pic_with_caption(ImageOps.mirror(image), index)
+
+ for index, imagefile in enumerate(tqdm.tqdm(files)):
+ subindex = [0]
+ filename = os.path.join(src, imagefile)
+ img = Image.open(filename).convert("RGB")
+
+ if shared.state.interrupted:
+ break
+
+ ratio = img.height / img.width
+ is_tall = ratio > 1.35
+ is_wide = ratio < 1 / 1.35
+
+ if process_split and is_tall:
+ img = img.resize((size, size * img.height // img.width))
+
+ top = img.crop((0, 0, size, size))
+ save_pic(top, index)
+
+ bot = img.crop((0, img.height - size, size, img.height))
+ save_pic(bot, index)
+ elif process_split and is_wide:
+ img = img.resize((size * img.width // img.height, size))
+
+ left = img.crop((0, 0, size, size))
+ save_pic(left, index)
+
+ right = img.crop((img.width - size, 0, img.width, size))
+ save_pic(right, index)
+ else:
+ img = images.resize_image(1, img, size, size)
+ save_pic(img, index)
+
+ shared.state.nextjob()
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1183aab7..d4e250d8 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,6 +7,7 @@ import tqdm
import html
import datetime
+
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 633037d8..f19ac5e0 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -2,24 +2,31 @@ import html
import gradio as gr
-import modules.textual_inversion.textual_inversion as ti
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
- filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+def preprocess(*args):
+ modules.textual_inversion.preprocess.preprocess(*args)
+
+ return "Preprocessing finished.", ""
+
+
def train_embedding(*args):
try:
sd_hijack.undo_optimizations()
- embedding, filename = ti.train_embedding(*args)
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
@@ -30,3 +37,4 @@ Embedding saved to {html.escape(filename)}
raise
finally:
sd_hijack.apply_optimizations()
+
diff --git a/modules/ui.py b/modules/ui.py
index 8912deff..e7bde53b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -961,6 +961,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
+ gr.HTML(value="See wiki for detailed explanation.
")
+
gr.HTML(value="Create a new embedding
")
new_embedding_name = gr.Textbox(label="Name")
@@ -974,6 +976,24 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
+ with gr.Group():
+ gr.HTML(value="Preprocess images
")
+
+ process_src = gr.Textbox(label='Source directory')
+ process_dst = gr.Textbox(label='Destination directory')
+
+ with gr.Row():
+ process_flip = gr.Checkbox(label='Flip')
+ process_split = gr.Checkbox(label='Split into two')
+ process_caption = gr.Checkbox(label='Add caption')
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ run_preprocess = gr.Button(value="Preprocess", variant='primary')
+
with gr.Group():
gr.HTML(value="Train an embedding; must specify a directory with a set of 512x512 images
")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
@@ -1018,6 +1038,22 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ run_preprocess.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ process_src,
+ process_dst,
+ process_flip,
+ process_split,
+ process_caption,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ],
+ )
+
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
--
cgit v1.2.3
From 6785331e22d6a488fbf5905fab56d7fec867e038 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 22:59:01 +0300
Subject: keep textual inversion dataset latents in CPU memory to save a bit of
VRAM
---
modules/textual_inversion/dataset.py | 2 ++
modules/textual_inversion/textual_inversion.py | 3 +++
modules/ui.py | 4 ++--
3 files changed, 7 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7e134a08..e8394ff6 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,6 +8,7 @@ from torchvision import transforms
import random
import tqdm
+from modules import devices
class PersonalizedBase(Dataset):
@@ -47,6 +48,7 @@ class PersonalizedBase(Dataset):
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
+ init_latent = init_latent.to(devices.cpu)
self.dataset.append((init_latent, filename_tokens))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d4e250d8..8686f534 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -212,7 +212,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
with torch.autocast("cuda"):
c = cond_model([text])
+
+ x = x.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ del x
losses[embedding.step % losses.shape[0]] = loss.item()
diff --git a/modules/ui.py b/modules/ui.py
index e7bde53b..d9d02ece 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1002,8 +1002,8 @@ def create_ui(wrap_gradio_gpu_call):
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
steps = gr.Number(label='Max steps', value=100000, precision=0)
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
with gr.Row():
with gr.Column(scale=2):
--
cgit v1.2.3
From 2865ef4b9ab16d56326cc805541bebcf01d099bc Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 3 Oct 2022 13:10:03 +0300
Subject: fix broken date in TI
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8686f534..cd9f3498 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
--
cgit v1.2.3