diff options
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/autocrop.py | 202 | ||||
-rw-r--r-- | modules/textual_inversion/dataset.py | 4 | ||||
-rw-r--r-- | modules/textual_inversion/preprocess.py | 2 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 16 |
4 files changed, 113 insertions, 111 deletions
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 7770d22f..8e667a4d 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -10,63 +10,64 @@ RED = "#F00" def crop_image(im, settings):
- """ Intelligently crop an image to the subject matter """
-
- scale_by = 1
- if is_landscape(im.width, im.height):
- scale_by = settings.crop_height / im.height
- elif is_portrait(im.width, im.height):
- scale_by = settings.crop_width / im.width
- elif is_square(im.width, im.height):
- if is_square(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_landscape(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_portrait(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_height / im.height
-
- im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
- im_debug = im.copy()
-
- focus = focal_point(im_debug, settings)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(settings.crop_height / 2)
- x_half = int(settings.crop_width / 2)
-
- x1 = focus.x - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + settings.crop_width > im.width:
- x1 = im.width - settings.crop_width
-
- y1 = focus.y - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + settings.crop_height > im.height:
- y1 = im.height - settings.crop_height
-
- x2 = x1 + settings.crop_width
- y2 = y1 + settings.crop_height
-
- crop = [x1, y1, x2, y2]
-
- results = []
-
- results.append(im.crop(tuple(crop)))
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im_debug)
- rect = list(crop)
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- results.append(im_debug)
- if settings.destop_view_image:
- im_debug.show()
-
- return results
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
+
+ focus = focal_point(im_debug, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im_debug)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
+ if settings.destop_view_image:
+ im_debug.show()
+
+ return results
def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
@@ -86,7 +87,7 @@ def focal_point(im, settings): corner_centroid = None
if len(corner_points) > 0:
corner_centroid = centroid(corner_points)
- corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid)
entropy_centroid = None
@@ -98,7 +99,7 @@ def focal_point(im, settings): face_centroid = None
if len(face_points) > 0:
face_centroid = centroid(face_points)
- face_centroid.weight = settings.face_points_weight / weight_pref_total
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid)
average_point = poi_average(pois, settings)
@@ -132,7 +133,7 @@ def focal_point(im, settings): d.rectangle(f.bounding(4), outline=color)
d.ellipse(average_point.bounding(max_size), outline=GREEN)
-
+
return average_point
@@ -260,10 +261,11 @@ def image_entropy(im): hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
+
def centroid(pois):
- x = [poi.x for poi in pois]
- y = [poi.y for poi in pois]
- return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
def poi_average(pois, settings):
@@ -281,59 +283,59 @@ def poi_average(pois, settings): def is_landscape(w, h):
- return w > h
+ return w > h
def is_portrait(w, h):
- return h > w
+ return h > w
def is_square(w, h):
- return w == h
+ return w == h
def download_and_cache_models(dirname):
- download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- model_file_name = 'face_detection_yunet.onnx'
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
- cache_file = os.path.join(dirname, model_file_name)
- if not os.path.exists(cache_file):
- print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
- response = requests.get(download_url)
- with open(cache_file, "wb") as f:
- f.write(response.content)
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
- if os.path.exists(cache_file):
- return cache_file
- return None
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
class PointOfInterest:
- def __init__(self, x, y, weight=1.0, size=10):
- self.x = x
- self.y = y
- self.weight = weight
- self.size = size
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
- def bounding(self, size):
- return [
- self.x - size//2,
- self.y - size//2,
- self.x + size//2,
- self.y + size//2
- ]
+ def bounding(self, size):
+ return [
+ self.x - size // 2,
+ self.y - size // 2,
+ self.x + size // 2,
+ self.y + size // 2
+ ]
class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.corner_points_weight = corner_points_weight
- self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = face_points_weight
- self.annotate_image = annotate_image
- self.destop_view_image = False
- self.dnn_model_path = dnn_model_path
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 41610e03..b9621fc9 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,7 +118,7 @@ class PersonalizedBase(Dataset): weight = torch.ones(latent_sample.shape)
else:
weight = None
-
+
if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
else:
@@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader): return self
def collate_wrapper_random(batch):
- return BatchLoaderRandom(batch)
\ No newline at end of file + return BatchLoaderRandom(batch)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index d0cad09e..a009d8e8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -125,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr default=None
)
return wh and center_crop(image, *wh)
-
+
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9e1b2b9a..d489ed1e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -323,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
- tensorboard_writer.add_scalar(tag=tag,
+ tensorboard_writer.add_scalar(tag=tag,
scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
# Convert a pil image to a torch tensor
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
- img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
len(pil_image.getbands()))
img_tensor = img_tensor.permute((2, 0, 1))
-
+
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
@@ -402,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
-
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
@@ -412,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
-
+
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
@@ -439,7 +439,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
-
+
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
@@ -485,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if clip_grad:
clip_grad_sched.step(embedding.step)
-
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight:
@@ -513,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
-
+
if clip_grad:
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|