From 04a561c11c9bf9a00d7f9b50ca3f7962aa59ba6e Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 12:29:23 -0500 Subject: add option to skip interrogate categories --- modules/interrogate.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) (limited to 'modules/interrogate.py') diff --git a/modules/interrogate.py b/modules/interrogate.py index 1d1ac572..c252b148 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,6 +2,7 @@ import os import sys import traceback from collections import namedtuple +from pathlib import Path import re import torch @@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") -category_types = ["artists", "flavors", "mediums", "movements"] +def category_types(): + return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')] + def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") tmpdir = content_dir + "_tmp" + category_types = ["artists", "flavors", "mediums", "movements"] + try: os.makedirs(tmpdir) for category_type in category_types: @@ -48,33 +53,32 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None - self.selected_categories = [] + self.skip_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories: return self.loaded_categories self.loaded_categories = [] - if not os.path.exists(self.content_dir): - download_default_clip_interrogate_categories(self.content_dir) - if os.path.exists(self.content_dir): - self.selected_categories = shared.opts.interrogate_clip_categories - for category_type in category_types: - if 'all' not in self.selected_categories and category_type not in self.selected_categories: - continue - filename = os.path.join(self.content_dir, f"{category_type}.txt") - if not os.path.isfile(filename): + self.skip_categories = shared.opts.interrogate_clip_skip_categories + category_types = [] + for filename in Path(self.content_dir).glob('*.txt'): + category_types.append(filename.stem) + if filename.stem in self.skip_categories: continue - m = re_topn.search(filename) + m = re_topn.search(filename.stem) topn = 1 if m is None else int(m.group(1)) with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines)) return self.loaded_categories -- cgit v1.2.3