diff options
author | DepFA <35278260+dfaker@users.noreply.github.com> | 2022-10-10 14:34:49 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-10 14:34:49 +0000 |
commit | 707a431100362645e914042bb344d08439f48ac8 (patch) | |
tree | c0840a792c28a005dadeef07e6f71e43e73c5872 /modules/textual_inversion/textual_inversion.py | |
parent | ce2d7f7eaccbd1843835ca2d048d78ba5cb1ea13 (diff) | |
download | stable-diffusion-webui-gfx803-707a431100362645e914042bb344d08439f48ac8.tar.gz stable-diffusion-webui-gfx803-707a431100362645e914042bb344d08439f48ac8.tar.bz2 stable-diffusion-webui-gfx803-707a431100362645e914042bb344d08439f48ac8.zip |
add pixel data footer
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 48 |
1 files changed, 46 insertions, 2 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7a24192e..6fb64691 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,6 +12,7 @@ from ..images import captionImageOverlay import numpy as np
import base64
import json
+import zlib
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -20,7 +21,7 @@ class EmbeddingEncoder(json.JSONEncoder): def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
- return json.JSONEncoder.default(self, o)
+ return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
@@ -38,6 +39,45 @@ def embeddingFromB64(data): d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
+def appendImageDataFooter(image,data):
+ d = 3
+ data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
+ dnp = np.frombuffer(data_compressed,np.uint8).copy()
+ w = image.size[0]
+ next_size = dnp.shape[0] + (w-(dnp.shape[0]%w))
+ next_size = next_size + ((w*d)-(next_size%(w*d)))
+ dnp.resize(next_size)
+ dnp = dnp.reshape((-1,w,d))
+ print(dnp.shape)
+ im = Image.fromarray(dnp,mode='RGB')
+ background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0))
+ background.paste(image,(0,0))
+ background.paste(im,(0,image.size[1]+1))
+ return background
+
+def crop_black(img,tol=0):
+ mask = (img>tol).all(2)
+ mask0,mask1 = mask.any(0),mask.any(1)
+ col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
+ row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end,col_start:col_end]
+
+def extractImageDataFooter(image):
+ d=3
+ outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
+ lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
+ if lastRow[0].shape[0] == 0:
+ print('Image data block not found.')
+ return None
+ lastRow = lastRow[0]
+
+ lastRow = lastRow.max()
+
+ dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes()
+ print(lastRow)
+ data = zlib.decompress(dataBlock)
+ return json.loads(data,cls=EmbeddingDecoder)
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
@@ -113,6 +153,9 @@ class EmbeddingDatabase: if 'sd-ti-embedding' in embed_image.text:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
+ else:
+ data = extractImageDataFooter(embed_image)
+ name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@@ -190,7 +233,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -308,6 +351,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
+ captioned_image = appendImageDataFooter(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|