diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 8 | ||||
-rw-r--r-- | modules/artists.py | 25 | ||||
-rw-r--r-- | modules/interrogate.py | 55 | ||||
-rw-r--r-- | modules/shared.py | 5 | ||||
-rw-r--r-- | modules/ui.py | 11 |
5 files changed, 46 insertions, 58 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 2c371e6e..f2e9e884 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -126,8 +126,6 @@ class Api: self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) @@ -390,12 +388,6 @@ class Api: return styleList - def get_artists_categories(self): - return shared.artist_db.cats - - def get_artists(self): - return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] - def get_embeddings(self): db = sd_hijack.model_hijack.embedding_db diff --git a/modules/artists.py b/modules/artists.py deleted file mode 100644 index 3612758b..00000000 --- a/modules/artists.py +++ /dev/null @@ -1,25 +0,0 @@ -import os.path
-import csv
-from collections import namedtuple
-
-Artist = namedtuple("Artist", ['name', 'weight', 'category'])
-
-
-class ArtistsDatabase:
- def __init__(self, filename):
- self.cats = set()
- self.artists = []
-
- if not os.path.exists(filename):
- return
-
- with open(filename, "r", newline='', encoding="utf8") as file:
- reader = csv.DictReader(file)
-
- for row in reader:
- artist = Artist(row["artist"], float(row["score"]), row["category"])
- self.artists.append(artist)
- self.cats.add(artist.category)
-
- def categories(self):
- return sorted(self.cats)
diff --git a/modules/interrogate.py b/modules/interrogate.py index 738d8ff7..19938cbb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -5,12 +5,13 @@ from collections import namedtuple import re
import torch
+import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram, modelloader
+from modules import devices, paths, lowvram, modelloader, errors
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.")
+def download_default_clip_interrogate_categories(content_dir):
+ print("Downloading CLIP categories...")
+
+ tmpdir = content_dir + "_tmp"
+ try:
+ os.makedirs(tmpdir)
+
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
+
+ os.rename(tmpdir, content_dir)
+
+ except Exception as e:
+ errors.display(e, "downloading default CLIP interrogate categories")
+ finally:
+ if os.path.exists(tmpdir):
+ os.remove(tmpdir)
+
+
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
- categories = None
dtype = None
running_on_cpu = None
def __init__(self, content_dir):
- self.categories = []
+ self.loaded_categories = None
+ self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
- if os.path.exists(content_dir):
- for filename in os.listdir(content_dir):
+ def categories(self):
+ if self.loaded_categories is not None:
+ return self.loaded_categories
+
+ self.loaded_categories = []
+
+ if not os.path.exists(self.content_dir):
+ download_default_clip_interrogate_categories(self.content_dir)
+
+ if os.path.exists(self.content_dir):
+ for filename in os.listdir(self.content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))
- with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
+ with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
+
+ return self.loaded_categories
def load_blip_model(self):
import models.blip
@@ -139,7 +172,6 @@ class InterrogateModels: shared.state.begin()
shared.state.job = 'interrogate'
try:
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -159,12 +191,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True)
- if shared.opts.interrogate_use_builtin_artists:
- artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
-
- res += ", " + artist[0]
-
- for name, topn, items in self.categories:
+ for name, topn, items in self.categories():
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
diff --git a/modules/shared.py b/modules/shared.py index c0e11f18..72fb1934 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,7 +9,6 @@ from PIL import Image import gradio as gr
import tqdm
-import modules.artists
import modules.interrogate
import modules.memmon
import modules.styles
@@ -254,8 +253,6 @@ class State: state = State()
state.server_start = time.time()
-artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
-
styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename)
@@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
- "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
- "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
diff --git a/modules/ui.py b/modules/ui.py index d23b2b8e..164e0e93 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
- return [gr_show(True), None]
+ return [gr.update(), None]
def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB"))
- return gr_show(True) if prompt is None else prompt
+ return gr.update() if prompt is None else prompt
def interrogate_deepbooru(image):
prompt = deepbooru.model.tag(image)
- return gr_show(True) if prompt is None else prompt
+ return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface):
@@ -1039,19 +1039,18 @@ def create_ui(): init_img_inpaint,
],
outputs=[img2img_prompt, dummy_component],
- show_progress=False,
)
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
img2img_interrogate.click(
- fn=lambda *args : process_interrogate(interrogate, *args),
+ fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args,
)
img2img_deepbooru.click(
- fn=lambda *args : process_interrogate(interrogate_deepbooru, *args),
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
**interrogate_args,
)
|