From 1618df41bad092e068c61bf510b1e20856821ad5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Fri, 28 Oct 2022 10:31:27 +0700
Subject: Gradient clipping for textual embedding
---
modules/textual_inversion/textual_inversion.py | 11 ++++++++++-
1 file changed, 10 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ff002d3e..7bad73a6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -206,7 +206,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -256,6 +256,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if ititial_step > steps:
return embedding, filename
+ clip_grad_mode_value = clip_grad_mode == "value"
+ clip_grad_mode_norm = clip_grad_mode == "norm"
+
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -280,6 +283,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer.zero_grad()
loss.backward()
+
+ if clip_grad_mode_value:
+ torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value)
+ elif clip_grad_mode_norm:
+ torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value)
+
optimizer.step()
--
cgit v1.2.3
From 16451ca573220e49f2eaaab97580b6b91287c8c4 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Fri, 28 Oct 2022 17:16:23 +0700
Subject: Learning rate sched syntax support for grad clipping
---
modules/hypernetworks/hypernetwork.py | 13 ++++++++++---
modules/textual_inversion/learn_schedule.py | 11 ++++++++---
modules/textual_inversion/textual_inversion.py | 12 +++++++++---
modules/ui.py | 7 +++----
4 files changed, 30 insertions(+), 13 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c5d60654..86532063 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -383,11 +383,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
-
+
clip_grad_mode_value = clip_grad_mode == "value"
clip_grad_mode_norm = clip_grad_mode == "norm"
+ clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
+ if clip_grad_enabled:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
@@ -407,6 +411,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if shared.state.interrupted:
break
+ if clip_grad_enabled:
+ clip_grad_sched.step(hypernetwork.step)
+
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
@@ -430,9 +437,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_value)
+ torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_value)
+ torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
optimizer.step()
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..ffec3e1b 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -51,14 +51,19 @@ class LearnRateScheduler:
self.finished = False
- def apply(self, optimizer, step_number):
+ def step(self, step_number):
if step_number <= self.end_step:
- return
+ return False
try:
(self.learn_rate, self.end_step) = next(self.schedules)
- except Exception:
+ except StopIteration:
self.finished = True
+ return False
+ return True
+
+ def apply(self, optimizer, step_number):
+ if not self.step(step_number):
return
if self.verbose:
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7bad73a6..6b00c6a1 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -255,9 +255,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
-
+
clip_grad_mode_value = clip_grad_mode == "value"
clip_grad_mode_norm = clip_grad_mode == "norm"
+ clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
+ if clip_grad_enabled:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -273,6 +276,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if shared.state.interrupted:
break
+ if clip_grad_enabled:
+ clip_grad_sched.step(embedding.step)
+
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
@@ -285,9 +291,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value)
+ torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value)
+ torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
optimizer.step()
diff --git a/modules/ui.py b/modules/ui.py
index 97de7da2..47d16429 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1305,7 +1305,9 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
-
+ with gr.Row():
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="1.0", show_label=False)
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1313,9 +1315,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
- with gr.Row():
- clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
- clip_grad_value = gr.Number(value=1.0, show_label=False)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
--
cgit v1.2.3
From cffc240a7327ae60671ff533469fc4ed4bf605de Mon Sep 17 00:00:00 2001
From: Nerogar
Date: Sun, 23 Oct 2022 14:05:25 +0200
Subject: fixed textual inversion training with inpainting models
---
modules/textual_inversion/textual_inversion.py | 27 +++++++++++++++++++++++++-
1 file changed, 26 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 0aeb0459..2630c7c9 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -224,6 +224,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+def create_dummy_mask(x, width=None, height=None):
+ if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ return image_conditioning
+
+
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@@ -286,6 +306,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = ""
embedding_yet_to_be_embedded = False
+ img_c = None
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
@@ -299,8 +320,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
+ if img_c is None:
+ img_c = create_dummy_mask(c, training_width, training_height)
+
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ loss = shared.sd_model(x, cond)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
--
cgit v1.2.3
From bb832d7725187f8a8ab44faa6ee1b38cb5f600aa Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 5 Nov 2022 11:48:38 +0700
Subject: Simplify grad clip
---
modules/hypernetworks/hypernetwork.py | 16 +++++++---------
modules/textual_inversion/textual_inversion.py | 16 +++++++---------
2 files changed, 14 insertions(+), 18 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f4c2668f..02b624e1 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -385,10 +385,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- clip_grad_mode_value = clip_grad_mode == "value"
- clip_grad_mode_norm = clip_grad_mode == "norm"
- clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
- if clip_grad_enabled:
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
@@ -433,7 +433,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if shared.state.interrupted:
break
- if clip_grad_enabled:
+ if clip_grad:
clip_grad_sched.step(hypernetwork.step)
with torch.autocast("cuda"):
@@ -458,10 +458,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
- if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
- elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
+ if clip_grad:
+ clip_grad(weights, clip_grad_sched.learn_rate)
optimizer.step()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c567ec3f..687d97bb 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -269,10 +269,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- clip_grad_mode_value = clip_grad_mode == "value"
- clip_grad_mode_norm = clip_grad_mode == "norm"
- clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
- if clip_grad_enabled:
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -302,7 +302,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if shared.state.interrupted:
break
- if clip_grad_enabled:
+ if clip_grad:
clip_grad_sched.step(embedding.step)
with torch.autocast("cuda"):
@@ -316,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer.zero_grad()
loss.backward()
- if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
- elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
optimizer.step()
--
cgit v1.2.3
From f55ac33d446185680604e872ceda2ae858821d5c Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Sat, 31 Dec 2022 11:27:02 -0500
Subject: validate textual inversion embeddings
---
modules/sd_models.py | 3 ++
modules/textual_inversion/textual_inversion.py | 43 +++++++++++++++++++++++---
modules/ui.py | 2 --
3 files changed, 41 insertions(+), 7 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ecdd91c5..ebd4dff7 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -325,6 +325,9 @@ def load_model(checkpoint_info=None):
script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.")
+
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model
+
return sd_model
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f6112578..103ace60 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -23,6 +23,8 @@ class Embedding:
self.vec = vec
self.name = name
self.step = step
+ self.shape = None
+ self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
@@ -57,8 +59,10 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
+ self.skipped_embeddings = []
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
+ self.expected_shape = -1
def register_embedding(self, embedding, model):
@@ -75,14 +79,35 @@ class EmbeddingDatabase:
return embedding
- def load_textual_inversion_embeddings(self):
+ def get_expected_shape(self):
+ expected_shape = -1 # initialize with unknown
+ idx = torch.tensor(0).to(shared.device)
+ if expected_shape == -1:
+ try: # matches sd15 signature
+ first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
+ expected_shape = first_embedding.shape[0]
+ except:
+ pass
+ if expected_shape == -1:
+ try: # matches sd20 signature
+ first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
+ expected_shape = first_embedding.shape[0]
+ except:
+ pass
+ if expected_shape == -1:
+ print('Could not determine expected embeddings shape from model')
+ return expected_shape
+
+ def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
+ if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
+ self.skipped_embeddings = []
+ self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
@@ -122,7 +147,14 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- self.register_embedding(embedding, shared.sd_model)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings.append(name)
+ # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os.listdir(self.embeddings_dir):
try:
@@ -137,8 +169,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr)
continue
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
- print("Embeddings:", ', '.join(self.word_embeddings.keys()))
+ print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
+ if (len(self.skipped_embeddings) > 0):
+ print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
diff --git a/modules/ui.py b/modules/ui.py
index 57ee0465..397dd804 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1157,8 +1157,6 @@ def create_ui():
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
-
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="See wiki for detailed explanation.
")
--
cgit v1.2.3
From bdbe09827b39be63c9c0b3636132ca58da38ebf6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 31 Dec 2022 22:49:09 +0300
Subject: changed embedding accepted shape detection to use existing code and
support the new alt-diffusion model, and reformatted messages a bit #6149
---
modules/textual_inversion/textual_inversion.py | 30 ++++++--------------------
1 file changed, 6 insertions(+), 24 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 103ace60..66f40367 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -80,23 +80,8 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
- expected_shape = -1 # initialize with unknown
- idx = torch.tensor(0).to(shared.device)
- if expected_shape == -1:
- try: # matches sd15 signature
- first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
- expected_shape = first_embedding.shape[0]
- except:
- pass
- if expected_shape == -1:
- try: # matches sd20 signature
- first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
- expected_shape = first_embedding.shape[0]
- except:
- pass
- if expected_shape == -1:
- print('Could not determine expected embeddings shape from model')
- return expected_shape
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
@@ -112,8 +97,6 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = []
-
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
@@ -150,11 +133,10 @@ class EmbeddingDatabase:
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
- if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings.append(name)
- # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os.listdir(self.embeddings_dir):
try:
@@ -169,9 +151,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr)
continue
- print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
- if (len(self.skipped_embeddings) > 0):
- print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
--
cgit v1.2.3
From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 00:38:09 +0300
Subject: fix the issue with training on SD2.0
---
modules/sd_models.py | 2 ++
modules/textual_inversion/textual_inversion.py | 3 +--
2 files changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ebd4dff7..bff8d6c9 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ model.logvar = model.logvar.to(devices.device) # fix for training
+
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 66f40367..1e5722e7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -282,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- # dataset loading may take a while, so input validations and early returns should be done before this
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -310,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0
_loss_step = 0 #internal
-
last_saved_file = ""
last_saved_image = ""
forced_filename = ""
--
cgit v1.2.3
From c65909ad16a1962129114c6251de092f49479b06 Mon Sep 17 00:00:00 2001
From: Philpax
Date: Mon, 2 Jan 2023 12:21:22 +1100
Subject: feat(api): return more data for embeddings
---
modules/api/api.py | 17 +++++++++++++++--
modules/api/models.py | 11 +++++++++--
modules/textual_inversion/textual_inversion.py | 8 ++++----
3 files changed, 28 insertions(+), 8 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 30bf3dac..9c670f00 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -330,9 +330,22 @@ class Api:
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
return {
- "loaded": sorted(db.word_embeddings.keys()),
- "skipped": sorted(db.skipped_embeddings),
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self):
diff --git a/modules/api/models.py b/modules/api/models.py
index a8472dc9..4a632c68 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,6 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
class EmbeddingsResponse(BaseModel):
- loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1e5722e7..fd253477 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -59,7 +59,7 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
- self.skipped_embeddings = []
+ self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1
@@ -91,7 +91,7 @@ class EmbeddingDatabase:
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
- self.skipped_embeddings = []
+ self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
@@ -136,7 +136,7 @@ class EmbeddingDatabase:
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
- self.skipped_embeddings.append(name)
+ self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
- print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
--
cgit v1.2.3
From bddebe09edeb6a18f2c06986d5658a7be3a563ea Mon Sep 17 00:00:00 2001
From: Shondoit
Date: Tue, 3 Jan 2023 10:26:37 +0100
Subject: Save Optimizer next to TI embedding
Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory)
---
modules/shared.py | 2 +-
modules/textual_inversion/textual_inversion.py | 40 ++++++++++++++++++++------
2 files changed, 33 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/shared.py b/modules/shared.py
index 23657a93..c541d18c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd253477..16176e90 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -28,6 +28,7 @@ class Embedding:
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -41,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -95,9 +103,10 @@ class EmbeddingDatabase:
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -105,8 +114,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
+
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- #if shared.opts.save_optimizer_state:
- #embedding.optimizer_state_dict = optimizer.state_dict()
- save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
@@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}
return embedding, filename
-def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint
--
cgit v1.2.3
From 192ddc04d6de0d780f73aa5fbaa8c66cd4642e1c Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Tue, 3 Jan 2023 10:34:51 -0500
Subject: add job info to modules
---
modules/extras.py | 17 +++++++++++++----
modules/hypernetworks/hypernetwork.py | 1 +
modules/textual_inversion/preprocess.py | 1 +
modules/textual_inversion/textual_inversion.py | 1 +
4 files changed, 16 insertions(+), 4 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/extras.py b/modules/extras.py
index 7e222313..d665440a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -58,6 +58,9 @@ cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
+ shared.state.begin()
+ shared.state.job = 'extras'
+
imageArr = []
# Also keep track of original file names
imageNameArr = []
@@ -94,6 +97,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
@@ -104,6 +108,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
@@ -114,6 +119,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
@@ -180,6 +186,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
+
+ shared.state.textinfo = f'Processing image {image_name}'
+
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -193,6 +202,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
basename = ''
+ if opts.enable_pnginfo: # append info before save
+ image.info = existing_pnginfo
+ image.info["extras"] = info
+
if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
@@ -203,10 +216,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
- if opts.enable_pnginfo:
- image.info = existing_pnginfo
- image.info["extras"] = info
-
if extras_mode != 2 or show_extras_results :
outputs.append(image)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 109e8078..450fecac 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -417,6 +417,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
+ shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 56b9b2eb..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd253477..2c1251d6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -245,6 +245,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
--
cgit v1.2.3
From 184e670126f5fc50ba56fa0fedcf0cf60e45ed7e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 17:45:01 +0300
Subject: fix the merge
---
modules/textual_inversion/textual_inversion.py | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5421a758..8731ea5d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -251,6 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+
def create_dummy_mask(x, width=None, height=None):
if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
@@ -380,17 +381,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
break
with devices.autocast():
- # c = stack_conds(batch.cond).to(devices.device)
- # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
- # print(mask)
- # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
-
-
- if img_c is None:
- img_c = create_dummy_mask(c, training_width, training_height)
-
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
+
+ if img_c is None:
+ img_c = create_dummy_mask(c, training_width, training_height)
+
cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x
--
cgit v1.2.3
From 525cea924562afd676f55470095268a0f6fca59e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 17:58:07 +0300
Subject: use shared function from processing for creating dummy mask when
training inpainting model
---
modules/processing.py | 39 +++++++++++++-------------
modules/textual_inversion/textual_inversion.py | 33 ++++++----------------
2 files changed, 29 insertions(+), 43 deletions(-)
(limited to 'modules/textual_inversion/textual_inversion.py')
diff --git a/modules/processing.py b/modules/processing.py
index c03e77e7..c7264aff 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
@@ -139,26 +157,9 @@ class StableDiffusionProcessing():
self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- self.is_using_inpainting_conditioning = True
-
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8731ea5d..2250e41b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -252,26 +252,6 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def create_dummy_mask(x, width=None, height=None):
- if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- else:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- return image_conditioning
-
-
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@@ -346,7 +326,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
else:
print("No saved optimizer exists in checkpoint")
-
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -362,7 +341,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
forced_filename = ""
embedding_yet_to_be_embedded = False
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
img_c = None
+
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@@ -384,10 +365,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
- if img_c is None:
- img_c = create_dummy_mask(c, training_width, training_height)
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
- cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x
--
cgit v1.2.3