diff options
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r-- | modules/textual_inversion/dataset.py | 47 |
1 files changed, 34 insertions, 13 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index f61f40d3..67e90afe 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -11,11 +11,21 @@ import tqdm from modules import devices, shared
import re
-re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, latent=None, filename_text=None):
+ self.filename = filename
+ self.latent = latent
+ self.filename_text = filename_text
+ self.cond = None
+ self.cond_text = None
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset): except Exception:
continue
+ text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0]
- filename_tokens = re_tag.findall(filename_tokens)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
+
if include_cond:
- text = self.create_text(filename_tokens)
- cond = cond_model([text]).to(devices.cpu)
- else:
- cond = None
+ entry.cond_text = self.create_text(filename_text)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens, cond))
+ self.dataset.append(entry)
self.length = len(self.dataset) * repeats
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
- def create_text(self, filename_tokens):
+ def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset): self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens, cond = self.dataset[index]
+ entry = self.dataset[index]
+
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
- text = self.create_text(filename_tokens)
- return x, text, cond
+ return entry
|