diff options
Diffstat (limited to 'modules')
33 files changed, 936 insertions, 120 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 518b2a61..cdbdce32 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,7 +6,6 @@ import uvicorn import gradio as gr from threading import Lock from io import BytesIO -from gradio.processing_utils import decode_base64_to_file from fastapi import APIRouter, Depends, FastAPI, Request, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.exceptions import HTTPException @@ -131,8 +130,8 @@ def api_middleware(app: FastAPI): "body": vars(e).get('body', ''), "errors": str(e), } - print(f"API error: {request.method}: {request.url} {err}") if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + print(f"API error: {request.method}: {request.url} {err}") if rich_available: console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: @@ -272,7 +271,9 @@ class Api: raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") # always on script with no arg should always run so you don't really need to add them to the requests if "args" in request.alwayson_scripts[alwayson_script_name]: - script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"] + # min between arg length in scriptrunner and arg length in the request + for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))): + script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): @@ -395,16 +396,11 @@ class Api: def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): reqDict = setUpscalers(req) - def prepareFiles(file): - file = decode_base64_to_file(file.data, file_path=file.name) - file.orig_name = file.name - return file - - reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList'])) - reqDict.pop('imageList') + image_list = reqDict.pop('imageList', []) + image_folder = [decode_base64_to_image(x.data) for x in image_list] with self.queue_lock: - result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) diff --git a/modules/call_queue.py b/modules/call_queue.py index 92097c15..1829f3a6 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): try:
res = func(*args, **kwargs)
+ progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index bdf106bf..d906a571 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -95,6 +95,7 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin( parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
diff --git a/modules/config_states.py b/modules/config_states.py new file mode 100644 index 00000000..2ea00929 --- /dev/null +++ b/modules/config_states.py @@ -0,0 +1,200 @@ +""" +Supports saving and restoring webui and extensions from a known working set of commits +""" + +import os +import sys +import traceback +import json +import time +import tqdm + +from datetime import datetime +from collections import OrderedDict +import git + +from modules import shared, extensions +from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir + + +all_config_states = OrderedDict() + + +def list_config_states(): + global all_config_states + + all_config_states.clear() + os.makedirs(config_states_dir, exist_ok=True) + + config_states = [] + for filename in os.listdir(config_states_dir): + if filename.endswith(".json"): + path = os.path.join(config_states_dir, filename) + with open(path, "r", encoding="utf-8") as f: + j = json.load(f) + j["filepath"] = path + config_states.append(j) + + config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)) + + for cs in config_states: + timestamp = time.asctime(time.gmtime(cs["created_at"])) + name = cs.get("name", "Config") + full_name = f"{name}: {timestamp}" + all_config_states[full_name] = cs + + return all_config_states + + +def get_webui_config(): + webui_repo = None + + try: + if os.path.exists(os.path.join(script_path, ".git")): + webui_repo = git.Repo(script_path) + except Exception: + print(f"Error reading webui git info from {script_path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + webui_remote = None + webui_commit_hash = None + webui_commit_date = None + webui_branch = None + if webui_repo and not webui_repo.bare: + try: + webui_remote = next(webui_repo.remote().urls, None) + head = webui_repo.head.commit + webui_commit_date = webui_repo.head.commit.committed_date + webui_commit_hash = head.hexsha + webui_branch = webui_repo.active_branch.name + + except Exception: + webui_remote = None + + return { + "remote": webui_remote, + "commit_hash": webui_commit_hash, + "commit_date": webui_commit_date, + "branch": webui_branch, + } + + +def get_extension_config(): + ext_config = {} + + for ext in extensions.extensions: + entry = { + "name": ext.name, + "path": ext.path, + "enabled": ext.enabled, + "is_builtin": ext.is_builtin, + "remote": ext.remote, + "commit_hash": ext.commit_hash, + "commit_date": ext.commit_date, + "branch": ext.branch, + "have_info_from_repo": ext.have_info_from_repo + } + + ext_config[ext.name] = entry + + return ext_config + + +def get_config(): + creation_time = datetime.now().timestamp() + webui_config = get_webui_config() + ext_config = get_extension_config() + + return { + "created_at": creation_time, + "webui": webui_config, + "extensions": ext_config + } + + +def restore_webui_config(config): + print("* Restoring webui state...") + + if "webui" not in config: + print("Error: No webui data saved to config") + return + + webui_config = config["webui"] + + if "commit_hash" not in webui_config: + print("Error: No commit saved to webui config") + return + + webui_commit_hash = webui_config.get("commit_hash", None) + webui_repo = None + + try: + if os.path.exists(os.path.join(script_path, ".git")): + webui_repo = git.Repo(script_path) + except Exception: + print(f"Error reading webui git info from {script_path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return + + try: + webui_repo.git.fetch(all=True) + webui_repo.git.reset(webui_commit_hash, hard=True) + print(f"* Restored webui to commit {webui_commit_hash}.") + except Exception: + print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + +def restore_extension_config(config): + print("* Restoring extension state...") + + if "extensions" not in config: + print("Error: No extension data saved to config") + return + + ext_config = config["extensions"] + + results = [] + disabled = [] + + for ext in tqdm.tqdm(extensions.extensions): + if ext.is_builtin: + continue + + ext.read_info_from_repo() + current_commit = ext.commit_hash + + if ext.name not in ext_config: + ext.disabled = True + disabled.append(ext.name) + results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled")) + continue + + entry = ext_config[ext.name] + + if "commit_hash" in entry and entry["commit_hash"]: + try: + ext.fetch_and_reset_hard(entry["commit_hash"]) + ext.read_info_from_repo() + if current_commit != entry["commit_hash"]: + results.append((ext, current_commit[:8], True, entry["commit_hash"][:8])) + except Exception as ex: + results.append((ext, current_commit[:8], False, ex)) + else: + results.append((ext, current_commit[:8], False, "No commit hash found in config")) + + if not entry.get("enabled", False): + ext.disabled = True + disabled.append(ext.name) + else: + ext.disabled = False + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + print("* Finished restoring extensions. Results:") + for ext, prev_commit, success, result in results: + if success: + print(f" + {ext.name}: {prev_commit} -> {result}") + else: + print(f" ! {ext.name}: FAILURE ({result})") diff --git a/modules/devices.py b/modules/devices.py index 52c3e7cd..c705a3cb 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -92,14 +92,18 @@ def cond_cast_float(input): def randn(seed, shape): + from modules.shared import opts + torch.manual_seed(seed) - if device.type == 'mps': + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def randn_without_seed(shape): - if device.type == 'mps': + from modules.shared import opts + + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) diff --git a/modules/extensions.py b/modules/extensions.py index 3a7a0372..34d9d654 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -3,10 +3,11 @@ import sys import traceback
import time
+from datetime import datetime
import git
from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
extensions = []
@@ -31,12 +32,15 @@ class Extension: self.status = ''
self.can_update = False
self.is_builtin = is_builtin
+ self.commit_hash = ''
+ self.commit_date = None
self.version = ''
+ self.branch = None
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
- if self.have_info_from_repo:
+ if self.is_builtin or self.have_info_from_repo:
return
self.have_info_from_repo = True
@@ -56,10 +60,15 @@ class Extension: self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit
- ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
- self.version = f'{head.hexsha[:8]} ({ts})'
-
- except Exception:
+ self.commit_date = repo.head.commit.committed_date
+ ts = time.asctime(time.gmtime(self.commit_date))
+ if repo.active_branch:
+ self.branch = repo.active_branch.name
+ self.commit_hash = head.hexsha
+ self.version = f'{self.commit_hash[:8]} ({ts})'
+
+ except Exception as ex:
+ print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
def list_files(self, subdir, extension):
@@ -82,18 +91,30 @@ class Extension: for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
- self.status = "behind"
+ self.status = "new commits"
+ return
+
+ try:
+ origin = repo.rev_parse('origin')
+ if repo.head.commit != origin:
+ self.can_update = True
+ self.status = "behind HEAD"
return
+ except Exception:
+ self.can_update = False
+ self.status = "unknown (remote error)"
+ return
self.can_update = False
self.status = "latest"
- def fetch_and_reset_hard(self):
+ def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
- repo.git.reset('origin', hard=True)
+ repo.git.reset(commit, hard=True)
+ self.have_info_from_repo = False
def list_extensions():
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index d3a4d7ad..33d100dd 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
- if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
diff --git a/modules/extras.py b/modules/extras.py index d8ece955..ff4e9c4e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,6 +1,7 @@ import os
import re
import shutil
+import json
import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable): return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -241,13 +242,54 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
+ metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+
+ if save_metadata:
+ merge_recipe = {
+ "type": "webui", # indicate this model was merged with webui's built-in merger
+ "primary_model_hash": primary_model_info.sha256,
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+ "interp_method": interp_method,
+ "multiplier": multiplier,
+ "save_as_half": save_as_half,
+ "custom_name": custom_name,
+ "config_source": config_source,
+ "bake_in_vae": bake_in_vae,
+ "discard_weights": discard_weights,
+ "is_inpainting": result_is_inpainting_model,
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
+ }
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+
+ def add_model_metadata(checkpoint_info):
+ checkpoint_info.calculate_shorthash()
+ metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ "name": checkpoint_info.name,
+ "legacy_hash": checkpoint_info.hash,
+ "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
+ }
+
+ metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+
+ add_model_metadata(primary_model_info)
+ if secondary_model_info:
+ add_model_metadata(secondary_model_info)
+ if tertiary_model_info:
+ add_model_metadata(tertiary_model_info)
+
+ metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
else:
torch.save(theta_0, output_modelname)
sd_models.list_models()
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+ if created_model:
+ created_model.calculate_shorthash()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 6df76858..99f1a0d3 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model restore_old_hires_fix_params(res)
+ # Missing RNG means the default was set, which is GPU RNG
+ if "RNG" not in res:
+ res["RNG"] = "GPU"
+
return res
@@ -304,6 +308,8 @@ infotext_to_setting_name_mapping = [ ('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ('RNG', 'randn_source'),
+ ('NGMS', 's_min_uncond'),
]
diff --git a/modules/images.py b/modules/images.py index b3535070..fd173829 100644 --- a/modules/images.py +++ b/modules/images.py @@ -318,6 +318,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
+NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
def sanitize_filename_part(text, replace_spaces=True):
@@ -352,6 +353,10 @@ class FilenameGenerator: 'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
+ 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
}
default_time_format = '%Y%m%d%H%M%S'
@@ -360,6 +365,22 @@ class FilenameGenerator: self.seed = seed
self.prompt = prompt
self.image = image
+
+ def hasprompt(self, *args):
+ lower = self.prompt.lower()
+ if self.p is None or self.prompt is None:
+ return None
+ outres = ""
+ for arg in args:
+ if arg != "":
+ division = arg.split("|")
+ expected = division[0].lower()
+ default = division[1] if len(division) > 1 else ""
+ if lower.find(expected) >= 0:
+ outres = f'{outres}{expected}'
+ else:
+ outres = outres if default == "" else f'{outres}{default}'
+ return sanitize_filename_part(outres)
def prompt_no_style(self):
if self.p is None or self.prompt is None:
@@ -403,9 +424,9 @@ class FilenameGenerator: for m in re_pattern.finditer(x):
text, pattern = m.groups()
- res += text
if pattern is None:
+ res += text
continue
pattern_args = []
@@ -426,11 +447,13 @@ class FilenameGenerator: print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- if replacement is not None:
- res += str(replacement)
+ if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
+ continue
+ elif replacement is not None:
+ res += text + str(replacement)
continue
- res += f'[{pattern}]'
+ res += f'{text}[{pattern}]'
return res
diff --git a/modules/img2img.py b/modules/img2img.py index 953ac5d2..56c846d6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -4,7 +4,7 @@ import sys import traceback
import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
from modules import devices, sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
@@ -46,7 +46,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): if state.interrupted:
break
- img = Image.open(image)
+ try:
+ img = Image.open(image)
+ except UnidentifiedImageError:
+ continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
@@ -78,7 +81,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -114,6 +117,12 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if image is not None:
image = ImageOps.exif_transpose(image)
+ if selected_scale_tab == 1:
+ assert image, "Can't scale by because no image is selected"
+
+ width = int(image.width * scale_by)
+ height = int(image.height * scale_by)
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
@@ -151,7 +160,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s override_settings=override_settings,
)
- p.scripts = modules.scripts.scripts_txt2img
+ p.scripts = modules.scripts.scripts_img2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts:
diff --git a/modules/interrogate.py b/modules/interrogate.py index cbb80683..e1665708 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir): category_types = ["artists", "flavors", "mediums", "movements"]
try:
- os.makedirs(tmpdir)
+ os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
@@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir): |