diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-23 21:09:14 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-23 21:09:14 +0000 |
commit | 7ba7f4ed6e980051c9c461f514d2ddee43001b7e (patch) | |
tree | f5b927716493ca8bff06078dd68efaf5873bb91d /modules/interrogate.py | |
parent | 7b1c7ba87b14da9960d0347269421233f4cb5838 (diff) | |
parent | 04a561c11c9bf9a00d7f9b50ca3f7962aa59ba6e (diff) | |
download | stable-diffusion-webui-gfx803-7ba7f4ed6e980051c9c461f514d2ddee43001b7e.tar.gz stable-diffusion-webui-gfx803-7ba7f4ed6e980051c9c461f514d2ddee43001b7e.tar.bz2 stable-diffusion-webui-gfx803-7ba7f4ed6e980051c9c461f514d2ddee43001b7e.zip |
Merge pull request #7113 from vladmandic/interrogate
Add selector to interrogate categories
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r-- | modules/interrogate.py | 41 |
1 files changed, 25 insertions, 16 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py index 19938cbb..c252b148 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,6 +2,7 @@ import os import sys
import traceback
from collections import namedtuple
+from pathlib import Path
import re
import torch
@@ -20,19 +21,20 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.")
+def category_types():
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
+
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
+ category_types = ["artists", "flavors", "mediums", "movements"]
+
try:
os.makedirs(tmpdir)
-
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
-
+ for category_type in category_types:
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
except Exception as e:
@@ -51,27 +53,32 @@ class InterrogateModels: def __init__(self, content_dir):
self.loaded_categories = None
+ self.skip_categories = []
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
def categories(self):
- if self.loaded_categories is not None:
- return self.loaded_categories
-
- self.loaded_categories = []
-
if not os.path.exists(self.content_dir):
download_default_clip_interrogate_categories(self.content_dir)
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
+ return self.loaded_categories
+
+ self.loaded_categories = []
+
if os.path.exists(self.content_dir):
- for filename in os.listdir(self.content_dir):
- m = re_topn.search(filename)
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
+ category_types = []
+ for filename in Path(self.content_dir).glob('*.txt'):
+ category_types.append(filename.stem)
+ if filename.stem in self.skip_categories:
+ continue
+ m = re_topn.search(filename.stem)
topn = 1 if m is None else int(m.group(1))
-
- with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
+ with open(filename, "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
return self.loaded_categories
@@ -139,6 +146,8 @@ class InterrogateModels: def rank(self, image_features, text_array, top_count=1):
import clip
+ devices.torch_gc()
+
if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|