diff options
author | JC-Array <44535867+JC-Array@users.noreply.github.com> | 2022-10-10 23:11:02 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-10 23:11:02 +0000 |
commit | d66bc86159d415005f0745fdb5724bcd95576352 (patch) | |
tree | 2544c33a8f443f226c9cf4bea7df7e3a30369812 /modules/textual_inversion | |
parent | 76ef3d75f61253516c024553335d9083d9660a8a (diff) | |
parent | 47f5e216da2af4b1faf232a620572f8b357855d5 (diff) | |
download | stable-diffusion-webui-gfx803-d66bc86159d415005f0745fdb5724bcd95576352.tar.gz stable-diffusion-webui-gfx803-d66bc86159d415005f0745fdb5724bcd95576352.tar.bz2 stable-diffusion-webui-gfx803-d66bc86159d415005f0745fdb5724bcd95576352.zip |
Merge pull request #2 from JC-Array/master
resolve merge conflicts
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/dataset.py | 3 | ||||
-rw-r--r-- | modules/textual_inversion/preprocess.py | 19 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 14 |
3 files changed, 22 insertions, 14 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7c44ea5b..bcf772d2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -15,11 +15,10 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
- self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 9f63c9a4..4a2194da 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -10,8 +10,9 @@ from modules.shared import opts, cmd_opts if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption, process_caption_deepbooru=False):
- size = 512
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
@@ -69,23 +70,23 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca is_wide = ratio < 1 / 1.35
if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ img = img.resize((width, height * img.height // img.width))
- top = img.crop((0, 0, size, size))
+ top = img.crop((0, 0, width, height))
save_pic(top, index)
- bot = img.crop((0, img.height - size, size, img.height))
+ bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ img = img.resize((width * img.width // img.height, height))
- left = img.crop((0, 0, size, size))
+ left = img.crop((0, 0, width, height))
save_pic(left, index)
- right = img.crop((img.width - size, 0, img.width, size))
+ right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
else:
- img = images.resize_image(1, img, size, size)
+ img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cd9f3498..5965c5a0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -182,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -200,6 +200,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, if ititial_step > steps:
return embedding, filename
+ tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
+ epoch_len = (tr_img_len * num_repeats) + tr_img_len
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -223,7 +226,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ epoch_num = embedding.step // epoch_len
+ epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -236,6 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, sd_model=shared.sd_model,
prompt=text,
steps=20,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
|