diff options
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r-- | modules/interrogate.py | 82 |
1 files changed, 64 insertions, 18 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py index 46935210..cbb80683 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,15 +2,17 @@ import os import sys
import traceback
from collections import namedtuple
+from pathlib import Path
import re
import torch
+import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram, modelloader
+from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -19,30 +21,76 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.")
+def category_types():
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
+
+
+def download_default_clip_interrogate_categories(content_dir):
+ print("Downloading CLIP categories...")
+
+ tmpdir = content_dir + "_tmp"
+ category_types = ["artists", "flavors", "mediums", "movements"]
+
+ try:
+ os.makedirs(tmpdir)
+ for category_type in category_types:
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
+ os.rename(tmpdir, content_dir)
+
+ except Exception as e:
+ errors.display(e, "downloading default CLIP interrogate categories")
+ finally:
+ if os.path.exists(tmpdir):
+ os.remove(tmpdir)
+
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
- categories = None
dtype = None
running_on_cpu = None
def __init__(self, content_dir):
- self.categories = []
+ self.loaded_categories = None
+ self.skip_categories = []
+ self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
- if os.path.exists(content_dir):
- for filename in os.listdir(content_dir):
- m = re_topn.search(filename)
- topn = 1 if m is None else int(m.group(1))
+ def categories(self):
+ if not os.path.exists(self.content_dir):
+ download_default_clip_interrogate_categories(self.content_dir)
+
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
+ return self.loaded_categories
- with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
+ self.loaded_categories = []
+
+ if os.path.exists(self.content_dir):
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
+ category_types = []
+ for filename in Path(self.content_dir).glob('*.txt'):
+ category_types.append(filename.stem)
+ if filename.stem in self.skip_categories:
+ continue
+ m = re_topn.search(filename.stem)
+ topn = 1 if m is None else int(m.group(1))
+ with open(filename, "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
+
+ return self.loaded_categories
+
+ def create_fake_fairscale(self):
+ class FakeFairscale:
+ def checkpoint_wrapper(self):
+ pass
+
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
def load_blip_model(self):
+ self.create_fake_fairscale()
import models.blip
files = modelloader.load_models(
@@ -106,6 +154,8 @@ class InterrogateModels: def rank(self, image_features, text_array, top_count=1):
import clip
+ devices.torch_gc()
+
if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
@@ -135,10 +185,10 @@ class InterrogateModels: return caption[0]
def interrogate(self, pil_image):
- res = None
-
+ res = ""
+ shared.state.begin()
+ shared.state.job = 'interrogate'
try:
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -158,12 +208,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True)
- if shared.opts.interrogate_use_builtin_artists:
- artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
-
- res += ", " + artist[0]
-
- for name, topn, items in self.categories:
+ for name, topn, items in self.categories():
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
@@ -177,5 +222,6 @@ class InterrogateModels: res += "<error>"
self.unload()
+ shared.state.end()
return res
|