From abeec4b63029c2c4151a78fc395d312113881845 Mon Sep 17 00:00:00 2001
From: captin411
Date: Wed, 19 Oct 2022 03:18:26 -0700
Subject: Add auto focal point cropping to Preprocess images
This algorithm plots a bunch of points of interest on the source
image and averages their locations to find a center.
Most points come from OpenCV. One point comes from an
entropy model. OpenCV points account for 50% of the weight and the
entropy based point is the other 50%.
The center of all weighted points is calculated and a bounding box
is drawn as close to centered over that point as possible.
---
modules/textual_inversion/preprocess.py | 151 ++++++++++++++++++++++++++++++--
1 file changed, 146 insertions(+), 5 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..168bfb09 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,5 +1,7 @@
import os
-from PIL import Image, ImageOps
+import cv2
+import numpy as np
+from PIL import Image, ImageOps, ImageDraw
import platform
import sys
import tqdm
@@ -11,7 +13,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False):
try:
if process_caption:
shared.interrogator.load()
@@ -21,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, process_entropy_focus)
finally:
@@ -33,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -93,6 +95,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
is_tall = ratio > 1.35
is_wide = ratio < 1 / 1.35
+ processing_option_ran = False
+
if process_split and is_tall:
img = img.resize((width, height * img.height // img.width))
@@ -101,6 +105,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
+
+ processing_option_ran = True
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
@@ -109,8 +115,143 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
- else:
+
+ processing_option_ran = True
+
+ if process_entropy_focus and (is_tall or is_wide):
+ if is_tall:
+ img = img.resize((width, height * img.height // img.width))
+ else:
+ img = img.resize((width * img.width // img.height, height))
+
+ x_focal_center, y_focal_center = image_central_focal_point(img, width, height)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(height / 2)
+ x_half = int(width / 2)
+
+ x1 = x_focal_center - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + width > img.width:
+ x1 = img.width - width
+
+ y1 = y_focal_center - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + height > img.height:
+ y1 = img.height - height
+
+ x2 = x1 + width
+ y2 = y1 + height
+
+ crop = [x1, y1, x2, y2]
+
+ focal = img.crop(tuple(crop))
+ save_pic(focal, index)
+
+ processing_option_ran = True
+
+ if not processing_option_ran:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
+
+
+def image_central_focal_point(im, target_width, target_height):
+ focal_points = []
+
+ focal_points.extend(
+ image_focal_points(im)
+ )
+
+ fp_entropy = image_entropy_point(im, target_width, target_height)
+ fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy
+
+ focal_points.append(fp_entropy)
+
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for focal_point in focal_points:
+ weight += focal_point['weight']
+ x += focal_point['x'] * focal_point['weight']
+ y += focal_point['y'] * focal_point['weight']
+ avg_x = round(x // weight)
+ avg_y = round(y // weight)
+
+ return avg_x, avg_y
+
+
+def image_focal_points(im):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=50,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.05,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append({
+ 'x': x,
+ 'y': y,
+ 'weight': 1.0
+ })
+
+ return focal_points
+
+
+def image_entropy_point(im, crop_width, crop_height):
+ img = im.copy()
+ # just make it easier to slide the test crop with images oriented the same way
+ if (img.size[0] < img.size[1]):
+ portrait = True
+ img = img.rotate(90, expand=1)
+
+ e_max = 0
+ crop_current = [0, 0, crop_width, crop_height]
+ crop_best = crop_current
+ while crop_current[2] < img.size[0]:
+ crop = img.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e_max < e):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[0] += 4
+ crop_current[2] += 4
+
+ x_mid = int((crop_best[2] - crop_best[0])/2)
+ y_mid = int((crop_best[3] - crop_best[1])/2)
+
+ return {
+ 'x': x_mid,
+ 'y': y_mid,
+ 'weight': 1.0
+ }
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ band = np.asarray(im.convert("L"))
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
--
cgit v1.2.3
From 41e3877be2c667316515c86037413763eb0ba4da Mon Sep 17 00:00:00 2001
From: captin411
Date: Wed, 19 Oct 2022 13:44:59 -0700
Subject: fix entropy point calculation
---
modules/textual_inversion/preprocess.py | 34 ++++++++++++++++++---------------
1 file changed, 19 insertions(+), 15 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 168bfb09..7c1a594e 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -196,9 +196,9 @@ def image_focal_points(im):
points = cv2.goodFeaturesToTrack(
np_im,
- maxCorners=50,
+ maxCorners=100,
qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.05,
+ minDistance=min(grayscale.width, grayscale.height)*0.07,
useHarrisDetector=False,
)
@@ -218,28 +218,32 @@ def image_focal_points(im):
def image_entropy_point(im, crop_width, crop_height):
- img = im.copy()
- # just make it easier to slide the test crop with images oriented the same way
- if (img.size[0] < img.size[1]):
- portrait = True
- img = img.rotate(90, expand=1)
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
e_max = 0
crop_current = [0, 0, crop_width, crop_height]
crop_best = crop_current
- while crop_current[2] < img.size[0]:
- crop = img.crop(tuple(crop_current))
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
e = image_entropy(crop)
- if (e_max < e):
+ if (e > e_max):
e_max = e
crop_best = list(crop_current)
- crop_current[0] += 4
- crop_current[2] += 4
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + crop_width/2)
+ y_mid = int(crop_best[1] + crop_height/2)
- x_mid = int((crop_best[2] - crop_best[0])/2)
- y_mid = int((crop_best[3] - crop_best[1])/2)
return {
'x': x_mid,
@@ -250,7 +254,7 @@ def image_entropy_point(im, crop_width, crop_height):
def image_entropy(im):
# greyscale image entropy
- band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"))
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
--
cgit v1.2.3
From 59ed74438318af893d2cba552b0e28dbc2a9266c Mon Sep 17 00:00:00 2001
From: captin411
Date: Wed, 19 Oct 2022 17:19:02 -0700
Subject: face detection algo, configurability, reusability
Try to move the crop in the direction of a face if it is present
More internal configuration options for choosing weights of each of the algorithm's findings
Move logic into its module
---
modules/textual_inversion/autocrop.py | 216 ++++++++++++++++++++++++++++++++
modules/textual_inversion/preprocess.py | 150 +++-------------------
2 files changed, 230 insertions(+), 136 deletions(-)
create mode 100644 modules/textual_inversion/autocrop.py
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
new file mode 100644
index 00000000..f858a958
--- /dev/null
+++ b/modules/textual_inversion/autocrop.py
@@ -0,0 +1,216 @@
+import cv2
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+ if im.height > im.width:
+ im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width))
+ else:
+ im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height))
+
+ focus = focal_point(im, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ if settings.destop_view_image:
+ im.show()
+
+ return im.crop(tuple(crop))
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings)
+ entropy_points = image_entropy_points(im, settings)
+ face_points = image_face_points(im, settings)
+
+ total_points = len(corner_points) + len(entropy_points) + len(face_points)
+
+ corner_weight = settings.corner_points_weight
+ entropy_weight = settings.entropy_points_weight
+ face_weight = settings.face_points_weight
+
+ weight_pref_total = corner_weight + entropy_weight + face_weight
+
+ # weight things
+ pois = []
+ if weight_pref_total == 0 or total_points == 0:
+ return pois
+
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
+ )
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+
+ average_point = poi_average(pois, settings, im=im)
+
+ if settings.annotate_image:
+ d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+ classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml')
+
+ minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.05,
+ minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+
+ if len(faces) == 0:
+ return []
+
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ if settings.annotate_image:
+ for f in rects:
+ d = ImageDraw.Draw(im)
+ d.rectangle(f, outline=RED)
+
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects]
+
+
+def image_corner_points(im, settings):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.07,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e > e_max):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ band = np.asarray(im.convert("1"))
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+
+def poi_average(pois, settings, im=None):
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for pois in pois:
+ if settings.annotate_image and im is not None:
+ w = 4 * 0.5 * sqrt(pois.weight)
+ d = ImageDraw.Draw(im)
+ d.ellipse([
+ pois.x - w, pois.y - w,
+ pois.x + w, pois.y + w ], fill=BLUE)
+ weight += pois.weight
+ x += pois.x * pois.weight
+ y += pois.y * pois.weight
+ avg_x = round(x / weight)
+ avg_y = round(y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0):
+ self.x = x
+ self.y = y
+ self.weight = weight
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = entropy_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
\ No newline at end of file
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 7c1a594e..0c79f012 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,7 +1,5 @@
import os
-import cv2
-import numpy as np
-from PIL import Image, ImageOps, ImageDraw
+from PIL import Image, ImageOps
import platform
import sys
import tqdm
@@ -9,6 +7,7 @@ import time
from modules import shared, images
from modules.shared import opts, cmd_opts
+from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
@@ -80,6 +79,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
+
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
@@ -118,37 +118,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
processing_option_ran = True
- if process_entropy_focus and (is_tall or is_wide):
- if is_tall:
- img = img.resize((width, height * img.height // img.width))
- else:
- img = img.resize((width * img.width // img.height, height))
-
- x_focal_center, y_focal_center = image_central_focal_point(img, width, height)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(height / 2)
- x_half = int(width / 2)
-
- x1 = x_focal_center - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + width > img.width:
- x1 = img.width - width
-
- y1 = y_focal_center - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + height > img.height:
- y1 = img.height - height
-
- x2 = x1 + width
- y2 = y1 + height
-
- crop = [x1, y1, x2, y2]
-
- focal = img.crop(tuple(crop))
+ if process_entropy_focus and img.height != img.width:
+ autocrop_settings = autocrop.Settings(
+ crop_width = width,
+ crop_height = height,
+ face_points_weight = 0.9,
+ entropy_points_weight = 0.7,
+ corner_points_weight = 0.5,
+ annotate_image = False
+ )
+ focal = autocrop.crop_image(img, autocrop_settings)
save_pic(focal, index)
processing_option_ran = True
@@ -157,105 +136,4 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = images.resize_image(1, img, width, height)
save_pic(img, index)
- shared.state.nextjob()
-
-
-def image_central_focal_point(im, target_width, target_height):
- focal_points = []
-
- focal_points.extend(
- image_focal_points(im)
- )
-
- fp_entropy = image_entropy_point(im, target_width, target_height)
- fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy
-
- focal_points.append(fp_entropy)
-
- weight = 0.0
- x = 0.0
- y = 0.0
- for focal_point in focal_points:
- weight += focal_point['weight']
- x += focal_point['x'] * focal_point['weight']
- y += focal_point['y'] * focal_point['weight']
- avg_x = round(x // weight)
- avg_y = round(y // weight)
-
- return avg_x, avg_y
-
-
-def image_focal_points(im):
- grayscale = im.convert("L")
-
- # naive attempt at preventing focal points from collecting at watermarks near the bottom
- gd = ImageDraw.Draw(grayscale)
- gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
-
- np_im = np.array(grayscale)
-
- points = cv2.goodFeaturesToTrack(
- np_im,
- maxCorners=100,
- qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.07,
- useHarrisDetector=False,
- )
-
- if points is None:
- return []
-
- focal_points = []
- for point in points:
- x, y = point.ravel()
- focal_points.append({
- 'x': x,
- 'y': y,
- 'weight': 1.0
- })
-
- return focal_points
-
-
-def image_entropy_point(im, crop_width, crop_height):
- landscape = im.height < im.width
- portrait = im.height > im.width
- if landscape:
- move_idx = [0, 2]
- move_max = im.size[0]
- elif portrait:
- move_idx = [1, 3]
- move_max = im.size[1]
-
- e_max = 0
- crop_current = [0, 0, crop_width, crop_height]
- crop_best = crop_current
- while crop_current[move_idx[1]] < move_max:
- crop = im.crop(tuple(crop_current))
- e = image_entropy(crop)
-
- if (e > e_max):
- e_max = e
- crop_best = list(crop_current)
-
- crop_current[move_idx[0]] += 4
- crop_current[move_idx[1]] += 4
-
- x_mid = int(crop_best[0] + crop_width/2)
- y_mid = int(crop_best[1] + crop_height/2)
-
-
- return {
- 'x': x_mid,
- 'y': y_mid,
- 'weight': 1.0
- }
-
-
-def image_entropy(im):
- # greyscale image entropy
- band = np.asarray(im.convert("1"))
- hist, _ = np.histogram(band, bins=range(0, 256))
- hist = hist[hist > 0]
- return -np.log2(hist / hist.sum()).sum()
-
+ shared.state.nextjob()
\ No newline at end of file
--
cgit v1.2.3
From 0ddaf8d2028a7251e8c4ad93551a43b5d4700841 Mon Sep 17 00:00:00 2001
From: captin411
Date: Thu, 20 Oct 2022 00:34:55 -0700
Subject: improve face detection a lot
---
modules/textual_inversion/autocrop.py | 99 ++++++++++++++++++++++-------------
1 file changed, 62 insertions(+), 37 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index f858a958..5a551c25 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -8,12 +8,18 @@ GREEN = "#0F0"
BLUE = "#00F"
RED = "#F00"
+
def crop_image(im, settings):
""" Intelligently crop an image to the subject matter """
if im.height > im.width:
im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width))
- else:
+ elif im.width > im.height:
im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height))
+ else:
+ im = im.resize((settings.crop_width, settings.crop_height))
+
+ if im.height == im.width:
+ return im
focus = focal_point(im, settings)
@@ -78,13 +84,18 @@ def focal_point(im, settings):
[ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
)
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
-
- average_point = poi_average(pois, settings, im=im)
+ average_point = poi_average(pois, settings)
if settings.annotate_image:
- d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN)
+ d = ImageDraw.Draw(im)
+ for f in face_points:
+ d.rectangle(f.bounding(f.size), outline=RED)
+ for f in entropy_points:
+ d.rectangle(f.bounding(30), outline=BLUE)
+ for poi in pois:
+ w = max(4, 4 * 0.5 * sqrt(poi.weight))
+ d.ellipse(poi.bounding(w), fill=BLUE)
+ d.ellipse(average_point.bounding(25), outline=GREEN)
return average_point
@@ -92,22 +103,32 @@ def focal_point(im, settings):
def image_face_points(im, settings):
np_im = np.array(im)
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
- classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml')
-
- minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side
- faces = classifier.detectMultiScale(gray, scaleFactor=1.05,
- minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- if len(faces) == 0:
- return []
-
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- if settings.annotate_image:
- for f in rects:
- d = ImageDraw.Draw(im)
- d.rectangle(f, outline=RED)
-
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects]
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+
+ for t in tries:
+ # print(t[0])
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
+ return []
def image_corner_points(im, settings):
@@ -132,8 +153,8 @@ def image_corner_points(im, settings):
focal_points = []
for point in points:
- x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y))
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4))
return focal_points
@@ -167,31 +188,26 @@ def image_entropy_points(im, settings):
x_mid = int(crop_best[0] + settings.crop_width/2)
y_mid = int(crop_best[1] + settings.crop_height/2)
- return [PointOfInterest(x_mid, y_mid)]
+ return [PointOfInterest(x_mid, y_mid, size=25)]
def image_entropy(im):
# greyscale image entropy
- band = np.asarray(im.convert("1"))
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
-def poi_average(pois, settings, im=None):
+def poi_average(pois, settings):
weight = 0.0
x = 0.0
y = 0.0
- for pois in pois:
- if settings.annotate_image and im is not None:
- w = 4 * 0.5 * sqrt(pois.weight)
- d = ImageDraw.Draw(im)
- d.ellipse([
- pois.x - w, pois.y - w,
- pois.x + w, pois.y + w ], fill=BLUE)
- weight += pois.weight
- x += pois.x * pois.weight
- y += pois.y * pois.weight
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
avg_x = round(x / weight)
avg_y = round(y / weight)
@@ -199,10 +215,19 @@ def poi_average(pois, settings, im=None):
class PointOfInterest:
- def __init__(self, x, y, weight=1.0):
+ def __init__(self, x, y, weight=1.0, size=10):
self.x = x
self.y = y
self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
class Settings:
--
cgit v1.2.3
From 1be5933ba21a3badec42b7b2753d626f849b609d Mon Sep 17 00:00:00 2001
From: captin411
Date: Sun, 23 Oct 2022 04:11:07 -0700
Subject: auto cropping now works with non square crops
---
modules/textual_inversion/autocrop.py | 509 ++++++++++++++++++----------------
1 file changed, 269 insertions(+), 240 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 5a551c25..b2f9241c 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,241 +1,270 @@
-import cv2
-from collections import defaultdict
-from math import log, sqrt
-import numpy as np
-from PIL import Image, ImageDraw
-
-GREEN = "#0F0"
-BLUE = "#00F"
-RED = "#F00"
-
-
-def crop_image(im, settings):
- """ Intelligently crop an image to the subject matter """
- if im.height > im.width:
- im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width))
- elif im.width > im.height:
- im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height))
- else:
- im = im.resize((settings.crop_width, settings.crop_height))
-
- if im.height == im.width:
- return im
-
- focus = focal_point(im, settings)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(settings.crop_height / 2)
- x_half = int(settings.crop_width / 2)
-
- x1 = focus.x - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + settings.crop_width > im.width:
- x1 = im.width - settings.crop_width
-
- y1 = focus.y - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + settings.crop_height > im.height:
- y1 = im.height - settings.crop_height
-
- x2 = x1 + settings.crop_width
- y2 = y1 + settings.crop_height
-
- crop = [x1, y1, x2, y2]
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- rect = list(crop)
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- if settings.destop_view_image:
- im.show()
-
- return im.crop(tuple(crop))
-
-def focal_point(im, settings):
- corner_points = image_corner_points(im, settings)
- entropy_points = image_entropy_points(im, settings)
- face_points = image_face_points(im, settings)
-
- total_points = len(corner_points) + len(entropy_points) + len(face_points)
-
- corner_weight = settings.corner_points_weight
- entropy_weight = settings.entropy_points_weight
- face_weight = settings.face_points_weight
-
- weight_pref_total = corner_weight + entropy_weight + face_weight
-
- # weight things
- pois = []
- if weight_pref_total == 0 or total_points == 0:
- return pois
-
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
- )
-
- average_point = poi_average(pois, settings)
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- for f in face_points:
- d.rectangle(f.bounding(f.size), outline=RED)
- for f in entropy_points:
- d.rectangle(f.bounding(30), outline=BLUE)
- for poi in pois:
- w = max(4, 4 * 0.5 * sqrt(poi.weight))
- d.ellipse(poi.bounding(w), fill=BLUE)
- d.ellipse(average_point.bounding(25), outline=GREEN)
-
- return average_point
-
-
-def image_face_points(im, settings):
- np_im = np.array(im)
- gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
-
- tries = [
- [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
- ]
-
- for t in tries:
- # print(t[0])
- classifier = cv2.CascadeClassifier(t[0])
- minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
- try:
- faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
- minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
- continue
-
- if len(faces) > 0:
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
- return []
-
-
-def image_corner_points(im, settings):
- grayscale = im.convert("L")
-
- # naive attempt at preventing focal points from collecting at watermarks near the bottom
- gd = ImageDraw.Draw(grayscale)
- gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
-
- np_im = np.array(grayscale)
-
- points = cv2.goodFeaturesToTrack(
- np_im,
- maxCorners=100,
- qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.07,
- useHarrisDetector=False,
- )
-
- if points is None:
- return []
-
- focal_points = []
- for point in points:
- x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y, size=4))
-
- return focal_points
-
-
-def image_entropy_points(im, settings):
- landscape = im.height < im.width
- portrait = im.height > im.width
- if landscape:
- move_idx = [0, 2]
- move_max = im.size[0]
- elif portrait:
- move_idx = [1, 3]
- move_max = im.size[1]
- else:
- return []
-
- e_max = 0
- crop_current = [0, 0, settings.crop_width, settings.crop_height]
- crop_best = crop_current
- while crop_current[move_idx[1]] < move_max:
- crop = im.crop(tuple(crop_current))
- e = image_entropy(crop)
-
- if (e > e_max):
- e_max = e
- crop_best = list(crop_current)
-
- crop_current[move_idx[0]] += 4
- crop_current[move_idx[1]] += 4
-
- x_mid = int(crop_best[0] + settings.crop_width/2)
- y_mid = int(crop_best[1] + settings.crop_height/2)
-
- return [PointOfInterest(x_mid, y_mid, size=25)]
-
-
-def image_entropy(im):
- # greyscale image entropy
- # band = np.asarray(im.convert("L"))
- band = np.asarray(im.convert("1"), dtype=np.uint8)
- hist, _ = np.histogram(band, bins=range(0, 256))
- hist = hist[hist > 0]
- return -np.log2(hist / hist.sum()).sum()
-
-
-def poi_average(pois, settings):
- weight = 0.0
- x = 0.0
- y = 0.0
- for poi in pois:
- weight += poi.weight
- x += poi.x * poi.weight
- y += poi.y * poi.weight
- avg_x = round(x / weight)
- avg_y = round(y / weight)
-
- return PointOfInterest(avg_x, avg_y)
-
-
-class PointOfInterest:
- def __init__(self, x, y, weight=1.0, size=10):
- self.x = x
- self.y = y
- self.weight = weight
- self.size = size
-
- def bounding(self, size):
- return [
- self.x - size//2,
- self.y - size//2,
- self.x + size//2,
- self.y + size//2
- ]
-
-
-class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.corner_points_weight = corner_points_weight
- self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = entropy_points_weight
- self.annotate_image = annotate_image
+import cv2
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+
+ if im.width == settings.crop_width and im.height == settings.crop_height:
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ rect = [0, 0, im.width, im.height]
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ if settings.destop_view_image:
+ im.show()
+ return im
+
+ focus = focal_point(im, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ if settings.destop_view_image:
+ im.show()
+
+ return im.crop(tuple(crop))
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings)
+ entropy_points = image_entropy_points(im, settings)
+ face_points = image_face_points(im, settings)
+
+ total_points = len(corner_points) + len(entropy_points) + len(face_points)
+
+ corner_weight = settings.corner_points_weight
+ entropy_weight = settings.entropy_points_weight
+ face_weight = settings.face_points_weight
+
+ weight_pref_total = corner_weight + entropy_weight + face_weight
+
+ # weight things
+ pois = []
+ if weight_pref_total == 0 or total_points == 0:
+ return pois
+
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
+ )
+
+ average_point = poi_average(pois, settings)
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ for f in face_points:
+ d.rectangle(f.bounding(f.size), outline=RED)
+ for f in entropy_points:
+ d.rectangle(f.bounding(30), outline=BLUE)
+ for poi in pois:
+ w = max(4, 4 * 0.5 * sqrt(poi.weight))
+ d.ellipse(poi.bounding(w), fill=BLUE)
+ d.ellipse(average_point.bounding(25), outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+
+ for t in tries:
+ # print(t[0])
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
+ return []
+
+
+def image_corner_points(im, settings):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.07,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e > e_max):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid, size=25)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+
+def poi_average(pois, settings):
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
+ avg_x = round(x / weight)
+ avg_y = round(y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+def is_landscape(w, h):
+ return w > h
+
+
+def is_portrait(w, h):
+ return h > w
+
+
+def is_square(w, h):
+ return w == h
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = entropy_points_weight
+ self.annotate_image = annotate_image
self.destop_view_image = False
\ No newline at end of file
--
cgit v1.2.3
From 3e6c2420c1177e9e79f2b566a5a7795b7416e34a Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 13:10:58 -0700
Subject: improve debug markers, fix algo weighting
---
modules/textual_inversion/autocrop.py | 207 +++++++++++++++++++++-------------
1 file changed, 129 insertions(+), 78 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index b2f9241c..caaf18c8 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,4 +1,5 @@
import cv2
+import os
from collections import defaultdict
from math import log, sqrt
import numpy as np
@@ -26,19 +27,9 @@ def crop_image(im, settings):
scale_by = settings.crop_height / im.height
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
- if im.width == settings.crop_width and im.height == settings.crop_height:
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- rect = [0, 0, im.width, im.height]
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- if settings.destop_view_image:
- im.show()
- return im
-
- focus = focal_point(im, settings)
+ focus = focal_point(im_debug, settings)
# take the focal point and turn it into crop coordinates that try to center over the focal
# point but then get adjusted back into the frame
@@ -62,89 +53,143 @@ def crop_image(im, settings):
crop = [x1, y1, x2, y2]
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
if settings.annotate_image:
- d = ImageDraw.Draw(im)
+ d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
if settings.destop_view_image:
- im.show()
+ im_debug.show()
- return im.crop(tuple(crop))
+ return results
def focal_point(im, settings):
corner_points = image_corner_points(im, settings)
entropy_points = image_entropy_points(im, settings)
face_points = image_face_points(im, settings)
- total_points = len(corner_points) + len(entropy_points) + len(face_points)
-
- corner_weight = settings.corner_points_weight
- entropy_weight = settings.entropy_points_weight
- face_weight = settings.face_points_weight
-
- weight_pref_total = corner_weight + entropy_weight + face_weight
-
- # weight things
pois = []
- if weight_pref_total == 0 or total_points == 0:
- return pois
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
- )
+ weight_pref_total = 0
+ if len(corner_points) > 0:
+ weight_pref_total += settings.corner_points_weight
+ if len(entropy_points) > 0:
+ weight_pref_total += settings.entropy_points_weight
+ if len(face_points) > 0:
+ weight_pref_total += settings.face_points_weight
+
+ corner_centroid = None
+ if len(corner_points) > 0:
+ corner_centroid = centroid(corner_points)
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ pois.append(corner_centroid)
+
+ entropy_centroid = None
+ if len(entropy_points) > 0:
+ entropy_centroid = centroid(entropy_points)
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+ pois.append(entropy_centroid)
+
+ face_centroid = None
+ if len(face_points) > 0:
+ face_centroid = centroid(face_points)
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
+ pois.append(face_centroid)
average_point = poi_average(pois, settings)
if settings.annotate_image:
d = ImageDraw.Draw(im)
- for f in face_points:
- d.rectangle(f.bounding(f.size), outline=RED)
- for f in entropy_points:
- d.rectangle(f.bounding(30), outline=BLUE)
- for poi in pois:
- w = max(4, 4 * 0.5 * sqrt(poi.weight))
- d.ellipse(poi.bounding(w), fill=BLUE)
- d.ellipse(average_point.bounding(25), outline=GREEN)
+ max_size = min(im.width, im.height) * 0.07
+ if corner_centroid is not None:
+ color = BLUE
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
+ d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(corner_points) > 1:
+ for f in corner_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if entropy_centroid is not None:
+ color = "#ff0"
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+ d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(entropy_points) > 1:
+ for f in entropy_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if face_centroid is not None:
+ color = RED
+ box = face_centroid.bounding(max_size * face_centroid.weight)
+ d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(face_points) > 1:
+ for f in face_points:
+ d.rectangle(f.bounding(4), outline=color)
+
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
return average_point
def image_face_points(im, settings):
- np_im = np.array(im)
- gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
-
- tries = [
- [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
- ]
-
- for t in tries:
- # print(t[0])
- classifier = cv2.CascadeClassifier(t[0])
- minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
- try:
- faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
- minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
- continue
-
- if len(faces) > 0:
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
+ if settings.dnn_model_path is not None:
+ detector = cv2.FaceDetectorYN.create(
+ settings.dnn_model_path,
+ "",
+ (im.width, im.height),
+ 0.8, # score threshold
+ 0.3, # nms threshold
+ 5000 # keep top k before nms
+ )
+ faces = detector.detect(np.array(im))
+ results = []
+ if faces[1] is not None:
+ for face in faces[1]:
+ x = face[0]
+ y = face[1]
+ w = face[2]
+ h = face[3]
+ results.append(
+ PointOfInterest(
+ int(x + (w * 0.5)), # face focus left/right is center
+ int(y + (h * 0)), # face focus up/down is close to the top of the head
+ size = w,
+ weight = 1/len(faces[1])
+ )
+ )
+ return results
+ else:
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+ for t in tries:
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
return []
@@ -161,7 +206,7 @@ def image_corner_points(im, settings):
np_im,
maxCorners=100,
qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.07,
+ minDistance=min(grayscale.width, grayscale.height)*0.03,
useHarrisDetector=False,
)
@@ -171,7 +216,7 @@ def image_corner_points(im, settings):
focal_points = []
for point in points:
x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y, size=4))
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
return focal_points
@@ -205,17 +250,22 @@ def image_entropy_points(im, settings):
x_mid = int(crop_best[0] + settings.crop_width/2)
y_mid = int(crop_best[1] + settings.crop_height/2)
- return [PointOfInterest(x_mid, y_mid, size=25)]
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
def image_entropy(im):
# greyscale image entropy
- # band = np.asarray(im.convert("L"))
- band = np.asarray(im.convert("1"), dtype=np.uint8)
+ band = np.asarray(im.convert("L"))
+ # band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
+def centroid(pois):
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+
def poi_average(pois, settings):
weight = 0.0
@@ -260,11 +310,12 @@ class PointOfInterest:
class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
self.crop_width = crop_width
self.crop_height = crop_height
self.corner_points_weight = corner_points_weight
self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
self.annotate_image = annotate_image
- self.destop_view_image = False
\ No newline at end of file
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path
\ No newline at end of file
--
cgit v1.2.3
From db8ed5fe5cd6e967d12d43d96b7f83083e58626c Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 15:22:29 -0700
Subject: Focal crop UI elements
---
modules/textual_inversion/preprocess.py | 26 +++++++++++++-------------
1 file changed, 13 insertions(+), 13 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index a8c17c6f..1e4d4de8 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -13,7 +13,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
@@ -23,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_entropy_focus)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
finally:
@@ -35,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -139,27 +139,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
ratio = (img.height * width) / (img.width * height)
inverse_xy = True
- processing_option_ran = False
+ process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption)
- processing_option_ran = True
+ process_default_resize = False
if process_entropy_focus and img.height != img.width:
autocrop_settings = autocrop.Settings(
crop_width = width,
crop_height = height,
- face_points_weight = 0.9,
- entropy_points_weight = 0.7,
- corner_points_weight = 0.5,
- annotate_image = False
+ face_points_weight = process_focal_crop_face_weight,
+ entropy_points_weight = process_focal_crop_entropy_weight,
+ corner_points_weight = process_focal_crop_edges_weight,
+ annotate_image = process_focal_crop_debug
)
- focal = autocrop.crop_image(img, autocrop_settings)
- save_pic(focal, index, existing_caption=existing_caption)
- processing_option_ran = True
+ for focal in autocrop.crop_image(img, autocrop_settings):
+ save_pic(focal, index, existing_caption=existing_caption)
+ process_default_resize = False
- if not processing_option_ran:
+ if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption)
--
cgit v1.2.3
From 54f0c1482427a5b3f2248b97be55878e742cbcb1 Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 16:14:13 -0700
Subject: download better face detection module dynamically
---
modules/textual_inversion/autocrop.py | 20 ++++++++++++++++++++
modules/textual_inversion/preprocess.py | 13 +++++++++++--
2 files changed, 31 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index caaf18c8..01a92b12 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,4 +1,5 @@
import cv2
+import requests
import os
from collections import defaultdict
from math import log, sqrt
@@ -293,6 +294,25 @@ def is_square(w, h):
return w == h
+def download_and_cache_models(dirname):
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
+
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
+
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
+
+
class PointOfInterest:
def __init__(self, x, y, weight=1.0, size=10):
self.x = x
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 1e4d4de8..e13b1894 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,6 +7,7 @@ import tqdm
import time
from modules import shared, images
+from modules.paths import models_path
from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
@@ -146,14 +147,22 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
save_pic(splitted, index, existing_caption=existing_caption)
process_default_resize = False
- if process_entropy_focus and img.height != img.width:
+ if process_focal_crop and img.height != img.width:
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+ except Exception as e:
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+
autocrop_settings = autocrop.Settings(
crop_width = width,
crop_height = height,
face_points_weight = process_focal_crop_face_weight,
entropy_points_weight = process_focal_crop_entropy_weight,
corner_points_weight = process_focal_crop_edges_weight,
- annotate_image = process_focal_crop_debug
+ annotate_image = process_focal_crop_debug,
+ dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption)
--
cgit v1.2.3
From df0c5ea29d7f0c682ac81f184f3e482a6450d018 Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 17:06:59 -0700
Subject: update default weights
---
modules/textual_inversion/autocrop.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 01a92b12..9859974a 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -71,9 +71,9 @@ def crop_image(im, settings):
return results
def focal_point(im, settings):
- corner_points = image_corner_points(im, settings)
- entropy_points = image_entropy_points(im, settings)
- face_points = image_face_points(im, settings)
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
pois = []
@@ -144,7 +144,7 @@ def image_face_points(im, settings):
settings.dnn_model_path,
"",
(im.width, im.height),
- 0.8, # score threshold
+ 0.9, # score threshold
0.3, # nms threshold
5000 # keep top k before nms
)
@@ -159,7 +159,7 @@ def image_face_points(im, settings):
results.append(
PointOfInterest(
int(x + (w * 0.5)), # face focus left/right is center
- int(y + (h * 0)), # face focus up/down is close to the top of the head
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
size = w,
weight = 1/len(faces[1])
)
@@ -207,7 +207,7 @@ def image_corner_points(im, settings):
np_im,
maxCorners=100,
qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.03,
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
useHarrisDetector=False,
)
@@ -256,8 +256,8 @@ def image_entropy_points(im, settings):
def image_entropy(im):
# greyscale image entropy
- band = np.asarray(im.convert("L"))
- # band = np.asarray(im.convert("1"), dtype=np.uint8)
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
--
cgit v1.2.3
From cbb857b675cf0f169b21515c29da492b513cc8c4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 26 Oct 2022 09:44:02 +0300
Subject: enable creating embedding with --medvram
---
modules/textual_inversion/textual_inversion.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 529ed3e2..647ffe3e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
--
cgit v1.2.3
From c2dc9bfa89070b8e1d857f8773a790b752f1b709 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Mon, 24 Oct 2022 23:22:58 -0700
Subject: Implement PR #3189 but for embeddings.
---
modules/textual_inversion/textual_inversion.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 647ffe3e..22c7b54b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin
-from modules import shared, devices, sd_hijack, processing, sd_models
+from modules import shared, devices, sd_hijack, processing, sd_models, images
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -247,6 +247,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_file = ""
last_saved_image = ""
+ forced_filename = ""
embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
@@ -296,8 +297,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
-
+ forced_filename = f'{embedding_name}-{embedding.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
@@ -353,8 +354,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
- image.save(last_saved_image)
-
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
--
cgit v1.2.3
From 4875a6c217df5cc06ee2bf11fb645b172c7156a8 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Mon, 24 Oct 2022 23:38:07 -0700
Subject: Implement PR #3309 but for embeddings.
---
modules/textual_inversion/textual_inversion.py | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 22c7b54b..4921bd01 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -167,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -287,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{embedding.step}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
@@ -374,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
+ # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
+ embedding.name = embedding_name
+ filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
embedding.save(filename)
return embedding, filename
--
cgit v1.2.3
From f4e14642173a04723200b131deb417c6c79cab17 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Tue, 25 Oct 2022 00:04:25 -0700
Subject: Implement PR #3625 but for embeddings.
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4921bd01..4fcebe74 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -358,7 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
--
cgit v1.2.3
From 737eb28faca8be2bb996ee0930ec77d1f7ebd939 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 26 Oct 2022 14:45:33 +0100
Subject: typo: cmd_opts.embedding_dir to cmd_opts.embeddings_dir
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4fcebe74..ff002d3e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -380,7 +380,7 @@ Last saved image: {html.escape(last_saved_image)}
embedding.cached_checksum = None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
embedding.name = embedding_name
- filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
embedding.save(filename)
return embedding, filename
--
cgit v1.2.3
From a0a7024c679056dd66beb1832e52041b10143130 Mon Sep 17 00:00:00 2001
From: FlameLaw <116745066+FlameLaw@users.noreply.github.com>
Date: Fri, 28 Oct 2022 02:13:48 +0900
Subject: Fix random dataset shuffle on TI
---
modules/textual_inversion/dataset.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 5b1c5002..8bb00d27 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -86,12 +86,12 @@ class PersonalizedBase(Dataset):
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
- self.initial_indexes = np.arange(len(self.dataset))
+ self.dataset_length = len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
+ self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text):
text = random.choice(self.lines)
--
cgit v1.2.3
From 9ceef81f77ecce89f0c8f412c4d849210d852e82 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Fri, 28 Oct 2022 20:48:08 +0700
Subject: Fix log off by 1
---
modules/textual_inversion/learn_schedule.py | 2 +-
modules/textual_inversion/textual_inversion.py | 24 ++++++++++++------------
2 files changed, 13 insertions(+), 13 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..3a736065 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -52,7 +52,7 @@ class LearnRateScheduler:
self.finished = False
def apply(self, optimizer, step_number):
- if step_number <= self.end_step:
+ if step_number < self.end_step:
return
try:
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ff002d3e..17dfb223 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
- if step % shared.opts.training_write_csv_every != 0:
+ if (step + 1) % shared.opts.training_write_csv_every != 0:
return
-
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
@@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = step // epoch_len
- epoch_step = step - epoch * epoch_len
+ epoch_step = step % epoch_len
csv_writer.writerow({
"step": step + 1,
- "epoch": epoch + 1,
+ "epoch": epoch,
"epoch_step": epoch_step + 1,
**values,
})
@@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
optimizer.step()
+ steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+ epoch_step = embedding.step % len(ds)
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
- if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{embedding.step}'
+ embedding.name = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
@@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
"learn_rate": scheduler.learn_rate
})
- if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- forced_filename = f'{embedding_name}-{embedding.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
@@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
--
cgit v1.2.3
From a5f3adbdd7d9b8245f7782216ac48913660e6bb5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 15:37:24 +0700
Subject: Allow trailing comma in learning rate
---
modules/textual_inversion/learn_schedule.py | 33 +++++++++++++++++------------
1 file changed, 20 insertions(+), 13 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 3a736065..76e611b6 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -11,23 +11,30 @@ class LearnScheduleIterator:
self.rates = []
self.it = 0
self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
return
- elif step == -1:
+ else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception("Invalid learning rate schedule")
+
def __iter__(self):
return self
--
cgit v1.2.3
From ef4c94e1cfe66299227aa95a28c2380d21cb1600 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 15:42:51 +0700
Subject: Improve lr schedule error message
---
modules/textual_inversion/learn_schedule.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 76e611b6..dd0c0ad1 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -4,7 +4,7 @@ import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
"""
pairs = learn_rate.split(',')
@@ -33,7 +33,7 @@ class LearnScheduleIterator:
return
assert self.rates
except (ValueError, AssertionError):
- raise Exception("Invalid learning rate schedule")
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
def __iter__(self):
--
cgit v1.2.3
From ab27c111d06ec920791c73eea25ad9a61671852e Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 18:09:17 +0700
Subject: Add input validations before loading dataset for training
---
modules/textual_inversion/textual_inversion.py | 48 +++++++++++++++++++-------
1 file changed, 36 insertions(+), 12 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 17dfb223..44f06443 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_file, "Prompt template file is empty"
+ assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0 , "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0 , "Create image must be positive or 0"
+ if save_model_every or create_image_every:
+ assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- assert embedding_name, 'embedding not selected'
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
-
+
cond_model = shared.sd_model.cond_stage_model
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return embedding, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
- hijack = sd_hijack.model_hijack
-
- embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
losses = torch.zeros((32,))
@@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = ""
embedding_yet_to_be_embedded = False
- ititial_step = embedding.step or 0
- if ititial_step > steps:
- return embedding, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
-
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
--
cgit v1.2.3
From 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 19:43:21 +0700
Subject: Add cleanup after training
---
modules/textual_inversion/textual_inversion.py | 185 +++++++++++++------------
1 file changed, 95 insertions(+), 90 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 44f06443..fd7f0897 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -283,111 +283,113 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, entries in pbar:
- embedding.step = i + ititial_step
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- steps_done = embedding.step + 1
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step % len(ds)
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
+ try:
+ for i, entries in pbar:
+ embedding.step = i + ititial_step
+
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step % len(ds)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
+
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
+ embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0]
+ processed = processing.process_images(p)
+ image = processed.images[0]
- shared.state.current_image = image
+ shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name', '???'))
+ title = "<{}>".format(data.get('name', '???'))
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = embedding.step
+ shared.state.job_no = embedding.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
+ finally:
+ if embedding and embedding.vec is not None:
+ embedding.vec.requires_grad = False
checkpoint = sd_models.select_checkpoint()
--
cgit v1.2.3
From a27d19de2eff633b6a39f9f4a5c0f2d6abb81bb5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 19:44:05 +0700
Subject: Additional assert on dataset
---
modules/textual_inversion/dataset.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 8bb00d27..ad726577 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
--
cgit v1.2.3
From ab05a74ead9fabb45dd099990e34061c7eb02ca3 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:32:02 +0700
Subject: Revert "Add cleanup after training"
This reverts commit 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1.
---
modules/textual_inversion/textual_inversion.py | 185 ++++++++++++-------------
1 file changed, 90 insertions(+), 95 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd7f0897..44f06443 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -283,113 +283,111 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, entries in pbar:
+ embedding.step = i + ititial_step
- try:
- for i, entries in pbar:
- embedding.step = i + ititial_step
-
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- steps_done = embedding.step + 1
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step % len(ds)
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step % len(ds)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
+
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
+ embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0]
+ processed = processing.process_images(p)
+ image = processed.images[0]
- shared.state.current_image = image
+ shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name', '???'))
+ title = "<{}>".format(data.get('name', '???'))
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = embedding.step
+ shared.state.job_no = embedding.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -398,9 +396,6 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
- finally:
- if embedding and embedding.vec is not None:
- embedding.vec.requires_grad = False
checkpoint = sd_models.select_checkpoint()
--
cgit v1.2.3
From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:49:29 +0700
Subject: Add missing info on hypernetwork/embedding model log
Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513
Also group the saving into one
---
modules/textual_inversion/textual_inversion.py | 39 +++++++++++++++++---------
1 file changed, 26 insertions(+), 13 deletions(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 44f06443..ee9917ce 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -119,7 +119,7 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
@@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
if ititial_step > steps:
@@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
@@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}
"""
- checkpoint = sd_models.select_checkpoint()
-
- embedding.sd_checkpoint = checkpoint.hash
- embedding.sd_checkpoint_name = checkpoint.model_name
- embedding.cached_checksum = None
- # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
- embedding.name = embedding_name
- filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
- embedding.save(filename)
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
return embedding, filename
+
+def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.save(filename)
+ except:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise
--
cgit v1.2.3
From 3d58510f214c645ce5cdb261aa47df6573b239e9 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:54:59 +0700
Subject: Fix dataset still being loaded even when training will be skipped
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/textual_inversion')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ee9917ce..e0babb46 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -262,7 +262,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
- if ititial_step > steps:
+ if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
--
cgit v1.2.3