From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 2 Oct 2022 15:03:39 +0300
Subject: initial support for training textual inversion
---
modules/textual_inversion/dataset.py | 76 ++++++++
modules/textual_inversion/textual_inversion.py | 258 +++++++++++++++++++++++++
modules/textual_inversion/ui.py | 32 +++
3 files changed, 366 insertions(+)
create mode 100644 modules/textual_inversion/dataset.py
create mode 100644 modules/textual_inversion/textual_inversion.py
create mode 100644 modules/textual_inversion/ui.py
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
new file mode 100644
index 00000000..7e134a08
--- /dev/null
+++ b/modules/textual_inversion/dataset.py
@@ -0,0 +1,76 @@
+import os
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+import random
+import tqdm
+
+
+class PersonalizedBase(Dataset):
+ def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+
+ self.placeholder_token = placeholder_token
+
+ self.size = size
+ self.width = width
+ self.height = height
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ self.dataset = []
+
+ with open(template_file, "r") as file:
+ lines = [x.strip() for x in file.readlines()]
+
+ self.lines = lines
+
+ assert data_root, 'dataset directory not specified'
+
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+ print("Preparing dataset...")
+ for path in tqdm.tqdm(self.image_paths):
+ image = Image.open(path)
+ image = image.convert('RGB')
+ image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+
+ filename = os.path.basename(path)
+ filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
+ filename_tokens = [token for token in filename_tokens if token.isalpha()]
+
+ npimage = np.array(image).astype(np.uint8)
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
+
+ torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
+ torchdata = torch.moveaxis(torchdata, 2, 0)
+
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
+
+ self.dataset.append((init_latent, filename_tokens))
+
+ self.length = len(self.dataset) * repeats
+
+ self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.indexes = None
+ self.shuffle()
+
+ def shuffle(self):
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, i):
+ if i % len(self.dataset) == 0:
+ self.shuffle()
+
+ index = self.indexes[i % len(self.indexes)]
+ x, filename_tokens = self.dataset[index]
+
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+
+ return x, text
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
new file mode 100644
index 00000000..c0baaace
--- /dev/null
+++ b/modules/textual_inversion/textual_inversion.py
@@ -0,0 +1,258 @@
+import os
+import sys
+import traceback
+
+import torch
+import tqdm
+import html
+import datetime
+
+from modules import shared, devices, sd_hijack, processing
+import modules.textual_inversion.dataset
+
+
+class Embedding:
+ def __init__(self, vec, name, step=None):
+ self.vec = vec
+ self.name = name
+ self.step = step
+ self.cached_checksum = None
+
+ def save(self, filename):
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ }
+
+ torch.save(embedding_data, filename)
+
+ def checksum(self):
+ if self.cached_checksum is not None:
+ return self.cached_checksum
+
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
+ return self.cached_checksum
+
+class EmbeddingDatabase:
+ def __init__(self, embeddings_dir):
+ self.ids_lookup = {}
+ self.word_embeddings = {}
+ self.dir_mtime = None
+ self.embeddings_dir = embeddings_dir
+
+ def register_embedding(self, embedding, model):
+
+ self.word_embeddings[embedding.name] = embedding
+
+ ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+ self.ids_lookup[first_id].append((ids, embedding))
+
+ return embedding
+
+ def load_textual_inversion_embeddings(self):
+ mt = os.path.getmtime(self.embeddings_dir)
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
+ return
+
+ self.dir_mtime = mt
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+
+ def process_file(path, filename):
+ name = os.path.splitext(filename)[0]
+
+ data = torch.load(path, map_location="cpu")
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ self.register_embedding(embedding, shared.sd_model)
+
+ for fn in os.listdir(self.embeddings_dir):
+ try:
+ fullfn = os.path.join(self.embeddings_dir, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+
+ def find_embedding_at_position(self, tokens, offset):
+ token = tokens[offset]
+ possible_matches = self.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ return None
+
+ for ids, embedding in possible_matches:
+ if tokens[offset:offset + len(ids)] == ids:
+ return embedding
+
+ return None
+
+
+
+def create_embedding(name, num_vectors_per_token):
+ init_text = '*'
+
+ cond_model = shared.sd_model.cond_stage_model
+ embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+
+ ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
+
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ embedding = Embedding(vec, name)
+ embedding.step = 0
+ embedding.save(fn)
+
+ return fn
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+ assert embedding_name, 'embedding not selected'
+
+ shared.state.textinfo = "Initializing textual inversion training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+
+ if save_embedding_every > 0:
+ embedding_dir = os.path.join(log_directory, "embeddings")
+ os.makedirs(embedding_dir, exist_ok=True)
+ else:
+ embedding_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ embedding.vec.requires_grad = True
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "
+Loss: {losses.mean():.7f} Create a new embedding
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
See wiki for detailed explanation.
") + gr.HTML(value="Create a new embedding
") new_embedding_name = gr.Textbox(label="Name") @@ -974,6 +976,24 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create", variant='primary') + with gr.Group(): + gr.HTML(value="Preprocess images
") + + process_src = gr.Textbox(label='Source directory') + process_dst = gr.Textbox(label='Destination directory') + + with gr.Row(): + process_flip = gr.Checkbox(label='Flip') + process_split = gr.Checkbox(label='Split into two') + process_caption = gr.Checkbox(label='Add caption') + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + run_preprocess = gr.Button(value="Preprocess", variant='primary') + with gr.Group(): gr.HTML(value="Train an embedding; must specify a directory with a set of 512x512 images
") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) @@ -1018,6 +1038,22 @@ def create_ui(wrap_gradio_gpu_call): ] ) + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_flip, + process_split, + process_caption, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + train_embedding.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", -- cgit v1.2.3 From 6785331e22d6a488fbf5905fab56d7fec867e038 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 22:59:01 +0300 Subject: keep textual inversion dataset latents in CPU memory to save a bit of VRAM --- modules/textual_inversion/dataset.py | 2 ++ modules/textual_inversion/textual_inversion.py | 3 +++ modules/ui.py | 4 ++-- 3 files changed, 7 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7e134a08..e8394ff6 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,6 +8,7 @@ from torchvision import transforms import random import tqdm +from modules import devices class PersonalizedBase(Dataset): @@ -47,6 +48,7 @@ class PersonalizedBase(Dataset): torchdata = torch.moveaxis(torchdata, 2, 0) init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + init_latent = init_latent.to(devices.cpu) self.dataset.append((init_latent, filename_tokens)) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d4e250d8..8686f534 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -212,7 +212,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, with torch.autocast("cuda"): c = cond_model([text]) + + x = x.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x losses[embedding.step % losses.shape[0]] = loss.item() diff --git a/modules/ui.py b/modules/ui.py index e7bde53b..d9d02ece 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1002,8 +1002,8 @@ def create_ui(wrap_gradio_gpu_call): log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) with gr.Row(): with gr.Column(scale=2): -- cgit v1.2.3 From 2865ef4b9ab16d56326cc805541bebcf01d099bc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 3 Oct 2022 13:10:03 +0300 Subject: fix broken date in TI --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8686f534..cd9f3498 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) if save_embedding_every > 0: embedding_dir = os.path.join(log_directory, "embeddings") -- cgit v1.2.3 From 5ef0baf5eaec7f21a1666af424405cbee19f3764 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 08:52:11 +0300 Subject: add support for gelbooru tags in filenames for textual inversion --- modules/textual_inversion/dataset.py | 7 +++++-- modules/textual_inversion/preprocess.py | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e8394ff6..7c44ea5b 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -9,6 +9,9 @@ from torchvision import transforms import random import tqdm from modules import devices +import re + +re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): @@ -38,8 +41,8 @@ class PersonalizedBase(Dataset): image = image.resize((self.width, self.height), PIL.Image.BICUBIC) filename = os.path.basename(path) - filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') - filename_tokens = [token for token in filename_tokens if token.isalpha()] + filename_tokens = os.path.splitext(filename)[0] + filename_tokens = re_tag.findall(filename_tokens) npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 209e928f..f545a993 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -26,7 +26,9 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca if process_caption: caption = "-" + shared.interrogator.generate_caption(image) else: - caption = "" + caption = filename + caption = os.path.splitext(caption)[0] + caption = os.path.basename(caption) image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png")) subindex[0] += 1 -- cgit v1.2.3 From 2499fb4e1910d31ff12c24110f161b20641b8835 Mon Sep 17 00:00:00 2001 From: Raphael Stoeckli
+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
Create a new hypernetwork
") + + new_hypernetwork_name = gr.Textbox(label="Name") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create", variant='primary') + with gr.Group(): gr.HTML(value="Preprocess images
") @@ -986,6 +999,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): gr.HTML(value="Train an embedding; must specify a directory with a set of 512x512 images
") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) learn_rate = gr.Number(label='Learning rate', value=5.0e-03) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") @@ -993,15 +1007,12 @@ def create_ui(wrap_gradio_gpu_call): steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + preview_image_prompt = gr.Textbox(label='Preview prompt', value="") with gr.Row(): - with gr.Column(scale=2): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_embedding = gr.Button(value="Train", variant='primary') + interrupt_training = gr.Button(value="Interrupt") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') + train_embedding = gr.Button(value="Train Embedding", variant='primary') with gr.Column(): progressbar = gr.HTML(elem_id="ti_progressbar") @@ -1027,6 +1038,18 @@ def create_ui(wrap_gradio_gpu_call): ] ) + create_hypernetwork.click( + fn=modules.hypernetwork.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + run_preprocess.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", @@ -1062,12 +1085,33 @@ def create_ui(wrap_gradio_gpu_call): ] ) + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + preview_image_prompt, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + interrupt_training.click( fn=lambda: shared.state.interrupt(), inputs=[], outputs=[], ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c0c364df..5b504de6 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -78,8 +78,7 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hn = shared.hypernetworks.get(x, None) - opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + shared.hypernetwork = shared.hypernetworks.get(x, None) def format_value_add_label(p, opt, x): @@ -199,7 +198,7 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 - initial_hn = opts.sd_hypernetwork + initial_hn = shared.hypernetwork def process_axis(opt, vals): if opt.label == 'Nothing': @@ -308,6 +307,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) - opts.data["sd_hypernetwork"] = initial_hn + shared.hypernetwork = initial_hn return processed diff --git a/textual_inversion_templates/hypernetwork.txt b/textual_inversion_templates/hypernetwork.txt new file mode 100644 index 00000000..91e06890 --- /dev/null +++ b/textual_inversion_templates/hypernetwork.txt @@ -0,0 +1,27 @@ +a photo of a [filewords] +a rendering of a [filewords] +a cropped photo of the [filewords] +the photo of a [filewords] +a photo of a clean [filewords] +a photo of a dirty [filewords] +a dark photo of the [filewords] +a photo of my [filewords] +a photo of the cool [filewords] +a close-up photo of a [filewords] +a bright photo of the [filewords] +a cropped photo of a [filewords] +a photo of the [filewords] +a good photo of the [filewords] +a photo of one [filewords] +a close-up photo of the [filewords] +a rendition of the [filewords] +a photo of the clean [filewords] +a rendition of a [filewords] +a photo of a nice [filewords] +a good photo of a [filewords] +a photo of the nice [filewords] +a photo of the small [filewords] +a photo of the weird [filewords] +a photo of the large [filewords] +a photo of a cool [filewords] +a photo of a small [filewords] diff --git a/textual_inversion_templates/none.txt b/textual_inversion_templates/none.txt new file mode 100644 index 00000000..f77af461 --- /dev/null +++ b/textual_inversion_templates/none.txt @@ -0,0 +1 @@ +picture diff --git a/webui.py b/webui.py index 480360fe..60f9061f 100644 --- a/webui.py +++ b/webui.py @@ -74,6 +74,15 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) +def set_hypernetwork(): + shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) + + +shared.reload_hypernetworks() +shared.opts.onchange("sd_hypernetwork", set_hypernetwork) +set_hypernetwork() + + modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() -- cgit v1.2.3 From 5841990b0df04906da7321beef6f7f7902b7d57b Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 05:38:38 +0100 Subject: Update textual_inversion.py --- modules/textual_inversion/textual_inversion.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cd9f3498..f6316020 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,6 +7,9 @@ import tqdm import html import datetime +from PIL import Image, PngImagePlugin +import base64 +from io import BytesIO from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -80,7 +83,15 @@ class EmbeddingDatabase: def process_file(path, filename): name = os.path.splitext(filename)[0] - data = torch.load(path, map_location="cpu") + data = [] + + if filename.upper().endswith('.PNG'): + embed_image = Image.open(path) + if 'sd-embedding' in embed_image.text: + embeddingData = base64.b64decode(embed_image.text['sd-embedding']) + data = torch.load(BytesIO(embeddingData), map_location="cpu") + else: + data = torch.load(path, map_location="cpu") # textual inversion embeddings if 'string_to_param' in data: @@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -244,7 +255,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, image = processed.images[0] shared.state.current_image = image - image.save(last_saved_image) + + if save_image_with_stored_embedding: + info = PngImagePlugin.PngInfo() + info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read())) + image.save(last_saved_image, "PNG", pnginfo=info) + else: + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" -- cgit v1.2.3 From 03694e1f9915e34cf7d9a31073f1a1a9def2909f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 21:58:14 +0100 Subject: add embedding load and save from b64 json --- modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f6316020..1b7f8906 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,9 +7,11 @@ import tqdm import html import datetime -from PIL import Image, PngImagePlugin +from PIL import Image,PngImagePlugin +from ..images import captionImge +import numpy as np import base64 -from io import BytesIO +import json from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -87,9 +89,9 @@ class EmbeddingDatabase: if filename.upper().endswith('.PNG'): embed_image = Image.open(path) - if 'sd-embedding' in embed_image.text: - embeddingData = base64.b64decode(embed_image.text['sd-embedding']) - data = torch.load(BytesIO(embeddingData), map_location="cpu") + if 'sd-ti-embedding' in embed_image.text: + data = embeddingFromB64(embed_image.text['sd-ti-embedding']) + name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, if save_image_with_stored_embedding: info = PngImagePlugin.PngInfo() - info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read())) - image.save(last_saved_image, "PNG", pnginfo=info) + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embeddingToB64(data)) + + pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] + + caption_checkpoint_hash = data.get('sd_checkpoint','UNK') + caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK' + caption_stepcount = data.get('step',0) + caption_stepcount = caption_stepcount if caption_stepcount else 0 + + post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash, + caption_stepcount))] + captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) + captioned_image.save(last_saved_image, "PNG", pnginfo=info) else: image.save(last_saved_image) - - last_saved_image += f", prompt: {text}" shared.state.job_no = embedding.step -- cgit v1.2.3 From 969bd8256e5b4f1007d3cc653723d4ad50a92528 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:02:28 +0100 Subject: add alternate checkpoint hash source --- modules/textual_inversion/textual_inversion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1b7f8906..d7813084 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -265,8 +265,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - caption_checkpoint_hash = data.get('sd_checkpoint','UNK') - caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK' + caption_checkpoint_hash = data.get('sd_checkpoint') + if caption_checkpoint_hash is None: + caption_checkpoint_hash = data.get('hash') + caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN' + caption_stepcount = data.get('step',0) caption_stepcount = caption_stepcount if caption_stepcount else 0 -- cgit v1.2.3 From 5d12ec82d3e13f5ff4c55db2930e4e10aed7015a Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:05:09 +0100 Subject: add encoder and decoder classes --- modules/textual_inversion/textual_inversion.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d7813084..44d4e08b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -16,6 +16,27 @@ import json from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +class EmbeddingEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()} + return json.JSONEncoder.default(self, o) + +class EmbeddingDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + def object_hook(self, d): + if 'EMBEDDINGTENSOR' in d: + return torch.from_numpy(np.array(d['EMBEDDINGTENSOR'])) + return d + +def embeddingToB64(data): + d = json.dumps(data,cls=EmbeddingEncoder) + return base64.b64encode(d.encode()) + +def EmbeddingFromB64(data): + d = base64.b64decode(data) + return json.loads(d,cls=EmbeddingDecoder) class Embedding: def __init__(self, vec, name, step=None): -- cgit v1.2.3 From d0184b8f76ce492da699f1926f34b57cd095242e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:06:12 +0100 Subject: change json tensor key name --- modules/textual_inversion/textual_inversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44d4e08b..ae8d207d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -19,15 +19,15 @@ import modules.textual_inversion.dataset class EmbeddingEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.Tensor): - return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()} + return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} return json.JSONEncoder.default(self, o) class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) def object_hook(self, d): - if 'EMBEDDINGTENSOR' in d: - return torch.from_numpy(np.array(d['EMBEDDINGTENSOR'])) + if 'TORCHTENSOR' in d: + return torch.from_numpy(np.array(d['TORCHTENSOR'])) return d def embeddingToB64(data): -- cgit v1.2.3 From 66846105103cfc282434d0dc2102910160b7a633 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:06:42 +0100 Subject: correct case on embeddingFromB64 --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ae8d207d..d2b95fa3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -34,7 +34,7 @@ def embeddingToB64(data): d = json.dumps(data,cls=EmbeddingEncoder) return base64.b64encode(d.encode()) -def EmbeddingFromB64(data): +def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) -- cgit v1.2.3 From 96f1e6be59316ec640cab2435fa95b3688194906 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:14:50 +0100 Subject: source checkpoint hash from current checkpoint --- modules/textual_inversion/textual_inversion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d2b95fa3..b16fa84e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -286,10 +286,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - caption_checkpoint_hash = data.get('sd_checkpoint') - if caption_checkpoint_hash is None: - caption_checkpoint_hash = data.get('hash') - caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN' + checkpoint = sd_models.select_checkpoint() + caption_checkpoint_hash = checkpoint.hash caption_stepcount = data.get('step',0) caption_stepcount = caption_stepcount if caption_stepcount else 0 -- cgit v1.2.3 From 01fd9cf0d28d8b71a113ab1aa62accfe7f0d9c51 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:17:02 +0100 Subject: change source of step count --- modules/textual_inversion/textual_inversion.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b16fa84e..e4f339b8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -285,15 +285,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, info.add_text("sd-ti-embedding", embeddingToB64(data)) pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - checkpoint = sd_models.select_checkpoint() - caption_checkpoint_hash = checkpoint.hash - - caption_stepcount = data.get('step',0) - caption_stepcount = caption_stepcount if caption_stepcount else 0 - - post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash, - caption_stepcount))] + post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash, + embedding.step))] captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) captioned_image.save(last_saved_image, "PNG", pnginfo=info) else: -- cgit v1.2.3 From d6a599ef9ba18a66ae79b50f2945af5788fdda8f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:07:52 +0100 Subject: change caption method --- modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e4f339b8..21596e78 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -8,7 +8,7 @@ import html import datetime from PIL import Image,PngImagePlugin -from ..images import captionImge +from ..images import captionImageOverlay import numpy as np import base64 import json @@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, else: images_dir = None + if create_image_every > 0 and save_image_with_stored_embedding: + images_embeds_dir = os.path.join(log_directory, "image_embeddings") + os.makedirs(images_embeds_dir, exist_ok=True) + else: + images_embeds_dir = None + cond_model = shared.sd_model.cond_stage_model shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." @@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, shared.state.current_image = image - if save_image_with_stored_embedding: + if save_image_with_stored_embedding and os.path.exists(last_saved_file): + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') + info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embeddingToB64(data)) - pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] + title = "<{}>".format(data.get('name','???')) checkpoint = sd_models.select_checkpoint() - post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash, - embedding.step))] - captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) - captioned_image.save(last_saved_image, "PNG", pnginfo=info) - else: - image.save(last_saved_image) + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '[{}]'.format(embedding.step) + + captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + + image.save(last_saved_image) last_saved_image += f", prompt: {text}" -- cgit v1.2.3 From e2c2925eb4d634b186de2c76798162ec56e2f869 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:12:53 +0100 Subject: remove braces from steps --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 21596e78..9a18ee5c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -297,7 +297,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '[{}]'.format(embedding.step) + footer_right = '{}'.format(embedding.step) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) -- cgit v1.2.3 From 1f92336be768d235c18a82acb2195b7135101ae7 Mon Sep 17 00:00:00 2001 From: JC_ArrayTrain an embedding; must specify a directory with a set of 512x512 images
") + gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) learn_rate = gr.Number(label='Learning rate', value=5.0e-03) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) + num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, + process_size, process_flip, process_split, process_caption, @@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, + training_size, steps, + num_repeats, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 4ee7519fc2e459ce8eff1f61f1655afba393357c Mon Sep 17 00:00:00 2001 From: alg-wikiTrain an embedding; must specify a directory with a set of 512x512 images
") + gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) learn_rate = gr.Number(label='Learning rate', value=5.0e-03) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) + num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, + process_size, process_flip, process_split, process_caption, @@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, + training_size, steps, + num_repeats, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 6ad3a53e368d36535de1a4fca73b3bb78fd40654 Mon Sep 17 00:00:00 2001 From: alg-wikiTrain an embedding; must specify a directory with a set of 1:1 ratio images
") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - learn_rate = gr.Number(label='Learning rate', value=5.0e-03) + learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) -- cgit v1.2.3 From 907a88b2d0be320575c2129d8d6a1d4f3a68f9eb Mon Sep 17 00:00:00 2001 From: alg-wikiCreate a new hypernetwork
") @@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_hypernetwork = gr.Button(value="Create", variant='primary') + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') with gr.Group(): gr.HTML(value="Preprocess images
") @@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + preview_image_prompt, ], outputs=[ ti_output, diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 42e1489c..0af5993c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,8 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, hypernetwork +from modules import images +from modules.hypernetwork import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared diff --git a/webui.py b/webui.py index 7c200551..ba2156c8 100644 --- a/webui.py +++ b/webui.py @@ -29,6 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts +import modules.hypernetwork.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() @@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) -def set_hypernetwork(): - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) - - -shared.reload_hypernetworks() -shared.opts.onchange("sd_hypernetwork", set_hypernetwork) -set_hypernetwork() - - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) -shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) def webui(): @@ -117,7 +108,7 @@ def webui(): prevent_thread_lock=True ) - app.add_middleware(GZipMiddleware,minimum_size=1000) + app.add_middleware(GZipMiddleware, minimum_size=1000) while 1: time.sleep(0.5) -- cgit v1.2.3 From 6d09b8d1df3a96e1380bb1650f5961781630af96 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 18:33:57 +0300 Subject: produce error when training with medvram/lowvram enabled --- modules/hypernetworks/ui.py | 2 ++ modules/textual_inversion/ui.py | 3 +++ 2 files changed, 5 insertions(+) (limited to 'modules/textual_inversion') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index cdddcce1..3541a388 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -25,6 +25,8 @@ def train_hypernetwork(*args): initial_hypernetwork = shared.loaded_hypernetwork + assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + try: sd_hijack.undo_optimizations() diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index c57de1f9..70f47343 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -22,6 +22,9 @@ def preprocess(*args): def train_embedding(*args): + + assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + try: sd_hijack.undo_optimizations() -- cgit v1.2.3 From d4ea5f4d8631f778d11efcde397e4a5b8801d43b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 19:03:08 +0300 Subject: add an option to unload models during hypernetwork training to save VRAM --- modules/hypernetworks/hypernetwork.py | 25 +++++++++++++++------- modules/hypernetworks/ui.py | 4 +++- modules/shared.py | 4 ++++ modules/textual_inversion/dataset.py | 29 ++++++++++++++++++-------- modules/textual_inversion/textual_inversion.py | 2 +- 5 files changed, 46 insertions(+), 18 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b081f14e..4700e1ec 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + unload = shared.opts.unload_models_when_training if save_hypernetwork_every > 0: hypernetwork_dir = os.path.join(log_directory, "hypernetworks") @@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, else: images_dir = None - cond_model = shared.sd_model.cond_stage_model - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() @@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, return hypernetwork, filename pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: + for i, (x, text, cond) in pbar: hypernetwork.step = i + ititial_step if hypernetwork.step > steps: @@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, break with torch.autocast("cuda"): - c = cond_model([text]) - + cond = cond.to(devices.device) x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] + loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x + del cond losses[hypernetwork.step % losses.shape[0]] = loss.item() @@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, preview_text = text if preview_image_prompt == "" else preview_image_prompt + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, prompt=preview_text, @@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, processed = processing.process_images(p) image = processed.images[0] + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + shared.state.current_image = image image.save(last_saved_image) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 3541a388..c67facbb 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -5,7 +5,7 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess -from modules import sd_hijack, shared +from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork @@ -41,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)} raise finally: shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/shared.py b/modules/shared.py index 20b45f23..c1092ff7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -228,6 +228,10 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), })) +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), +})) + options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 4d006366..f61f40d3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,14 +8,14 @@ from torchvision import transforms import random import tqdm -from modules import devices +from modules import devices, shared import re re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): self.placeholder_token = placeholder_token @@ -32,6 +32,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' + cond_model = shared.sd_model.cond_stage_model + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -53,7 +55,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) - self.dataset.append((init_latent, filename_tokens)) + if include_cond: + text = self.create_text(filename_tokens) + cond = cond_model([text]).to(devices.cpu) + else: + cond = None + + self.dataset.append((init_latent, filename_tokens, cond)) self.length = len(self.dataset) * repeats @@ -64,6 +72,12 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + def create_text(self, filename_tokens): + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + return text + def __len__(self): return self.length @@ -72,10 +86,7 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens = self.dataset[index] - - text = random.choice(self.lines) - text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) + x, filename_tokens, cond = self.dataset[index] - return x, text + text = self.create_text(filename_tokens) + return x, text, cond diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index bb05cdc6..35f4bd9e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -201,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini return embedding, filename pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text) in pbar: + for i, (x, text, _) in pbar: embedding.step = i + ititial_step if embedding.step > steps: -- cgit v1.2.3 From c080f52ceae73b893155eff7de577aaf1a982a2f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:37:58 +0100 Subject: move embedding logic to separate file --- modules/textual_inversion/image_embedding.py | 234 +++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 modules/textual_inversion/image_embedding.py (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py new file mode 100644 index 00000000..6ad39602 --- /dev/null +++ b/modules/textual_inversion/image_embedding.py @@ -0,0 +1,234 @@ +import base64 +import json +import numpy as np +import zlib +from PIL import Image,PngImagePlugin,ImageDraw,ImageFont +from fonts.ttf import Roboto +import torch + +class EmbeddingEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} + return json.JSONEncoder.default(self, obj) + +class EmbeddingDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + def object_hook(self, d): + if 'TORCHTENSOR' in d: + return torch.from_numpy(np.array(d['TORCHTENSOR'])) + return d + +def embedding_to_b64(data): + d = json.dumps(data,cls=EmbeddingEncoder) + return base64.b64encode(d.encode()) + +def embedding_from_b64(data): + d = base64.b64decode(data) + return json.loads(d,cls=EmbeddingDecoder) + +def lcg(m=2**32, a=1664525, c=1013904223, seed=0): + while True: + seed = (a * seed + c) % m + yield seed%255 + +def xor_block(block): + g = lcg() + randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) + return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) + +def style_block(block,sequence): + im = Image.new('RGB',(block.shape[1],block.shape[0])) + draw = ImageDraw.Draw(im) + i=0 + for x in range(-6,im.size[0],8): + for yi,y in enumerate(range(-6,im.size[1],8)): + offset=0 + if yi%2==0: + offset=4 + shade = sequence[i%len(sequence)] + i+=1 + draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) + + fg = np.array(im).astype(np.uint8) & 0xF0 + + return block ^ fg + +def insert_image_data_embed(image,data): + d = 3 + data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) + data_np_ = np.frombuffer(data_compressed,np.uint8).copy() + data_np_high = data_np_ >> 4 + data_np_low = data_np_ & 0x0F + + h = image.size[1] + next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0]%h)) + next_size = next_size + ((h*d)-(next_size%(h*d))) + + data_np_low.resize(next_size) + data_np_low = data_np_low.reshape((h,-1,d)) + + data_np_high.resize(next_size) + data_np_high = data_np_high.reshape((h,-1,d)) + + edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] + edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8) + + data_np_low = style_block(data_np_low,sequence=edge_style) + data_np_low = xor_block(data_np_low) + data_np_high = style_block(data_np_high,sequence=edge_style[::-1]) + data_np_high = xor_block(data_np_high) + + im_low = Image.fromarray(data_np_low,mode='RGB') + im_high = Image.fromarray(data_np_high,mode='RGB') + + background = Image.new('RGB',(image.size[0]+im_low.size[0]+im_high.size[0]+2,image.size[1]),(0,0,0)) + background.paste(im_low,(0,0)) + background.paste(image,(im_low.size[0]+1,0)) + background.paste(im_high,(im_low.size[0]+1+image.size[0]+1,0)) + + return background + +def crop_black(img,tol=0): + mask = (img>tol).all(2) + mask0,mask1 = mask.any(0),mask.any(1) + col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() + row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() + return img[row_start:row_end,col_start:col_end] + +def extract_image_data_embed(image): + d=3 + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F + black_cols = np.where( np.sum(outarr, axis=(0,2))==0) + if black_cols[0].shape[0] < 2: + print('No Image data blocks found.') + return None + + data_block_lower = outarr[:,:black_cols[0].min(),:].astype(np.uint8) + data_block_upper = outarr[:,black_cols[0].max()+1:,:].astype(np.uint8) + + data_block_lower = xor_block(data_block_lower) + data_block_upper = xor_block(data_block_upper) + + data_block = (data_block_upper << 4) | (data_block_lower) + data_block = data_block.flatten().tobytes() + + data = zlib.decompress(data_block) + return json.loads(data,cls=EmbeddingDecoder) + +def addCaptionLines(lines,image,initialx,textfont): + draw = ImageDraw.Draw(image) + hstart =initialx + for fill,line in lines: + fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) + _,_,w, h = draw.textbbox((0,0),line,font=font) + fontsize = min( int(fontsize * ((image.size[0]-35)/w) ), 28) + font = ImageFont.truetype(textfont, fontsize) + _,_,w,h = draw.textbbox((0,0),line,font=font) + draw.text(((image.size[0]-w)/2,hstart), line, font=font, fill=fill) + hstart += h + return hstart + +def caption_image(image,prelines,postlines,background=(51, 51, 51),font=None): + if font is None: + try: + font = ImageFont.truetype(opts.font or Roboto, fontsize) + font = opts.font or Roboto + except Exception: + font = Roboto + + sample_image = image + background = Image.new("RGBA", (sample_image.size[0],sample_image.size[1]+1024), background) + hoffset = addCaptionLines(prelines,background,5,font)+16 + background.paste(sample_image,(0,hoffset)) + hoffset = hoffset+sample_image.size[1]+8 + hoffset = addCaptionLines(postlines,background,hoffset,font) + background = background.crop((0,0,sample_image.size[0],hoffset+8)) + return background + +def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None): + from math import cos + + image = srcimage.copy() + + if textfont is None: + try: + textfont = ImageFont.truetype(opts.font or Roboto, fontsize) + textfont = opts.font or Roboto + except Exception: + textfont = Roboto + + factor = 1.5 + gradient = Image.new('RGBA', (1,image.size[1]), color=(0,0,0,0)) + for y in range(image.size[1]): + mag = 1-cos(y/image.size[1]*factor) + mag = max(mag,1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) + gradient.putpixel((0, y), (0,0,0,int(mag*255))) + image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) + + draw = ImageDraw.Draw(image) + fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) + padding = 10 + + _,_,w, h = draw.textbbox((0,0),title,font=font) + fontsize = min( int(fontsize * (((image.size[0]*0.75)-(padding*4))/w) ), 72) + font = ImageFont.truetype(textfont, fontsize) + _,_,w,h = draw.textbbox((0,0),title,font=font) + draw.text((padding,padding), title, anchor='lt', font=font, fill=(255,255,255,230)) + + _,_,w, h = draw.textbbox((0,0),footerLeft,font=font) + fontsize_left = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) + _,_,w, h = draw.textbbox((0,0),footerMid,font=font) + fontsize_mid = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) + _,_,w, h = draw.textbbox((0,0),footerRight,font=font) + fontsize_right = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) + + font = ImageFont.truetype(textfont, min(fontsize_left,fontsize_mid,fontsize_right)) + + draw.text((padding,image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255,255,255,230)) + draw.text((image.size[0]/2,image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255,255,255,230)) + draw.text((image.size[0]-padding,image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255,255,255,230)) + + return image + +if __name__ == '__main__': + + image = Image.new('RGBA',(512,512),(255,255,200,255)) + caption_image(image,[((255,255,255),'line a'),((255,255,255),'line b')], + [((255,255,255),'line c'),((255,255,255),'line d')]) + + image = Image.new('RGBA',(512,512),(255,255,200,255)) + cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') + + test_embed = {'string_to_param':{'*':torch.from_numpy(np.random.random((2, 4096)))}} + + embedded_image = insert_image_data_embed(cap_image, test_embed) + + retrived_embed = extract_image_data_embed(embedded_image) + + assert str(retrived_embed) == str(test_embed) + + embedded_image2 = insert_image_data_embed(cap_image, retrived_embed) + + assert embedded_image == embedded_image2 + + g = lcg() + shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist() + + reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177, + 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179, + 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193, + 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28, + 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0, + 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185, + 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82, + 204, 86, 73, 222, 44, 198, 118, 240, 97] + + assert shared_random == reference_random + + hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) + + assert 12731374 == hunna_kay_random_sum \ No newline at end of file -- cgit v1.2.3 From 61788c0538415fa9ca1dd1b306519c116b18bd2c Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:50:50 +0100 Subject: shift embedding logic out of textual_inversion --- modules/textual_inversion/textual_inversion.py | 125 ++----------------------- 1 file changed, 6 insertions(+), 119 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8c66aeb5..22b4ae7f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,124 +7,11 @@ import tqdm import html import datetime -from PIL import Image,PngImagePlugin,ImageDraw -from ..images import captionImageOverlay -import numpy as np -import base64 -import json -import zlib +from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -class EmbeddingEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, torch.Tensor): - return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} - return json.JSONEncoder.default(self, obj) - -class EmbeddingDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) - def object_hook(self, d): - if 'TORCHTENSOR' in d: - return torch.from_numpy(np.array(d['TORCHTENSOR'])) - return d - -def embeddingToB64(data): - d = json.dumps(data,cls=EmbeddingEncoder) - return base64.b64encode(d.encode()) - -def embeddingFromB64(data): - d = base64.b64decode(data) - return json.loads(d,cls=EmbeddingDecoder) - -def lcg(m=2**32, a=1664525, c=1013904223, seed=0): - while True: - seed = (a * seed + c) % m - yield seed - -def xorBlock(block): - g = lcg() - randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) - return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) - -def styleBlock(block,sequence): - im = Image.new('RGB',(block.shape[1],block.shape[0])) - draw = ImageDraw.Draw(im) - i=0 - for x in range(-6,im.size[0],8): - for yi,y in enumerate(range(-6,im.size[1],8)): - offset=0 - if yi%2==0: - offset=4 - shade = sequence[i%len(sequence)] - i+=1 - draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) - - fg = np.array(im).astype(np.uint8) & 0xF0 - return block ^ fg - -def insertImageDataEmbed(image,data): - d = 3 - data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) - dnp = np.frombuffer(data_compressed,np.uint8).copy() - dnphigh = dnp >> 4 - dnplow = dnp & 0x0F - - h = image.size[1] - next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h)) - next_size = next_size + ((h*d)-(next_size%(h*d))) - - dnplow.resize(next_size) - dnplow = dnplow.reshape((h,-1,d)) - - dnphigh.resize(next_size) - dnphigh = dnphigh.reshape((h,-1,d)) - - edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] - edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8) - - dnplow = styleBlock(dnplow,sequence=edgeStyleWeights) - dnplow = xorBlock(dnplow) - dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1]) - dnphigh = xorBlock(dnphigh) - - imlow = Image.fromarray(dnplow,mode='RGB') - imhigh = Image.fromarray(dnphigh,mode='RGB') - - background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0)) - background.paste(imlow,(0,0)) - background.paste(image,(imlow.size[0]+1,0)) - background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0)) - - return background - -def crop_black(img,tol=0): - mask = (img>tol).all(2) - mask0,mask1 = mask.any(0),mask.any(1) - col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() - row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() - return img[row_start:row_end,col_start:col_end] - -def extractImageDataEmbed(image): - d=3 - outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F - blackCols = np.where( np.sum(outarr, axis=(0,2))==0) - if blackCols[0].shape[0] < 2: - print('No Image data blocks found.') - return None - - dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8) - dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8) - - dataBlocklower = xorBlock(dataBlocklower) - dataBlockupper = xorBlock(dataBlockupper) - - dataBlock = (dataBlockupper << 4) | (dataBlocklower) - dataBlock = dataBlock.flatten().tobytes() - data = zlib.decompress(dataBlock) - return json.loads(data,cls=EmbeddingDecoder) class Embedding: def __init__(self, vec, name, step=None): @@ -199,10 +86,10 @@ class EmbeddingDatabase: if filename.upper().endswith('.PNG'): embed_image = Image.open(path) if 'sd-ti-embedding' in embed_image.text: - data = embeddingFromB64(embed_image.text['sd-ti-embedding']) + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) name = data.get('name',name) else: - data = extractImageDataEmbed(embed_image) + data = extract_image_data_embed(embed_image) name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -393,7 +280,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embeddingToB64(data)) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) title = "<{}>".format(data.get('name','???')) checkpoint = sd_models.select_checkpoint() @@ -401,8 +288,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_mid = '[{}]'.format(checkpoint.hash) footer_right = '{}'.format(embedding.step) - captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) - captioned_image = insertImageDataEmbed(captioned_image,data) + captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right) + captioned_image = insert_image_data_embed(captioned_image,data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) -- cgit v1.2.3 From db71290d2659d3b58ff9b57a82e4721a9eab9229 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:55:54 +0100 Subject: remove old caption method --- modules/textual_inversion/image_embedding.py | 39 ++-------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 6ad39602..c67028a5 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -117,37 +117,6 @@ def extract_image_data_embed(image): data = zlib.decompress(data_block) return json.loads(data,cls=EmbeddingDecoder) -def addCaptionLines(lines,image,initialx,textfont): - draw = ImageDraw.Draw(image) - hstart =initialx - for fill,line in lines: - fontsize = 32 - font = ImageFont.truetype(textfont, fontsize) - _,_,w, h = draw.textbbox((0,0),line,font=font) - fontsize = min( int(fontsize * ((image.size[0]-35)/w) ), 28) - font = ImageFont.truetype(textfont, fontsize) - _,_,w,h = draw.textbbox((0,0),line,font=font) - draw.text(((image.size[0]-w)/2,hstart), line, font=font, fill=fill) - hstart += h - return hstart - -def caption_image(image,prelines,postlines,background=(51, 51, 51),font=None): - if font is None: - try: - font = ImageFont.truetype(opts.font or Roboto, fontsize) - font = opts.font or Roboto - except Exception: - font = Roboto - - sample_image = image - background = Image.new("RGBA", (sample_image.size[0],sample_image.size[1]+1024), background) - hoffset = addCaptionLines(prelines,background,5,font)+16 - background.paste(sample_image,(0,hoffset)) - hoffset = hoffset+sample_image.size[1]+8 - hoffset = addCaptionLines(postlines,background,hoffset,font) - background = background.crop((0,0,sample_image.size[0],hoffset+8)) - return background - def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None): from math import cos @@ -195,11 +164,7 @@ def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfo return image if __name__ == '__main__': - - image = Image.new('RGBA',(512,512),(255,255,200,255)) - caption_image(image,[((255,255,255),'line a'),((255,255,255),'line b')], - [((255,255,255),'line c'),((255,255,255),'line d')]) - + image = Image.new('RGBA',(512,512),(255,255,200,255)) cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') @@ -231,4 +196,4 @@ if __name__ == '__main__': hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) - assert 12731374 == hunna_kay_random_sum \ No newline at end of file + assert 12731374 == hunna_kay_random_sum -- cgit v1.2.3 From d6fcc6b87bc00fcdecea276fe5b7c7945f7a8b14 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 22:03:05 +0300 Subject: apply lr schedule to hypernets --- modules/hypernetworks/hypernetwork.py | 19 ++++++++--- modules/textual_inversion/learn_schedule.py | 34 ++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 44 +++----------------------- modules/ui.py | 2 +- 4 files changed, 54 insertions(+), 45 deletions(-) create mode 100644 modules/textual_inversion/learn_schedule.py (limited to 'modules/textual_inversion') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5608e799..470659df 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,6 +14,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset +from modules.textual_inversion.learn_schedule import LearnSchedule class HypernetworkModule(torch.nn.Module): @@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, for weight in weights: weight.requires_grad = True - optimizer = torch.optim.AdamW(weights, lr=learn_rate) - losses = torch.zeros((32,)) last_saved_file = "Train an embedding; must specify a directory with a set of 1:1 ratio images
") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) - learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03") + learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) -- cgit v1.2.3 From aa75d5cfe8c84768b0f5d16f977ddba298677379 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:06:13 +0100 Subject: correct conflict resolution typo --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 22b4ae7f..789383ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -169,7 +169,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt) +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." -- cgit v1.2.3 From 91d7ee0d097a7ea203d261b570cd2b834837d9e2 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:09:10 +0100 Subject: update imports --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 789383ce..ff0a62b3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,6 +12,9 @@ from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64, + insert_image_data_embed,extract_image_data_embed, + caption_image_overlay ) class Embedding: def __init__(self, vec, name, step=None): -- cgit v1.2.3 From 5f3317376bb7952bc5145f05f16c1bbd466efc85 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:09:49 +0100 Subject: spacing --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ff0a62b3..485ef46c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,7 +12,7 @@ from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64, +from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64, insert_image_data_embed,extract_image_data_embed, caption_image_overlay ) -- cgit v1.2.3 From 7e6a6e00ad6f3b7ef43c8120db9ecac6e8d6bea5 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:20:46 +0100 Subject: Add files via upload --- modules/textual_inversion/test_embedding.png | Bin 0 -> 489220 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 modules/textual_inversion/test_embedding.png (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png new file mode 100644 index 00000000..07e2d9af Binary files /dev/null and b/modules/textual_inversion/test_embedding.png differ -- cgit v1.2.3 From 66ec505975aaa305a217fc27281ce368cbaef281 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:21:30 +0100 Subject: add file based test --- modules/textual_inversion/image_embedding.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index c67028a5..1224fb42 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -164,6 +164,14 @@ def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfo return image if __name__ == '__main__': + + testEmbed = Image.open('test_embedding.png') + + data = extract_image_data_embed(testEmbed) + assert data is not None + + data = embedding_from_b64(testEmbed.text['sd-ti-embedding']) + assert data is not None image = Image.new('RGBA',(512,512),(255,255,200,255)) cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') -- cgit v1.2.3 From 6be32b31d181e42c639dad3451229aa7b9cfd1cf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 23:07:09 +0300 Subject: reports that training with medvram is possible. --- modules/hypernetworks/ui.py | 2 +- modules/textual_inversion/ui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index c67facbb..dfa599af 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -25,7 +25,7 @@ def train_hypernetwork(*args): initial_hypernetwork = shared.loaded_hypernetwork - assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' try: sd_hijack.undo_optimizations() diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 70f47343..36881e7a 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -23,7 +23,7 @@ def preprocess(*args): def train_embedding(*args): - assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' try: sd_hijack.undo_optimizations() -- cgit v1.2.3 From f53f703aebc801c4204182d52bb1e0bef9808e1f Mon Sep 17 00:00:00 2001 From: JC_Array
Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entry[-1].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Train an embedding; must specify a directory with a set of 1:1 ratio images
") with gr.Row(): @@ -1327,7 +1337,9 @@ def create_ui(wrap_gradio_gpu_call): process_flip, process_split, process_caption, - process_caption_deepbooru + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, ], outputs=[ ti_output, -- cgit v1.2.3 From b69c37d25e4ffc56e8f8c247fa2c38b4648cefb7 Mon Sep 17 00:00:00 2001 From: guaneecTrain an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
") with gr.Row(): @@ -1368,7 +1380,11 @@ def create_ui(wrap_gradio_gpu_call): process_caption_deepbooru, process_split_threshold, process_overlap_ratio, - process_entropy_focus, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, ], outputs=[ ti_output, -- cgit v1.2.3 From 54f0c1482427a5b3f2248b97be55878e742cbcb1 Mon Sep 17 00:00:00 2001 From: captin411
Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -510,7 +512,14 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -512,14 +510,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -398,9 +396,6 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}