diff options
Diffstat (limited to 'modules/shared.py')
-rw-r--r-- | modules/shared.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/modules/shared.py b/modules/shared.py index c2775603..63fb4cec 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,6 +3,7 @@ import datetime import json
import os
import sys
+from collections import OrderedDict
import gradio as gr
import tqdm
@@ -30,6 +31,8 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'),
+ help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
@@ -103,6 +106,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+aesthetic_embeddings = {}
+
+def update_aesthetic_embeddings():
+ global aesthetic_embeddings
+ aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
+ os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+ aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
+
+update_aesthetic_embeddings()
def reload_hypernetworks():
global hypernetworks
@@ -381,6 +393,11 @@ sd_upscalers = [] sd_model = None
+clip_model = None
+
+from modules.aesthetic_clip import AestheticCLIP
+aesthetic_clip = AestheticCLIP()
+
progress_print_out = sys.stdout
|