From 5841990b0df04906da7321beef6f7f7902b7d57b Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 05:38:38 +0100 Subject: Update textual_inversion.py --- modules/textual_inversion/textual_inversion.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cd9f3498..f6316020 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,6 +7,9 @@ import tqdm import html import datetime +from PIL import Image, PngImagePlugin +import base64 +from io import BytesIO from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -80,7 +83,15 @@ class EmbeddingDatabase: def process_file(path, filename): name = os.path.splitext(filename)[0] - data = torch.load(path, map_location="cpu") + data = [] + + if filename.upper().endswith('.PNG'): + embed_image = Image.open(path) + if 'sd-embedding' in embed_image.text: + embeddingData = base64.b64decode(embed_image.text['sd-embedding']) + data = torch.load(BytesIO(embeddingData), map_location="cpu") + else: + data = torch.load(path, map_location="cpu") # textual inversion embeddings if 'string_to_param' in data: @@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -244,7 +255,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, image = processed.images[0] shared.state.current_image = image - image.save(last_saved_image) + + if save_image_with_stored_embedding: + info = PngImagePlugin.PngInfo() + info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read())) + image.save(last_saved_image, "PNG", pnginfo=info) + else: + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" -- cgit v1.2.3 From 03694e1f9915e34cf7d9a31073f1a1a9def2909f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 21:58:14 +0100 Subject: add embedding load and save from b64 json --- modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f6316020..1b7f8906 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,9 +7,11 @@ import tqdm import html import datetime -from PIL import Image, PngImagePlugin +from PIL import Image,PngImagePlugin +from ..images import captionImge +import numpy as np import base64 -from io import BytesIO +import json from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -87,9 +89,9 @@ class EmbeddingDatabase: if filename.upper().endswith('.PNG'): embed_image = Image.open(path) - if 'sd-embedding' in embed_image.text: - embeddingData = base64.b64decode(embed_image.text['sd-embedding']) - data = torch.load(BytesIO(embeddingData), map_location="cpu") + if 'sd-ti-embedding' in embed_image.text: + data = embeddingFromB64(embed_image.text['sd-ti-embedding']) + name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, if save_image_with_stored_embedding: info = PngImagePlugin.PngInfo() - info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read())) - image.save(last_saved_image, "PNG", pnginfo=info) + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embeddingToB64(data)) + + pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] + + caption_checkpoint_hash = data.get('sd_checkpoint','UNK') + caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK' + caption_stepcount = data.get('step',0) + caption_stepcount = caption_stepcount if caption_stepcount else 0 + + post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash, + caption_stepcount))] + captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) + captioned_image.save(last_saved_image, "PNG", pnginfo=info) else: image.save(last_saved_image) - - last_saved_image += f", prompt: {text}" shared.state.job_no = embedding.step -- cgit v1.2.3 From 969bd8256e5b4f1007d3cc653723d4ad50a92528 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:02:28 +0100 Subject: add alternate checkpoint hash source --- modules/textual_inversion/textual_inversion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1b7f8906..d7813084 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -265,8 +265,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - caption_checkpoint_hash = data.get('sd_checkpoint','UNK') - caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK' + caption_checkpoint_hash = data.get('sd_checkpoint') + if caption_checkpoint_hash is None: + caption_checkpoint_hash = data.get('hash') + caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN' + caption_stepcount = data.get('step',0) caption_stepcount = caption_stepcount if caption_stepcount else 0 -- cgit v1.2.3 From 5d12ec82d3e13f5ff4c55db2930e4e10aed7015a Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:05:09 +0100 Subject: add encoder and decoder classes --- modules/textual_inversion/textual_inversion.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d7813084..44d4e08b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -16,6 +16,27 @@ import json from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +class EmbeddingEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()} + return json.JSONEncoder.default(self, o) + +class EmbeddingDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + def object_hook(self, d): + if 'EMBEDDINGTENSOR' in d: + return torch.from_numpy(np.array(d['EMBEDDINGTENSOR'])) + return d + +def embeddingToB64(data): + d = json.dumps(data,cls=EmbeddingEncoder) + return base64.b64encode(d.encode()) + +def EmbeddingFromB64(data): + d = base64.b64decode(data) + return json.loads(d,cls=EmbeddingDecoder) class Embedding: def __init__(self, vec, name, step=None): -- cgit v1.2.3 From d0184b8f76ce492da699f1926f34b57cd095242e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:06:12 +0100 Subject: change json tensor key name --- modules/textual_inversion/textual_inversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44d4e08b..ae8d207d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -19,15 +19,15 @@ import modules.textual_inversion.dataset class EmbeddingEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.Tensor): - return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()} + return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} return json.JSONEncoder.default(self, o) class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) def object_hook(self, d): - if 'EMBEDDINGTENSOR' in d: - return torch.from_numpy(np.array(d['EMBEDDINGTENSOR'])) + if 'TORCHTENSOR' in d: + return torch.from_numpy(np.array(d['TORCHTENSOR'])) return d def embeddingToB64(data): -- cgit v1.2.3 From 66846105103cfc282434d0dc2102910160b7a633 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:06:42 +0100 Subject: correct case on embeddingFromB64 --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ae8d207d..d2b95fa3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -34,7 +34,7 @@ def embeddingToB64(data): d = json.dumps(data,cls=EmbeddingEncoder) return base64.b64encode(d.encode()) -def EmbeddingFromB64(data): +def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) -- cgit v1.2.3 From 96f1e6be59316ec640cab2435fa95b3688194906 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:14:50 +0100 Subject: source checkpoint hash from current checkpoint --- modules/textual_inversion/textual_inversion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d2b95fa3..b16fa84e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -286,10 +286,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - caption_checkpoint_hash = data.get('sd_checkpoint') - if caption_checkpoint_hash is None: - caption_checkpoint_hash = data.get('hash') - caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN' + checkpoint = sd_models.select_checkpoint() + caption_checkpoint_hash = checkpoint.hash caption_stepcount = data.get('step',0) caption_stepcount = caption_stepcount if caption_stepcount else 0 -- cgit v1.2.3 From 01fd9cf0d28d8b71a113ab1aa62accfe7f0d9c51 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:17:02 +0100 Subject: change source of step count --- modules/textual_inversion/textual_inversion.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b16fa84e..e4f339b8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -285,15 +285,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, info.add_text("sd-ti-embedding", embeddingToB64(data)) pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - checkpoint = sd_models.select_checkpoint() - caption_checkpoint_hash = checkpoint.hash - - caption_stepcount = data.get('step',0) - caption_stepcount = caption_stepcount if caption_stepcount else 0 - - post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash, - caption_stepcount))] + post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash, + embedding.step))] captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) captioned_image.save(last_saved_image, "PNG", pnginfo=info) else: -- cgit v1.2.3 From d6a599ef9ba18a66ae79b50f2945af5788fdda8f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:07:52 +0100 Subject: change caption method --- modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e4f339b8..21596e78 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -8,7 +8,7 @@ import html import datetime from PIL import Image,PngImagePlugin -from ..images import captionImge +from ..images import captionImageOverlay import numpy as np import base64 import json @@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, else: images_dir = None + if create_image_every > 0 and save_image_with_stored_embedding: + images_embeds_dir = os.path.join(log_directory, "image_embeddings") + os.makedirs(images_embeds_dir, exist_ok=True) + else: + images_embeds_dir = None + cond_model = shared.sd_model.cond_stage_model shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." @@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, shared.state.current_image = image - if save_image_with_stored_embedding: + if save_image_with_stored_embedding and os.path.exists(last_saved_file): + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') + info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embeddingToB64(data)) - pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] + title = "<{}>".format(data.get('name','???')) checkpoint = sd_models.select_checkpoint() - post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash, - embedding.step))] - captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) - captioned_image.save(last_saved_image, "PNG", pnginfo=info) - else: - image.save(last_saved_image) + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '[{}]'.format(embedding.step) + + captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + + image.save(last_saved_image) last_saved_image += f", prompt: {text}" -- cgit v1.2.3 From e2c2925eb4d634b186de2c76798162ec56e2f869 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:12:53 +0100 Subject: remove braces from steps --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 21596e78..9a18ee5c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -297,7 +297,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '[{}]'.format(embedding.step) + footer_right = '{}'.format(embedding.step) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) -- cgit v1.2.3 From 707a431100362645e914042bb344d08439f48ac8 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 15:34:49 +0100 Subject: add pixel data footer --- modules/textual_inversion/textual_inversion.py | 48 ++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7a24192e..6fb64691 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,6 +12,7 @@ from ..images import captionImageOverlay import numpy as np import base64 import json +import zlib from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -20,7 +21,7 @@ class EmbeddingEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.Tensor): return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} - return json.JSONEncoder.default(self, o) + return json.JSONEncoder.default(self, obj) class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): @@ -38,6 +39,45 @@ def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) +def appendImageDataFooter(image,data): + d = 3 + data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) + dnp = np.frombuffer(data_compressed,np.uint8).copy() + w = image.size[0] + next_size = dnp.shape[0] + (w-(dnp.shape[0]%w)) + next_size = next_size + ((w*d)-(next_size%(w*d))) + dnp.resize(next_size) + dnp = dnp.reshape((-1,w,d)) + print(dnp.shape) + im = Image.fromarray(dnp,mode='RGB') + background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0)) + background.paste(image,(0,0)) + background.paste(im,(0,image.size[1]+1)) + return background + +def crop_black(img,tol=0): + mask = (img>tol).all(2) + mask0,mask1 = mask.any(0),mask.any(1) + col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() + row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() + return img[row_start:row_end,col_start:col_end] + +def extractImageDataFooter(image): + d=3 + outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) + lastRow = np.where( np.sum(outarr, axis=(1,2))==0) + if lastRow[0].shape[0] == 0: + print('Image data block not found.') + return None + lastRow = lastRow[0] + + lastRow = lastRow.max() + + dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes() + print(lastRow) + data = zlib.decompress(dataBlock) + return json.loads(data,cls=EmbeddingDecoder) + class Embedding: def __init__(self, vec, name, step=None): self.vec = vec @@ -113,6 +153,9 @@ class EmbeddingDatabase: if 'sd-ti-embedding' in embed_image.text: data = embeddingFromB64(embed_image.text['sd-ti-embedding']) name = data.get('name',name) + else: + data = extractImageDataFooter(embed_image) + name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -190,7 +233,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -308,6 +351,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_right = '{}'.format(embedding.step) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) + captioned_image = appendImageDataFooter(captioned_image,data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) -- cgit v1.2.3 From df6d0d9286279c41c4c67460c3158fa268697524 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 15:43:09 +0100 Subject: convert back to rgb as some hosts add alpha --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6fb64691..667a7cf2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -64,7 +64,7 @@ def crop_black(img,tol=0): def extractImageDataFooter(image): d=3 - outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) lastRow = np.where( np.sum(outarr, axis=(1,2))==0) if lastRow[0].shape[0] == 0: print('Image data block not found.') -- cgit v1.2.3 From 315d5a8ed975c88f670bc484f40a23fbf3a77b63 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:14:44 +0100 Subject: update data dis[play style --- modules/textual_inversion/textual_inversion.py | 88 +++++++++++++++++++------- 1 file changed, 65 insertions(+), 23 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 667a7cf2..95eebea7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -39,20 +39,59 @@ def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) -def appendImageDataFooter(image,data): +def xorBlock(block): + return np.bitwise_xor(block.astype(np.uint8), + ((np.random.RandomState(0xDEADBEEF).random(block.shape)*255).astype(np.uint8)) & 0x0F ) + +def styleBlock(block,sequence): + im = Image.new('RGB',(block.shape[1],block.shape[0])) + draw = ImageDraw.Draw(im) + i=0 + for x in range(-6,im.size[0],8): + for yi,y in enumerate(range(-6,im.size[1],8)): + offset=0 + if yi%2==0: + offset=4 + shade = sequence[i%len(sequence)] + i+=1 + draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) + + fg = np.array(im).astype(np.uint8) & 0xF0 + return block ^ fg + +def insertImageDataEmbed(image,data): d = 3 data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) dnp = np.frombuffer(data_compressed,np.uint8).copy() - w = image.size[0] - next_size = dnp.shape[0] + (w-(dnp.shape[0]%w)) - next_size = next_size + ((w*d)-(next_size%(w*d))) - dnp.resize(next_size) - dnp = dnp.reshape((-1,w,d)) - print(dnp.shape) - im = Image.fromarray(dnp,mode='RGB') - background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0)) - background.paste(image,(0,0)) - background.paste(im,(0,image.size[1]+1)) + dnphigh = dnp >> 4 + dnplow = dnp & 0x0F + + h = image.size[1] + next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h)) + next_size = next_size + ((h*d)-(next_size%(h*d))) + + dnplow.resize(next_size) + dnplow = dnplow.reshape((h,-1,d)) + + dnphigh.resize(next_size) + dnphigh = dnphigh.reshape((h,-1,d)) + + edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] + edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8) + + dnplow = styleBlock(dnplow,sequence=edgeStyleWeights) + dnplow = xorBlock(dnplow) + dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1]) + dnphigh = xorBlock(dnphigh) + + imlow = Image.fromarray(dnplow,mode='RGB') + imhigh = Image.fromarray(dnphigh,mode='RGB') + + background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0)) + background.paste(imlow,(0,0)) + background.paste(image,(imlow.size[0]+1,0)) + background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0)) + return background def crop_black(img,tol=0): @@ -62,19 +101,22 @@ def crop_black(img,tol=0): row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() return img[row_start:row_end,col_start:col_end] -def extractImageDataFooter(image): +def extractImageDataEmbed(image): d=3 - outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) - lastRow = np.where( np.sum(outarr, axis=(1,2))==0) - if lastRow[0].shape[0] == 0: - print('Image data block not found.') + outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F + blackCols = np.where( np.sum(outarr, axis=(0,2))==0) + if blackCols[0].shape[0] < 2: + print('No Image data blocks found.') return None - lastRow = lastRow[0] - - lastRow = lastRow.max() - dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes() - print(lastRow) + dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8) + dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8) + + dataBlocklower = xorBlock(dataBlocklower) + dataBlockupper = xorBlock(dataBlockupper) + + dataBlock = (dataBlockupper << 4) | (dataBlocklower) + dataBlock = dataBlock.flatten().tobytes() data = zlib.decompress(dataBlock) return json.loads(data,cls=EmbeddingDecoder) @@ -154,7 +196,7 @@ class EmbeddingDatabase: data = embeddingFromB64(embed_image.text['sd-ti-embedding']) name = data.get('name',name) else: - data = extractImageDataFooter(embed_image) + data = extractImageDataEmbed(embed_image) name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -351,7 +393,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_right = '{}'.format(embedding.step) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) - captioned_image = appendImageDataFooter(captioned_image,data) + captioned_image = insertImageDataEmbed(captioned_image,data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) -- cgit v1.2.3 From 767202a4c324f9b49f63ab4dabbb5736fe9df6e5 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:20:52 +0100 Subject: add dependency --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 95eebea7..f3cacaa0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,7 +7,7 @@ import tqdm import html import datetime -from PIL import Image,PngImagePlugin +from PIL import Image,PngImagePlugin,ImageDraw from ..images import captionImageOverlay import numpy as np import base64 -- cgit v1.2.3 From e0fbe6d27e7b4505766c8cb5a4264e1114cf3721 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:26:24 +0100 Subject: colour depth conversion fix --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f3cacaa0..ae807268 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -103,7 +103,7 @@ def crop_black(img,tol=0): def extractImageDataEmbed(image): d=3 - outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F blackCols = np.where( np.sum(outarr, axis=(0,2))==0) if blackCols[0].shape[0] < 2: print('No Image data blocks found.') -- cgit v1.2.3 From 7aa8fcac1e45c3ad9c6a40df0e44a346afcd5032 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 04:17:36 +0100 Subject: use simple lcg in xor --- modules/textual_inversion/textual_inversion.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ae807268..13416a08 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -39,9 +39,15 @@ def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) +def lcg(m=2**32, a=1664525, c=1013904223, seed=0): + while True: + seed = (a * seed + c) % m + yield seed + def xorBlock(block): - return np.bitwise_xor(block.astype(np.uint8), - ((np.random.RandomState(0xDEADBEEF).random(block.shape)*255).astype(np.uint8)) & 0x0F ) + g = lcg() + randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) + return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) def styleBlock(block,sequence): im = Image.new('RGB',(block.shape[1],block.shape[0])) -- cgit v1.2.3 From 61788c0538415fa9ca1dd1b306519c116b18bd2c Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:50:50 +0100 Subject: shift embedding logic out of textual_inversion --- modules/textual_inversion/textual_inversion.py | 125 ++----------------------- 1 file changed, 6 insertions(+), 119 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8c66aeb5..22b4ae7f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,124 +7,11 @@ import tqdm import html import datetime -from PIL import Image,PngImagePlugin,ImageDraw -from ..images import captionImageOverlay -import numpy as np -import base64 -import json -import zlib +from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -class EmbeddingEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, torch.Tensor): - return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} - return json.JSONEncoder.default(self, obj) - -class EmbeddingDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) - def object_hook(self, d): - if 'TORCHTENSOR' in d: - return torch.from_numpy(np.array(d['TORCHTENSOR'])) - return d - -def embeddingToB64(data): - d = json.dumps(data,cls=EmbeddingEncoder) - return base64.b64encode(d.encode()) - -def embeddingFromB64(data): - d = base64.b64decode(data) - return json.loads(d,cls=EmbeddingDecoder) - -def lcg(m=2**32, a=1664525, c=1013904223, seed=0): - while True: - seed = (a * seed + c) % m - yield seed - -def xorBlock(block): - g = lcg() - randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) - return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) - -def styleBlock(block,sequence): - im = Image.new('RGB',(block.shape[1],block.shape[0])) - draw = ImageDraw.Draw(im) - i=0 - for x in range(-6,im.size[0],8): - for yi,y in enumerate(range(-6,im.size[1],8)): - offset=0 - if yi%2==0: - offset=4 - shade = sequence[i%len(sequence)] - i+=1 - draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) - - fg = np.array(im).astype(np.uint8) & 0xF0 - return block ^ fg - -def insertImageDataEmbed(image,data): - d = 3 - data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) - dnp = np.frombuffer(data_compressed,np.uint8).copy() - dnphigh = dnp >> 4 - dnplow = dnp & 0x0F - - h = image.size[1] - next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h)) - next_size = next_size + ((h*d)-(next_size%(h*d))) - - dnplow.resize(next_size) - dnplow = dnplow.reshape((h,-1,d)) - - dnphigh.resize(next_size) - dnphigh = dnphigh.reshape((h,-1,d)) - - edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] - edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8) - - dnplow = styleBlock(dnplow,sequence=edgeStyleWeights) - dnplow = xorBlock(dnplow) - dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1]) - dnphigh = xorBlock(dnphigh) - - imlow = Image.fromarray(dnplow,mode='RGB') - imhigh = Image.fromarray(dnphigh,mode='RGB') - - background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0)) - background.paste(imlow,(0,0)) - background.paste(image,(imlow.size[0]+1,0)) - background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0)) - - return background - -def crop_black(img,tol=0): - mask = (img>tol).all(2) - mask0,mask1 = mask.any(0),mask.any(1) - col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() - row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() - return img[row_start:row_end,col_start:col_end] - -def extractImageDataEmbed(image): - d=3 - outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F - blackCols = np.where( np.sum(outarr, axis=(0,2))==0) - if blackCols[0].shape[0] < 2: - print('No Image data blocks found.') - return None - - dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8) - dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8) - - dataBlocklower = xorBlock(dataBlocklower) - dataBlockupper = xorBlock(dataBlockupper) - - dataBlock = (dataBlockupper << 4) | (dataBlocklower) - dataBlock = dataBlock.flatten().tobytes() - data = zlib.decompress(dataBlock) - return json.loads(data,cls=EmbeddingDecoder) class Embedding: def __init__(self, vec, name, step=None): @@ -199,10 +86,10 @@ class EmbeddingDatabase: if filename.upper().endswith('.PNG'): embed_image = Image.open(path) if 'sd-ti-embedding' in embed_image.text: - data = embeddingFromB64(embed_image.text['sd-ti-embedding']) + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) name = data.get('name',name) else: - data = extractImageDataEmbed(embed_image) + data = extract_image_data_embed(embed_image) name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -393,7 +280,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embeddingToB64(data)) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) title = "<{}>".format(data.get('name','???')) checkpoint = sd_models.select_checkpoint() @@ -401,8 +288,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_mid = '[{}]'.format(checkpoint.hash) footer_right = '{}'.format(embedding.step) - captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) - captioned_image = insertImageDataEmbed(captioned_image,data) + captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right) + captioned_image = insert_image_data_embed(captioned_image,data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) -- cgit v1.2.3 From aa75d5cfe8c84768b0f5d16f977ddba298677379 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:06:13 +0100 Subject: correct conflict resolution typo --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 22b4ae7f..789383ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -169,7 +169,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt) +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." -- cgit v1.2.3 From 91d7ee0d097a7ea203d261b570cd2b834837d9e2 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:09:10 +0100 Subject: update imports --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 789383ce..ff0a62b3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,6 +12,9 @@ from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64, + insert_image_data_embed,extract_image_data_embed, + caption_image_overlay ) class Embedding: def __init__(self, vec, name, step=None): -- cgit v1.2.3 From 5f3317376bb7952bc5145f05f16c1bbd466efc85 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:09:49 +0100 Subject: spacing --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ff0a62b3..485ef46c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,7 +12,7 @@ from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64, +from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64, insert_image_data_embed,extract_image_data_embed, caption_image_overlay ) -- cgit v1.2.3 From 10a2de644f8ea4cfade88e85d768da3480f4c9f0 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 12 Oct 2022 13:15:35 +0100 Subject: formatting --- modules/textual_inversion/textual_inversion.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 485ef46c..b072d745 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,14 +7,14 @@ import tqdm import html import datetime -from PIL import Image,PngImagePlugin +from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64, - insert_image_data_embed,extract_image_data_embed, - caption_image_overlay ) +from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, + insert_image_data_embed, extract_image_data_embed, + caption_image_overlay) class Embedding: def __init__(self, vec, name, step=None): @@ -90,10 +90,10 @@ class EmbeddingDatabase: embed_image = Image.open(path) if 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name',name) + name = data.get('name', name) else: data = extract_image_data_embed(embed_image) - name = data.get('name',name) + name = data.get('name', name) else: data = torch.load(path, map_location="cpu") @@ -278,24 +278,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.current_image = image if save_image_with_stored_embedding and os.path.exists(last_saved_file): - + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embedding_to_b64(data)) - title = "<{}>".format(data.get('name','???')) + title = "<{}>".format(data.get('name', '???')) checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) footer_right = '{}'.format(embedding.step) - captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right) - captioned_image = insert_image_data_embed(captioned_image,data) + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - + image.save(last_saved_image) last_saved_image += f", prompt: {preview_text}" -- cgit v1.2.3 From c3c8eef9fd5a0c8b26319e32ca4a19b56204e6df Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 20:49:47 +0300 Subject: train: change filename processing to be more simple and configurable train: make it possible to make text files with prompts train: rework scheduler so that there's less repeating code in textual inversion and hypernets train: move epochs setting to options --- javascript/hints.js | 3 ++ modules/hypernetworks/hypernetwork.py | 40 +++++++++------------- modules/shared.py | 3 ++ modules/textual_inversion/dataset.py | 47 +++++++++++++++++++------- modules/textual_inversion/learn_schedule.py | 37 +++++++++++++++++++- modules/textual_inversion/textual_inversion.py | 35 +++++++------------ modules/ui.py | 2 -- 7 files changed, 105 insertions(+), 62 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/javascript/hints.js b/javascript/hints.js index b81c181b..d51ee14c 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -81,6 +81,9 @@ titles = { "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", + + "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", + "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", } diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8314450a..b6c06d49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,7 +14,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset -from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): @@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(schedules) - print(f'Training at rate of {learn_rate} until step {end_step}') - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text, cond) in pbar: + for i, entry in pbar: hypernetwork.step = i + ititial_step - if hypernetwork.step > end_step: - try: - (learn_rate, end_step) = next(schedules) - except Exception: - break - tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') - for pg in optimizer.param_groups: - pg['lr'] = learn_rate + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break if shared.state.interrupted: break with torch.autocast("cuda"): - cond = cond.to(devices.device) - x = x.to(devices.device) + cond = entry.cond.to(devices.device) + x = entry.latent.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x del cond @@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - preview_text = text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ) processed = processing.process_images(p) - image = processed.images[0] + image = processed.images[0] if len(processed.images)>0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step @@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}