From abeec4b63029c2c4151a78fc395d312113881845 Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 03:18:26 -0700 Subject: Add auto focal point cropping to Preprocess images This algorithm plots a bunch of points of interest on the source image and averages their locations to find a center. Most points come from OpenCV. One point comes from an entropy model. OpenCV points account for 50% of the weight and the entropy based point is the other 50%. The center of all weighted points is calculated and a bounding box is drawn as close to centered over that point as possible. --- modules/textual_inversion/preprocess.py | 151 ++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 5 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..168bfb09 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,7 @@ import os -from PIL import Image, ImageOps +import cv2 +import numpy as np +from PIL import Image, ImageOps, ImageDraw import platform import sys import tqdm @@ -11,7 +13,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): try: if process_caption: shared.interrogator.load() @@ -21,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, process_entropy_focus) finally: @@ -33,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -93,6 +95,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro is_tall = ratio > 1.35 is_wide = ratio < 1 / 1.35 + processing_option_ran = False + if process_split and is_tall: img = img.resize((width, height * img.height // img.width)) @@ -101,6 +105,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro bot = img.crop((0, img.height - height, width, img.height)) save_pic(bot, index) + + processing_option_ran = True elif process_split and is_wide: img = img.resize((width * img.width // img.height, height)) @@ -109,8 +115,143 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro right = img.crop((img.width - width, 0, img.width, height)) save_pic(right, index) - else: + + processing_option_ran = True + + if process_entropy_focus and (is_tall or is_wide): + if is_tall: + img = img.resize((width, height * img.height // img.width)) + else: + img = img.resize((width * img.width // img.height, height)) + + x_focal_center, y_focal_center = image_central_focal_point(img, width, height) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(height / 2) + x_half = int(width / 2) + + x1 = x_focal_center - x_half + if x1 < 0: + x1 = 0 + elif x1 + width > img.width: + x1 = img.width - width + + y1 = y_focal_center - y_half + if y1 < 0: + y1 = 0 + elif y1 + height > img.height: + y1 = img.height - height + + x2 = x1 + width + y2 = y1 + height + + crop = [x1, y1, x2, y2] + + focal = img.crop(tuple(crop)) + save_pic(focal, index) + + processing_option_ran = True + + if not processing_option_ran: img = images.resize_image(1, img, width, height) save_pic(img, index) shared.state.nextjob() + + +def image_central_focal_point(im, target_width, target_height): + focal_points = [] + + focal_points.extend( + image_focal_points(im) + ) + + fp_entropy = image_entropy_point(im, target_width, target_height) + fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy + + focal_points.append(fp_entropy) + + weight = 0.0 + x = 0.0 + y = 0.0 + for focal_point in focal_points: + weight += focal_point['weight'] + x += focal_point['x'] * focal_point['weight'] + y += focal_point['y'] * focal_point['weight'] + avg_x = round(x // weight) + avg_y = round(y // weight) + + return avg_x, avg_y + + +def image_focal_points(im): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=50, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.05, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append({ + 'x': x, + 'y': y, + 'weight': 1.0 + }) + + return focal_points + + +def image_entropy_point(im, crop_width, crop_height): + img = im.copy() + # just make it easier to slide the test crop with images oriented the same way + if (img.size[0] < img.size[1]): + portrait = True + img = img.rotate(90, expand=1) + + e_max = 0 + crop_current = [0, 0, crop_width, crop_height] + crop_best = crop_current + while crop_current[2] < img.size[0]: + crop = img.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e_max < e): + e_max = e + crop_best = list(crop_current) + + crop_current[0] += 4 + crop_current[2] += 4 + + x_mid = int((crop_best[2] - crop_best[0])/2) + y_mid = int((crop_best[3] - crop_best[1])/2) + + return { + 'x': x_mid, + 'y': y_mid, + 'weight': 1.0 + } + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("L")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + -- cgit v1.2.3 From 41e3877be2c667316515c86037413763eb0ba4da Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 13:44:59 -0700 Subject: fix entropy point calculation --- modules/textual_inversion/preprocess.py | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 168bfb09..7c1a594e 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -196,9 +196,9 @@ def image_focal_points(im): points = cv2.goodFeaturesToTrack( np_im, - maxCorners=50, + maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.05, + minDistance=min(grayscale.width, grayscale.height)*0.07, useHarrisDetector=False, ) @@ -218,28 +218,32 @@ def image_focal_points(im): def image_entropy_point(im, crop_width, crop_height): - img = im.copy() - # just make it easier to slide the test crop with images oriented the same way - if (img.size[0] < img.size[1]): - portrait = True - img = img.rotate(90, expand=1) + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] e_max = 0 crop_current = [0, 0, crop_width, crop_height] crop_best = crop_current - while crop_current[2] < img.size[0]: - crop = img.crop(tuple(crop_current)) + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) e = image_entropy(crop) - if (e_max < e): + if (e > e_max): e_max = e crop_best = list(crop_current) - crop_current[0] += 4 - crop_current[2] += 4 + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + crop_width/2) + y_mid = int(crop_best[1] + crop_height/2) - x_mid = int((crop_best[2] - crop_best[0])/2) - y_mid = int((crop_best[3] - crop_best[1])/2) return { 'x': x_mid, @@ -250,7 +254,7 @@ def image_entropy_point(im, crop_width, crop_height): def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1")) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -- cgit v1.2.3 From 59ed74438318af893d2cba552b0e28dbc2a9266c Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 17:19:02 -0700 Subject: face detection algo, configurability, reusability Try to move the crop in the direction of a face if it is present More internal configuration options for choosing weights of each of the algorithm's findings Move logic into its module --- modules/textual_inversion/autocrop.py | 216 ++++++++++++++++++++++++++++++++ modules/textual_inversion/preprocess.py | 150 +++------------------- 2 files changed, 230 insertions(+), 136 deletions(-) create mode 100644 modules/textual_inversion/autocrop.py (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py new file mode 100644 index 00000000..f858a958 --- /dev/null +++ b/modules/textual_inversion/autocrop.py @@ -0,0 +1,216 @@ +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + if im.height > im.width: + im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) + else: + im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + + average_point = poi_average(pois, settings, im=im) + + if settings.annotate_image: + d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') + + minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side + faces = classifier.detectMultiScale(gray, scaleFactor=1.05, + minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + + if len(faces) == 0: + return [] + + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + if settings.annotate_image: + for f in rects: + d = ImageDraw.Draw(im) + d.rectangle(f, outline=RED) + + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid)] + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("1")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings, im=None): + weight = 0.0 + x = 0.0 + y = 0.0 + for pois in pois: + if settings.annotate_image and im is not None: + w = 4 * 0.5 * sqrt(pois.weight) + d = ImageDraw.Draw(im) + d.ellipse([ + pois.x - w, pois.y - w, + pois.x + w, pois.y + w ], fill=BLUE) + weight += pois.weight + x += pois.x * pois.weight + y += pois.y * pois.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0): + self.x = x + self.y = y + self.weight = weight + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image + self.destop_view_image = False \ No newline at end of file diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 7c1a594e..0c79f012 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,7 +1,5 @@ import os -import cv2 -import numpy as np -from PIL import Image, ImageOps, ImageDraw +from PIL import Image, ImageOps import platform import sys import tqdm @@ -9,6 +7,7 @@ import time from modules import shared, images from modules.shared import opts, cmd_opts +from modules.textual_inversion import autocrop if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru @@ -80,6 +79,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if process_flip: save_pic_with_caption(ImageOps.mirror(image), index) + for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) @@ -118,37 +118,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro processing_option_ran = True - if process_entropy_focus and (is_tall or is_wide): - if is_tall: - img = img.resize((width, height * img.height // img.width)) - else: - img = img.resize((width * img.width // img.height, height)) - - x_focal_center, y_focal_center = image_central_focal_point(img, width, height) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(height / 2) - x_half = int(width / 2) - - x1 = x_focal_center - x_half - if x1 < 0: - x1 = 0 - elif x1 + width > img.width: - x1 = img.width - width - - y1 = y_focal_center - y_half - if y1 < 0: - y1 = 0 - elif y1 + height > img.height: - y1 = img.height - height - - x2 = x1 + width - y2 = y1 + height - - crop = [x1, y1, x2, y2] - - focal = img.crop(tuple(crop)) + if process_entropy_focus and img.height != img.width: + autocrop_settings = autocrop.Settings( + crop_width = width, + crop_height = height, + face_points_weight = 0.9, + entropy_points_weight = 0.7, + corner_points_weight = 0.5, + annotate_image = False + ) + focal = autocrop.crop_image(img, autocrop_settings) save_pic(focal, index) processing_option_ran = True @@ -157,105 +136,4 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro img = images.resize_image(1, img, width, height) save_pic(img, index) - shared.state.nextjob() - - -def image_central_focal_point(im, target_width, target_height): - focal_points = [] - - focal_points.extend( - image_focal_points(im) - ) - - fp_entropy = image_entropy_point(im, target_width, target_height) - fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy - - focal_points.append(fp_entropy) - - weight = 0.0 - x = 0.0 - y = 0.0 - for focal_point in focal_points: - weight += focal_point['weight'] - x += focal_point['x'] * focal_point['weight'] - y += focal_point['y'] * focal_point['weight'] - avg_x = round(x // weight) - avg_y = round(y // weight) - - return avg_x, avg_y - - -def image_focal_points(im): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append({ - 'x': x, - 'y': y, - 'weight': 1.0 - }) - - return focal_points - - -def image_entropy_point(im, crop_width, crop_height): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - - e_max = 0 - crop_current = [0, 0, crop_width, crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + crop_width/2) - y_mid = int(crop_best[1] + crop_height/2) - - - return { - 'x': x_mid, - 'y': y_mid, - 'weight': 1.0 - } - - -def image_entropy(im): - # greyscale image entropy - band = np.asarray(im.convert("1")) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - + shared.state.nextjob() \ No newline at end of file -- cgit v1.2.3 From 0ddaf8d2028a7251e8c4ad93551a43b5d4700841 Mon Sep 17 00:00:00 2001 From: captin411 Date: Thu, 20 Oct 2022 00:34:55 -0700 Subject: improve face detection a lot --- modules/textual_inversion/autocrop.py | 99 ++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 37 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index f858a958..5a551c25 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -8,12 +8,18 @@ GREEN = "#0F0" BLUE = "#00F" RED = "#F00" + def crop_image(im, settings): """ Intelligently crop an image to the subject matter """ if im.height > im.width: im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - else: + elif im.width > im.height: im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + else: + im = im.resize((settings.crop_width, settings.crop_height)) + + if im.height == im.width: + return im focus = focal_point(im, settings) @@ -78,13 +84,18 @@ def focal_point(im, settings): [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] ) - if settings.annotate_image: - d = ImageDraw.Draw(im) - - average_point = poi_average(pois, settings, im=im) + average_point = poi_average(pois, settings) if settings.annotate_image: - d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) return average_point @@ -92,22 +103,32 @@ def focal_point(im, settings): def image_face_points(im, settings): np_im = np.array(im) gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') - - minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side - faces = classifier.detectMultiScale(gray, scaleFactor=1.05, - minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - if len(faces) == 0: - return [] - - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - if settings.annotate_image: - for f in rects: - d = ImageDraw.Draw(im) - d.rectangle(f, outline=RED) - - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] def image_corner_points(im, settings): @@ -132,8 +153,8 @@ def image_corner_points(im, settings): focal_points = [] for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y)) + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) return focal_points @@ -167,31 +188,26 @@ def image_entropy_points(im, settings): x_mid = int(crop_best[0] + settings.crop_width/2) y_mid = int(crop_best[1] + settings.crop_height/2) - return [PointOfInterest(x_mid, y_mid)] + return [PointOfInterest(x_mid, y_mid, size=25)] def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("1")) + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -def poi_average(pois, settings, im=None): +def poi_average(pois, settings): weight = 0.0 x = 0.0 y = 0.0 - for pois in pois: - if settings.annotate_image and im is not None: - w = 4 * 0.5 * sqrt(pois.weight) - d = ImageDraw.Draw(im) - d.ellipse([ - pois.x - w, pois.y - w, - pois.x + w, pois.y + w ], fill=BLUE) - weight += pois.weight - x += pois.x * pois.weight - y += pois.y * pois.weight + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight avg_x = round(x / weight) avg_y = round(y / weight) @@ -199,10 +215,19 @@ def poi_average(pois, settings, im=None): class PointOfInterest: - def __init__(self, x, y, weight=1.0): + def __init__(self, x, y, weight=1.0, size=10): self.x = x self.y = y self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] class Settings: -- cgit v1.2.3 From 1be5933ba21a3badec42b7b2753d626f849b609d Mon Sep 17 00:00:00 2001 From: captin411 Date: Sun, 23 Oct 2022 04:11:07 -0700 Subject: auto cropping now works with non square crops --- modules/textual_inversion/autocrop.py | 509 ++++++++++++++++++---------------- 1 file changed, 269 insertions(+), 240 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 5a551c25..b2f9241c 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,241 +1,270 @@ -import cv2 -from collections import defaultdict -from math import log, sqrt -import numpy as np -from PIL import Image, ImageDraw - -GREEN = "#0F0" -BLUE = "#00F" -RED = "#F00" - - -def crop_image(im, settings): - """ Intelligently crop an image to the subject matter """ - if im.height > im.width: - im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - elif im.width > im.height: - im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) - else: - im = im.resize((settings.crop_width, settings.crop_height)) - - if im.height == im.width: - return im - - focus = focal_point(im, settings) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(settings.crop_height / 2) - x_half = int(settings.crop_width / 2) - - x1 = focus.x - x_half - if x1 < 0: - x1 = 0 - elif x1 + settings.crop_width > im.width: - x1 = im.width - settings.crop_width - - y1 = focus.y - y_half - if y1 < 0: - y1 = 0 - elif y1 + settings.crop_height > im.height: - y1 = im.height - settings.crop_height - - x2 = x1 + settings.crop_width - y2 = y1 + settings.crop_height - - crop = [x1, y1, x2, y2] - - if settings.annotate_image: - d = ImageDraw.Draw(im) - rect = list(crop) - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - if settings.destop_view_image: - im.show() - - return im.crop(tuple(crop)) - -def focal_point(im, settings): - corner_points = image_corner_points(im, settings) - entropy_points = image_entropy_points(im, settings) - face_points = image_face_points(im, settings) - - total_points = len(corner_points) + len(entropy_points) + len(face_points) - - corner_weight = settings.corner_points_weight - entropy_weight = settings.entropy_points_weight - face_weight = settings.face_points_weight - - weight_pref_total = corner_weight + entropy_weight + face_weight - - # weight things - pois = [] - if weight_pref_total == 0 or total_points == 0: - return pois - - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] - ) - - average_point = poi_average(pois, settings) - - if settings.annotate_image: - d = ImageDraw.Draw(im) - for f in face_points: - d.rectangle(f.bounding(f.size), outline=RED) - for f in entropy_points: - d.rectangle(f.bounding(30), outline=BLUE) - for poi in pois: - w = max(4, 4 * 0.5 * sqrt(poi.weight)) - d.ellipse(poi.bounding(w), fill=BLUE) - d.ellipse(average_point.bounding(25), outline=GREEN) - - return average_point - - -def image_face_points(im, settings): - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - - for t in tries: - # print(t[0]) - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: - continue - - if len(faces) > 0: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] - return [] - - -def image_corner_points(im, settings): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4)) - - return focal_points - - -def image_entropy_points(im, settings): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - else: - return [] - - e_max = 0 - crop_current = [0, 0, settings.crop_width, settings.crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + settings.crop_width/2) - y_mid = int(crop_best[1] + settings.crop_height/2) - - return [PointOfInterest(x_mid, y_mid, size=25)] - - -def image_entropy(im): - # greyscale image entropy - # band = np.asarray(im.convert("L")) - band = np.asarray(im.convert("1"), dtype=np.uint8) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - - -def poi_average(pois, settings): - weight = 0.0 - x = 0.0 - y = 0.0 - for poi in pois: - weight += poi.weight - x += poi.x * poi.weight - y += poi.y * poi.weight - avg_x = round(x / weight) - avg_y = round(y / weight) - - return PointOfInterest(avg_x, avg_y) - - -class PointOfInterest: - def __init__(self, x, y, weight=1.0, size=10): - self.x = x - self.y = y - self.weight = weight - self.size = size - - def bounding(self, size): - return [ - self.x - size//2, - self.y - size//2, - self.x + size//2, - self.y + size//2 - ] - - -class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): - self.crop_width = crop_width - self.crop_height = crop_height - self.corner_points_weight = corner_points_weight - self.entropy_points_weight = entropy_points_weight - self.face_points_weight = entropy_points_weight - self.annotate_image = annotate_image +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + + scale_by = 1 + if is_landscape(im.width, im.height): + scale_by = settings.crop_height / im.height + elif is_portrait(im.width, im.height): + scale_by = settings.crop_width / im.width + elif is_square(im.width, im.height): + if is_square(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_landscape(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_portrait(settings.crop_width, settings.crop_height): + scale_by = settings.crop_height / im.height + + im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + + if im.width == settings.crop_width and im.height == settings.crop_height: + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = [0, 0, im.width, im.height] + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + return im + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + average_point = poi_average(pois, settings) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid, size=25)] + + +def image_entropy(im): + # greyscale image entropy + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings): + weight = 0.0 + x = 0.0 + y = 0.0 + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +def is_landscape(w, h): + return w > h + + +def is_portrait(w, h): + return h > w + + +def is_square(w, h): + return w == h + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0, size=10): + self.x = x + self.y = y + self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image self.destop_view_image = False \ No newline at end of file -- cgit v1.2.3 From 3e6c2420c1177e9e79f2b566a5a7795b7416e34a Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 13:10:58 -0700 Subject: improve debug markers, fix algo weighting --- modules/textual_inversion/autocrop.py | 207 +++++++++++++++++++++------------- 1 file changed, 129 insertions(+), 78 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index b2f9241c..caaf18c8 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,4 +1,5 @@ import cv2 +import os from collections import defaultdict from math import log, sqrt import numpy as np @@ -26,19 +27,9 @@ def crop_image(im, settings): scale_by = settings.crop_height / im.height im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + im_debug = im.copy() - if im.width == settings.crop_width and im.height == settings.crop_height: - if settings.annotate_image: - d = ImageDraw.Draw(im) - rect = [0, 0, im.width, im.height] - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - if settings.destop_view_image: - im.show() - return im - - focus = focal_point(im, settings) + focus = focal_point(im_debug, settings) # take the focal point and turn it into crop coordinates that try to center over the focal # point but then get adjusted back into the frame @@ -62,89 +53,143 @@ def crop_image(im, settings): crop = [x1, y1, x2, y2] + results = [] + + results.append(im.crop(tuple(crop))) + if settings.annotate_image: - d = ImageDraw.Draw(im) + d = ImageDraw.Draw(im_debug) rect = list(crop) rect[2] -= 1 rect[3] -= 1 d.rectangle(rect, outline=GREEN) + results.append(im_debug) if settings.destop_view_image: - im.show() + im_debug.show() - return im.crop(tuple(crop)) + return results def focal_point(im, settings): corner_points = image_corner_points(im, settings) entropy_points = image_entropy_points(im, settings) face_points = image_face_points(im, settings) - total_points = len(corner_points) + len(entropy_points) + len(face_points) - - corner_weight = settings.corner_points_weight - entropy_weight = settings.entropy_points_weight - face_weight = settings.face_points_weight - - weight_pref_total = corner_weight + entropy_weight + face_weight - - # weight things pois = [] - if weight_pref_total == 0 or total_points == 0: - return pois - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] - ) + weight_pref_total = 0 + if len(corner_points) > 0: + weight_pref_total += settings.corner_points_weight + if len(entropy_points) > 0: + weight_pref_total += settings.entropy_points_weight + if len(face_points) > 0: + weight_pref_total += settings.face_points_weight + + corner_centroid = None + if len(corner_points) > 0: + corner_centroid = centroid(corner_points) + corner_centroid.weight = settings.corner_points_weight / weight_pref_total + pois.append(corner_centroid) + + entropy_centroid = None + if len(entropy_points) > 0: + entropy_centroid = centroid(entropy_points) + entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total + pois.append(entropy_centroid) + + face_centroid = None + if len(face_points) > 0: + face_centroid = centroid(face_points) + face_centroid.weight = settings.face_points_weight / weight_pref_total + pois.append(face_centroid) average_point = poi_average(pois, settings) if settings.annotate_image: d = ImageDraw.Draw(im) - for f in face_points: - d.rectangle(f.bounding(f.size), outline=RED) - for f in entropy_points: - d.rectangle(f.bounding(30), outline=BLUE) - for poi in pois: - w = max(4, 4 * 0.5 * sqrt(poi.weight)) - d.ellipse(poi.bounding(w), fill=BLUE) - d.ellipse(average_point.bounding(25), outline=GREEN) + max_size = min(im.width, im.height) * 0.07 + if corner_centroid is not None: + color = BLUE + box = corner_centroid.bounding(max_size * corner_centroid.weight) + d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(corner_points) > 1: + for f in corner_points: + d.rectangle(f.bounding(4), outline=color) + if entropy_centroid is not None: + color = "#ff0" + box = entropy_centroid.bounding(max_size * entropy_centroid.weight) + d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(entropy_points) > 1: + for f in entropy_points: + d.rectangle(f.bounding(4), outline=color) + if face_centroid is not None: + color = RED + box = face_centroid.bounding(max_size * face_centroid.weight) + d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(face_points) > 1: + for f in face_points: + d.rectangle(f.bounding(4), outline=color) + + d.ellipse(average_point.bounding(max_size), outline=GREEN) return average_point def image_face_points(im, settings): - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - - for t in tries: - # print(t[0]) - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: - continue - - if len(faces) > 0: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + if settings.dnn_model_path is not None: + detector = cv2.FaceDetectorYN.create( + settings.dnn_model_path, + "", + (im.width, im.height), + 0.8, # score threshold + 0.3, # nms threshold + 5000 # keep top k before nms + ) + faces = detector.detect(np.array(im)) + results = [] + if faces[1] is not None: + for face in faces[1]: + x = face[0] + y = face[1] + w = face[2] + h = face[3] + results.append( + PointOfInterest( + int(x + (w * 0.5)), # face focus left/right is center + int(y + (h * 0)), # face focus up/down is close to the top of the head + size = w, + weight = 1/len(faces[1]) + ) + ) + return results + else: + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + for t in tries: + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [] @@ -161,7 +206,7 @@ def image_corner_points(im, settings): np_im, maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, + minDistance=min(grayscale.width, grayscale.height)*0.03, useHarrisDetector=False, ) @@ -171,7 +216,7 @@ def image_corner_points(im, settings): focal_points = [] for point in points: x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4)) + focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) return focal_points @@ -205,17 +250,22 @@ def image_entropy_points(im, settings): x_mid = int(crop_best[0] + settings.crop_width/2) y_mid = int(crop_best[1] + settings.crop_height/2) - return [PointOfInterest(x_mid, y_mid, size=25)] + return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] def image_entropy(im): # greyscale image entropy - # band = np.asarray(im.convert("L")) - band = np.asarray(im.convert("1"), dtype=np.uint8) + band = np.asarray(im.convert("L")) + # band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() +def centroid(pois): + x = [poi.x for poi in pois] + y = [poi.y for poi in pois] + return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) + def poi_average(pois, settings): weight = 0.0 @@ -260,11 +310,12 @@ class PointOfInterest: class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): self.crop_width = crop_width self.crop_height = crop_height self.corner_points_weight = corner_points_weight self.entropy_points_weight = entropy_points_weight - self.face_points_weight = entropy_points_weight + self.face_points_weight = face_points_weight self.annotate_image = annotate_image - self.destop_view_image = False \ No newline at end of file + self.destop_view_image = False + self.dnn_model_path = dnn_model_path \ No newline at end of file -- cgit v1.2.3 From db8ed5fe5cd6e967d12d43d96b7f83083e58626c Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 15:22:29 -0700 Subject: Focal crop UI elements --- modules/textual_inversion/preprocess.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a8c17c6f..1e4d4de8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -13,7 +13,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False): +def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): try: if process_caption: shared.interrogator.load() @@ -23,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_entropy_focus) + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug) finally: @@ -35,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce -def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False): +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -139,27 +139,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ratio = (img.height * width) / (img.width * height) inverse_xy = True - processing_option_ran = False + process_default_resize = True if process_split and ratio < 1.0 and ratio <= split_threshold: for splitted in split_pic(img, inverse_xy): save_pic(splitted, index, existing_caption=existing_caption) - processing_option_ran = True + process_default_resize = False if process_entropy_focus and img.height != img.width: autocrop_settings = autocrop.Settings( crop_width = width, crop_height = height, - face_points_weight = 0.9, - entropy_points_weight = 0.7, - corner_points_weight = 0.5, - annotate_image = False + face_points_weight = process_focal_crop_face_weight, + entropy_points_weight = process_focal_crop_entropy_weight, + corner_points_weight = process_focal_crop_edges_weight, + annotate_image = process_focal_crop_debug ) - focal = autocrop.crop_image(img, autocrop_settings) - save_pic(focal, index, existing_caption=existing_caption) - processing_option_ran = True + for focal in autocrop.crop_image(img, autocrop_settings): + save_pic(focal, index, existing_caption=existing_caption) + process_default_resize = False - if not processing_option_ran: + if process_default_resize: img = images.resize_image(1, img, width, height) save_pic(img, index, existing_caption=existing_caption) -- cgit v1.2.3 From 54f0c1482427a5b3f2248b97be55878e742cbcb1 Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 16:14:13 -0700 Subject: download better face detection module dynamically --- modules/textual_inversion/autocrop.py | 20 ++++++++++++++++++++ modules/textual_inversion/preprocess.py | 13 +++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index caaf18c8..01a92b12 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,4 +1,5 @@ import cv2 +import requests import os from collections import defaultdict from math import log, sqrt @@ -293,6 +294,25 @@ def is_square(w, h): return w == h +def download_and_cache_models(dirname): + download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' + model_file_name = 'face_detection_yunet.onnx' + + if not os.path.exists(dirname): + os.makedirs(dirname) + + cache_file = os.path.join(dirname, model_file_name) + if not os.path.exists(cache_file): + print(f"downloading face detection model from '{download_url}' to '{cache_file}'") + response = requests.get(download_url) + with open(cache_file, "wb") as f: + f.write(response.content) + + if os.path.exists(cache_file): + return cache_file + return None + + class PointOfInterest: def __init__(self, x, y, weight=1.0, size=10): self.x = x diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 1e4d4de8..e13b1894 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,6 +7,7 @@ import tqdm import time from modules import shared, images +from modules.paths import models_path from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop if cmd_opts.deepdanbooru: @@ -146,14 +147,22 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre save_pic(splitted, index, existing_caption=existing_caption) process_default_resize = False - if process_entropy_focus and img.height != img.width: + if process_focal_crop and img.height != img.width: + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) + except Exception as e: + print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) + autocrop_settings = autocrop.Settings( crop_width = width, crop_height = height, face_points_weight = process_focal_crop_face_weight, entropy_points_weight = process_focal_crop_entropy_weight, corner_points_weight = process_focal_crop_edges_weight, - annotate_image = process_focal_crop_debug + annotate_image = process_focal_crop_debug, + dnn_model_path = dnn_model_path, ) for focal in autocrop.crop_image(img, autocrop_settings): save_pic(focal, index, existing_caption=existing_caption) -- cgit v1.2.3 From df0c5ea29d7f0c682ac81f184f3e482a6450d018 Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 17:06:59 -0700 Subject: update default weights --- modules/textual_inversion/autocrop.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 01a92b12..9859974a 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -71,9 +71,9 @@ def crop_image(im, settings): return results def focal_point(im, settings): - corner_points = image_corner_points(im, settings) - entropy_points = image_entropy_points(im, settings) - face_points = image_face_points(im, settings) + corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] + entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] + face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else [] pois = [] @@ -144,7 +144,7 @@ def image_face_points(im, settings): settings.dnn_model_path, "", (im.width, im.height), - 0.8, # score threshold + 0.9, # score threshold 0.3, # nms threshold 5000 # keep top k before nms ) @@ -159,7 +159,7 @@ def image_face_points(im, settings): results.append( PointOfInterest( int(x + (w * 0.5)), # face focus left/right is center - int(y + (h * 0)), # face focus up/down is close to the top of the head + int(y + (h * 0.33)), # face focus up/down is close to the top of the head size = w, weight = 1/len(faces[1]) ) @@ -207,7 +207,7 @@ def image_corner_points(im, settings): np_im, maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.03, + minDistance=min(grayscale.width, grayscale.height)*0.06, useHarrisDetector=False, ) @@ -256,8 +256,8 @@ def image_entropy_points(im, settings): def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("L")) - # band = np.asarray(im.convert("1"), dtype=np.uint8) + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -- cgit v1.2.3 From cbb857b675cf0f169b21515c29da492b513cc8c4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 26 Oct 2022 09:44:02 +0300 Subject: enable creating embedding with --medvram --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 529ed3e2..647ffe3e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings + with devices.autocast(): + cond_model([""]) # will send cond model to GPU if lowvram/medvram is active + ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) -- cgit v1.2.3 From c2dc9bfa89070b8e1d857f8773a790b752f1b709 Mon Sep 17 00:00:00 2001 From: timntorres Date: Mon, 24 Oct 2022 23:22:58 -0700 Subject: Implement PR #3189 but for embeddings. --- modules/textual_inversion/textual_inversion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 647ffe3e..22c7b54b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -10,7 +10,7 @@ import csv from PIL import Image, PngImagePlugin -from modules import shared, devices, sd_hijack, processing, sd_models +from modules import shared, devices, sd_hijack, processing, sd_models, images import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -247,6 +247,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc last_saved_file = "" last_saved_image = "" + forced_filename = "" embedding_yet_to_be_embedded = False ititial_step = embedding.step or 0 @@ -296,8 +297,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc }) if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - + forced_filename = f'{embedding_name}-{embedding.step}' + last_saved_image = os.path.join(images_dir, forced_filename) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, do_not_save_grid=True, @@ -353,8 +354,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) embedding_yet_to_be_embedded = False - image.save(last_saved_image) - + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step -- cgit v1.2.3 From 4875a6c217df5cc06ee2bf11fb645b172c7156a8 Mon Sep 17 00:00:00 2001 From: timntorres Date: Mon, 24 Oct 2022 23:38:07 -0700 Subject: Implement PR #3309 but for embeddings. --- modules/textual_inversion/textual_inversion.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 22c7b54b..4921bd01 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -167,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): for i in range(num_vectors_per_token): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" @@ -287,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: - last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{embedding.step}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') embedding.save(last_saved_file) embedding_yet_to_be_embedded = True @@ -374,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}
embedding.sd_checkpoint = checkpoint.hash embedding.sd_checkpoint_name = checkpoint.model_name embedding.cached_checksum = None + # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). + embedding.name = embedding_name + filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt') embedding.save(filename) return embedding, filename -- cgit v1.2.3 From f4e14642173a04723200b131deb417c6c79cab17 Mon Sep 17 00:00:00 2001 From: timntorres Date: Tue, 25 Oct 2022 00:04:25 -0700 Subject: Implement PR #3625 but for embeddings. --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4921bd01..4fcebe74 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -358,7 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) embedding_yet_to_be_embedded = False - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step -- cgit v1.2.3 From 737eb28faca8be2bb996ee0930ec77d1f7ebd939 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 26 Oct 2022 14:45:33 +0100 Subject: typo: cmd_opts.embedding_dir to cmd_opts.embeddings_dir --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4fcebe74..ff002d3e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -380,7 +380,7 @@ Last saved image: {html.escape(last_saved_image)}
embedding.cached_checksum = None # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). embedding.name = embedding_name - filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt') + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt') embedding.save(filename) return embedding, filename -- cgit v1.2.3 From a0a7024c679056dd66beb1832e52041b10143130 Mon Sep 17 00:00:00 2001 From: FlameLaw <116745066+FlameLaw@users.noreply.github.com> Date: Fri, 28 Oct 2022 02:13:48 +0900 Subject: Fix random dataset shuffle on TI --- modules/textual_inversion/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 5b1c5002..8bb00d27 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -86,12 +86,12 @@ class PersonalizedBase(Dataset): assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset) * repeats // batch_size - self.initial_indexes = np.arange(len(self.dataset)) + self.dataset_length = len(self.dataset) self.indexes = None self.shuffle() def shuffle(self): - self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] + self.indexes = np.random.permutation(self.dataset_length) def create_text(self, filename_text): text = random.choice(self.lines) -- cgit v1.2.3 From 9ceef81f77ecce89f0c8f412c4d849210d852e82 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Fri, 28 Oct 2022 20:48:08 +0700 Subject: Fix log off by 1 --- modules/textual_inversion/learn_schedule.py | 2 +- modules/textual_inversion/textual_inversion.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 2062726a..3a736065 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -52,7 +52,7 @@ class LearnRateScheduler: self.finished = False def apply(self, optimizer, step_number): - if step_number <= self.end_step: + if step_number < self.end_step: return try: diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ff002d3e..17dfb223 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return - if step % shared.opts.training_write_csv_every != 0: + if (step + 1) % shared.opts.training_write_csv_every != 0: return - write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True with open(os.path.join(log_directory, filename), "a+", newline='') as fout: @@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values): csv_writer.writeheader() epoch = step // epoch_len - epoch_step = step - epoch * epoch_len + epoch_step = step % epoch_len csv_writer.writerow({ "step": step + 1, - "epoch": epoch + 1, + "epoch": epoch, "epoch_step": epoch_step + 1, **values, }) @@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc loss.backward() optimizer.step() + steps_done = embedding.step + 1 epoch_num = embedding.step // len(ds) - epoch_step = embedding.step - (epoch_num * len(ds)) + 1 + epoch_step = embedding.step % len(ds) - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: + if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{embedding.step}' + embedding.name = f'{embedding_name}-{steps_done}' last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') embedding.save(last_saved_file) embedding_yet_to_be_embedded = True @@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc "learn_rate": scheduler.learn_rate }) - if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: - forced_filename = f'{embedding_name}-{embedding.step}' + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) @@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, embedding.step) + footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = insert_image_data_embed(captioned_image, data) -- cgit v1.2.3 From a5f3adbdd7d9b8245f7782216ac48913660e6bb5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 15:37:24 +0700 Subject: Allow trailing comma in learning rate --- modules/textual_inversion/learn_schedule.py | 33 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 13 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 3a736065..76e611b6 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -11,23 +11,30 @@ class LearnScheduleIterator: self.rates = [] self.it = 0 self.maxit = 0 - for i, pair in enumerate(pairs): - tmp = pair.split(':') - if len(tmp) == 2: - step = int(tmp[1]) - if step > cur_step: - self.rates.append((float(tmp[0]), min(step, max_steps))) - self.maxit += 1 - if step > max_steps: + try: + for i, pair in enumerate(pairs): + if not pair.strip(): + continue + tmp = pair.split(':') + if len(tmp) == 2: + step = int(tmp[1]) + if step > cur_step: + self.rates.append((float(tmp[0]), min(step, max_steps))) + self.maxit += 1 + if step > max_steps: + return + elif step == -1: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 return - elif step == -1: + else: self.rates.append((float(tmp[0]), max_steps)) self.maxit += 1 return - else: - self.rates.append((float(tmp[0]), max_steps)) - self.maxit += 1 - return + assert self.rates + except (ValueError, AssertionError): + raise Exception("Invalid learning rate schedule") + def __iter__(self): return self -- cgit v1.2.3 From ef4c94e1cfe66299227aa95a28c2380d21cb1600 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 15:42:51 +0700 Subject: Improve lr schedule error message --- modules/textual_inversion/learn_schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 76e611b6..dd0c0ad1 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -4,7 +4,7 @@ import tqdm class LearnScheduleIterator: def __init__(self, learn_rate, max_steps, cur_step=0): """ - specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 + specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 """ pairs = learn_rate.split(',') @@ -33,7 +33,7 @@ class LearnScheduleIterator: return assert self.rates except (ValueError, AssertionError): - raise Exception("Invalid learning rate schedule") + raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') def __iter__(self): -- cgit v1.2.3 From ab27c111d06ec920791c73eea25ad9a61671852e Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 18:09:17 +0700 Subject: Add input validations before loading dataset for training --- modules/textual_inversion/textual_inversion.py | 48 +++++++++++++++++++------- 1 file changed, 36 insertions(+), 12 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 17dfb223..44f06443 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): + assert model_name, f"{name} not selected" + assert learn_rate, "Learning rate is empty or 0" + assert isinstance(batch_size, int), "Batch size must be integer" + assert batch_size > 0, "Batch size must be positive" + assert data_root, "Dataset directory is empty" + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + assert template_file, "Prompt template file is empty" + assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert steps, "Max steps is empty or 0" + assert isinstance(steps, int), "Max steps must be integer" + assert steps > 0 , "Max steps must be positive" + assert isinstance(save_model_every, int), "Save {name} must be integer" + assert save_model_every >= 0 , "Save {name} must be positive or 0" + assert isinstance(create_image_every, int), "Create image must be integer" + assert create_image_every >= 0 , "Create image must be positive or 0" + if save_model_every or create_image_every: + assert log_directory, "Log directory is empty" def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - assert embedding_name, 'embedding not selected' + save_embedding_every = save_embedding_every or 0 + create_image_every = create_image_every or 0 + validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps @@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc os.makedirs(images_embeds_dir, exist_ok=True) else: images_embeds_dir = None - + cond_model = shared.sd_model.cond_stage_model + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + + ititial_step = embedding.step or 0 + if ititial_step > steps: + shared.state.textinfo = f"Model has already been trained beyond specified max steps" + return embedding, filename + + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) - hijack = sd_hijack.model_hijack - - embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding.vec.requires_grad = True + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) losses = torch.zeros((32,)) @@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc forced_filename = "" embedding_yet_to_be_embedded = False - ititial_step = embedding.step or 0 - if ititial_step > steps: - return embedding, filename - - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, entries in pbar: embedding.step = i + ititial_step -- cgit v1.2.3 From 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 19:43:21 +0700 Subject: Add cleanup after training --- modules/textual_inversion/textual_inversion.py | 185 +++++++++++++------------ 1 file changed, 95 insertions(+), 90 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44f06443..fd7f0897 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -283,111 +283,113 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, entries in pbar: - embedding.step = i + ititial_step - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - - losses[embedding.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - steps_done = embedding.step + 1 - - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) - - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) - embedding_yet_to_be_embedded = True - - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + try: + for i, entries in pbar: + embedding.step = i + ititial_step + + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step % len(ds) + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') + embedding.save(last_saved_file) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + p.width = training_width + p.height = training_height - preview_text = p.prompt + preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] + processed = processing.process_images(p) + image = processed.images[0] - shared.state.current_image = image + shared.state.current_image = image - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - title = "<{}>".format(data.get('name', '???')) + title = "<{}>".format(data.get('name', '???')) - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.job_no = embedding.step + shared.state.job_no = embedding.step - shared.state.textinfo = f""" + shared.state.textinfo = f"""

Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" + finally: + if embedding and embedding.vec is not None: + embedding.vec.requires_grad = False checkpoint = sd_models.select_checkpoint() -- cgit v1.2.3 From a27d19de2eff633b6a39f9f4a5c0f2d6abb81bb5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 19:44:05 +0700 Subject: Additional assert on dataset --- modules/textual_inversion/dataset.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 8bb00d27..ad726577 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -42,6 +42,8 @@ class PersonalizedBase(Dataset): self.lines = lines assert data_root, 'dataset directory not specified' + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" cond_model = shared.sd_model.cond_stage_model -- cgit v1.2.3 From ab05a74ead9fabb45dd099990e34061c7eb02ca3 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:32:02 +0700 Subject: Revert "Add cleanup after training" This reverts commit 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1. --- modules/textual_inversion/textual_inversion.py | 185 ++++++++++++------------- 1 file changed, 90 insertions(+), 95 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fd7f0897..44f06443 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -283,113 +283,111 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, entries in pbar: + embedding.step = i + ititial_step - try: - for i, entries in pbar: - embedding.step = i + ititial_step - - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - - losses[embedding.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - steps_done = embedding.step + 1 - - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) - - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) - embedding_yet_to_be_embedded = True - - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step % len(ds) + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') + embedding.save(last_saved_file) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + p.width = training_width + p.height = training_height - preview_text = p.prompt + preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] + processed = processing.process_images(p) + image = processed.images[0] - shared.state.current_image = image + shared.state.current_image = image - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - title = "<{}>".format(data.get('name', '???')) + title = "<{}>".format(data.get('name', '???')) - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.job_no = embedding.step + shared.state.job_no = embedding.step - shared.state.textinfo = f""" + shared.state.textinfo = f"""

Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -398,9 +396,6 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - finally: - if embedding and embedding.vec is not None: - embedding.vec.requires_grad = False checkpoint = sd_models.select_checkpoint() -- cgit v1.2.3 From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:49:29 +0700 Subject: Add missing info on hypernetwork/embedding model log Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513 Also group the saving into one --- modules/textual_inversion/textual_inversion.py | 39 +++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44f06443..ee9917ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -119,7 +119,7 @@ class EmbeddingDatabase: vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('hash', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) self.register_embedding(embedding, shared.sd_model) @@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] + checkpoint = sd_models.select_checkpoint() ititial_step = embedding.step or 0 if ititial_step > steps: @@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { @@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}

""" - checkpoint = sd_models.select_checkpoint() - - embedding.sd_checkpoint = checkpoint.hash - embedding.sd_checkpoint_name = checkpoint.model_name - embedding.cached_checksum = None - # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). - embedding.name = embedding_name - filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt') - embedding.save(filename) + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) return embedding, filename + +def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): + old_embedding_name = embedding.name + old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None + old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None + old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None + try: + embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint_name = checkpoint.model_name + if remove_cached_checksum: + embedding.cached_checksum = None + embedding.name = embedding_name + embedding.save(filename) + except: + embedding.sd_checkpoint = old_sd_checkpoint + embedding.sd_checkpoint_name = old_sd_checkpoint_name + embedding.name = old_embedding_name + embedding.cached_checksum = old_cached_checksum + raise -- cgit v1.2.3 From 3d58510f214c645ce5cdb261aa47df6573b239e9 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:54:59 +0700 Subject: Fix dataset still being loaded even when training will be skipped --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ee9917ce..e0babb46 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -262,7 +262,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc checkpoint = sd_models.select_checkpoint() ititial_step = embedding.step or 0 - if ititial_step > steps: + if ititial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return embedding, filename -- cgit v1.2.3