aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-10-12 07:35:42 +0000
committerGitHub <noreply@github.com>2022-10-12 07:35:42 +0000
commitdc1432e0dd2b826bbf5aee3e87d8270c151e4912 (patch)
tree276c27910e8ec38a1a054df02c67f73a84cb35df /modules/textual_inversion/dataset.py
parent1d64976dbc5a0f3124567b91fadd5014a9d93c5f (diff)
parentca5efc316b9431746ff886d259275310f63f95fb (diff)
downloadstable-diffusion-webui-gfx803-dc1432e0dd2b826bbf5aee3e87d8270c151e4912.tar.gz
stable-diffusion-webui-gfx803-dc1432e0dd2b826bbf5aee3e87d8270c151e4912.tar.bz2
stable-diffusion-webui-gfx803-dc1432e0dd2b826bbf5aee3e87d8270c151e4912.zip
Merge branch 'master' into feature/scale_to
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py37
1 files changed, 24 insertions, 13 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..f61f40d3 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,18 +8,17 @@ from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
self.placeholder_token = placeholder_token
- self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -33,12 +32,15 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
+ cond_model = shared.sd_model.cond_stage_model
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
@@ -53,7 +55,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens))
+ if include_cond:
+ text = self.create_text(filename_tokens)
+ cond = cond_model([text]).to(devices.cpu)
+ else:
+ cond = None
+
+ self.dataset.append((init_latent, filename_tokens, cond))
self.length = len(self.dataset) * repeats
@@ -64,6 +72,12 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ def create_text(self, filename_tokens):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+ return text
+
def __len__(self):
return self.length
@@ -72,10 +86,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
-
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ x, filename_tokens, cond = self.dataset[index]
- return x, text
+ text = self.create_text(filename_tokens)
+ return x, text, cond