From 6d805b669e86233432f56ee1892d062103abe501 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 09:14:27 +0300 Subject: make CLIP interrogator download original text files if the directory does not exist remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them --- modules/interrogate.py | 55 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 14 deletions(-) (limited to 'modules/interrogate.py') diff --git a/modules/interrogate.py b/modules/interrogate.py index 738d8ff7..19938cbb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -5,12 +5,13 @@ from collections import namedtuple import re import torch +import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader +from modules import devices, paths, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +def download_default_clip_interrogate_categories(content_dir): + print("Downloading CLIP categories...") + + tmpdir = content_dir + "_tmp" + try: + os.makedirs(tmpdir) + + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) + + os.rename(tmpdir, content_dir) + + except Exception as e: + errors.display(e, "downloading default CLIP interrogate categories") + finally: + if os.path.exists(tmpdir): + os.remove(tmpdir) + + class InterrogateModels: blip_model = None clip_model = None clip_preprocess = None - categories = None dtype = None running_on_cpu = None def __init__(self, content_dir): - self.categories = [] + self.loaded_categories = None + self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") - if os.path.exists(content_dir): - for filename in os.listdir(content_dir): + def categories(self): + if self.loaded_categories is not None: + return self.loaded_categories + + self.loaded_categories = [] + + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if os.path.exists(self.content_dir): + for filename in os.listdir(self.content_dir): m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: + with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + + return self.loaded_categories def load_blip_model(self): import models.blip @@ -139,7 +172,6 @@ class InterrogateModels: shared.state.begin() shared.state.job = 'interrogate' try: - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() devices.torch_gc() @@ -159,12 +191,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True) - if shared.opts.interrogate_use_builtin_artists: - artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] - - res += ", " + artist[0] - - for name, topn, items in self.categories: + for name, topn, items in self.categories(): matches = self.rank(image_features, items, top_count=topn) for match, score in matches: if shared.opts.interrogate_return_ranks: -- cgit v1.2.3