aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py18
-rw-r--r--modules/call_queue.py1
-rw-r--r--modules/cmd_args.py1
-rw-r--r--modules/config_states.py200
-rw-r--r--modules/devices.py8
-rw-r--r--modules/extensions.py39
-rw-r--r--modules/extra_networks_hypernet.py2
-rw-r--r--modules/extras.py46
-rw-r--r--modules/generation_parameters_copypaste.py6
-rw-r--r--modules/images.py31
-rw-r--r--modules/img2img.py17
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/ngrok.py12
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/postprocessing.py9
-rw-r--r--modules/processing.py44
-rw-r--r--modules/progress.py18
-rw-r--r--modules/realesrgan_model.py14
-rw-r--r--modules/safe.py5
-rw-r--r--modules/script_callbacks.py14
-rw-r--r--modules/sd_models.py11
-rw-r--r--modules/sd_samplers_common.py10
-rw-r--r--modules/sd_samplers_kdiffusion.py46
-rw-r--r--modules/shared.py48
-rw-r--r--modules/styles.py12
-rw-r--r--modules/textual_inversion/preprocess.py14
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/ui.py152
-rw-r--r--modules/ui_common.py2
-rw-r--r--modules/ui_components.py10
-rw-r--r--modules/ui_extensions.py245
-rw-r--r--modules/ui_extra_networks.py2
-rw-r--r--modules/ui_postprocessing.py8
33 files changed, 936 insertions, 120 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 518b2a61..cdbdce32 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -6,7 +6,6 @@ import uvicorn
import gradio as gr
from threading import Lock
from io import BytesIO
-from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
@@ -131,8 +130,8 @@ def api_middleware(app: FastAPI):
"body": vars(e).get('body', ''),
"errors": str(e),
}
- print(f"API error: {request.method}: {request.url} {err}")
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
+ print(f"API error: {request.method}: {request.url} {err}")
if rich_available:
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
@@ -272,7 +271,9 @@ class Api:
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]:
- script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
+ # min between arg length in scriptrunner and arg length in the request
+ for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
+ script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
@@ -395,16 +396,11 @@ class Api:
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
- def prepareFiles(file):
- file = decode_base64_to_file(file.data, file_path=file.name)
- file.orig_name = file.name
- return file
-
- reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
- reqDict.pop('imageList')
+ image_list = reqDict.pop('imageList', [])
+ image_folder = [decode_base64_to_image(x.data) for x in image_list]
with self.queue_lock:
- result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 92097c15..1829f3a6 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
try:
res = func(*args, **kwargs)
+ progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index bdf106bf..d906a571 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -95,6 +95,7 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
diff --git a/modules/config_states.py b/modules/config_states.py
new file mode 100644
index 00000000..2ea00929
--- /dev/null
+++ b/modules/config_states.py
@@ -0,0 +1,200 @@
+"""
+Supports saving and restoring webui and extensions from a known working set of commits
+"""
+
+import os
+import sys
+import traceback
+import json
+import time
+import tqdm
+
+from datetime import datetime
+from collections import OrderedDict
+import git
+
+from modules import shared, extensions
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
+
+
+all_config_states = OrderedDict()
+
+
+def list_config_states():
+ global all_config_states
+
+ all_config_states.clear()
+ os.makedirs(config_states_dir, exist_ok=True)
+
+ config_states = []
+ for filename in os.listdir(config_states_dir):
+ if filename.endswith(".json"):
+ path = os.path.join(config_states_dir, filename)
+ with open(path, "r", encoding="utf-8") as f:
+ j = json.load(f)
+ j["filepath"] = path
+ config_states.append(j)
+
+ config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
+
+ for cs in config_states:
+ timestamp = time.asctime(time.gmtime(cs["created_at"]))
+ name = cs.get("name", "Config")
+ full_name = f"{name}: {timestamp}"
+ all_config_states[full_name] = cs
+
+ return all_config_states
+
+
+def get_webui_config():
+ webui_repo = None
+
+ try:
+ if os.path.exists(os.path.join(script_path, ".git")):
+ webui_repo = git.Repo(script_path)
+ except Exception:
+ print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ webui_remote = None
+ webui_commit_hash = None
+ webui_commit_date = None
+ webui_branch = None
+ if webui_repo and not webui_repo.bare:
+ try:
+ webui_remote = next(webui_repo.remote().urls, None)
+ head = webui_repo.head.commit
+ webui_commit_date = webui_repo.head.commit.committed_date
+ webui_commit_hash = head.hexsha
+ webui_branch = webui_repo.active_branch.name
+
+ except Exception:
+ webui_remote = None
+
+ return {
+ "remote": webui_remote,
+ "commit_hash": webui_commit_hash,
+ "commit_date": webui_commit_date,
+ "branch": webui_branch,
+ }
+
+
+def get_extension_config():
+ ext_config = {}
+
+ for ext in extensions.extensions:
+ entry = {
+ "name": ext.name,
+ "path": ext.path,
+ "enabled": ext.enabled,
+ "is_builtin": ext.is_builtin,
+ "remote": ext.remote,
+ "commit_hash": ext.commit_hash,
+ "commit_date": ext.commit_date,
+ "branch": ext.branch,
+ "have_info_from_repo": ext.have_info_from_repo
+ }
+
+ ext_config[ext.name] = entry
+
+ return ext_config
+
+
+def get_config():
+ creation_time = datetime.now().timestamp()
+ webui_config = get_webui_config()
+ ext_config = get_extension_config()
+
+ return {
+ "created_at": creation_time,
+ "webui": webui_config,
+ "extensions": ext_config
+ }
+
+
+def restore_webui_config(config):
+ print("* Restoring webui state...")
+
+ if "webui" not in config:
+ print("Error: No webui data saved to config")
+ return
+
+ webui_config = config["webui"]
+
+ if "commit_hash" not in webui_config:
+ print("Error: No commit saved to webui config")
+ return
+
+ webui_commit_hash = webui_config.get("commit_hash", None)
+ webui_repo = None
+
+ try:
+ if os.path.exists(os.path.join(script_path, ".git")):
+ webui_repo = git.Repo(script_path)
+ except Exception:
+ print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return
+
+ try:
+ webui_repo.git.fetch(all=True)
+ webui_repo.git.reset(webui_commit_hash, hard=True)
+ print(f"* Restored webui to commit {webui_commit_hash}.")
+ except Exception:
+ print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+def restore_extension_config(config):
+ print("* Restoring extension state...")
+
+ if "extensions" not in config:
+ print("Error: No extension data saved to config")
+ return
+
+ ext_config = config["extensions"]
+
+ results = []
+ disabled = []
+
+ for ext in tqdm.tqdm(extensions.extensions):
+ if ext.is_builtin:
+ continue
+
+ ext.read_info_from_repo()
+ current_commit = ext.commit_hash
+
+ if ext.name not in ext_config:
+ ext.disabled = True
+ disabled.append(ext.name)
+ results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
+ continue
+
+ entry = ext_config[ext.name]
+
+ if "commit_hash" in entry and entry["commit_hash"]:
+ try:
+ ext.fetch_and_reset_hard(entry["commit_hash"])
+ ext.read_info_from_repo()
+ if current_commit != entry["commit_hash"]:
+ results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
+ except Exception as ex:
+ results.append((ext, current_commit[:8], False, ex))
+ else:
+ results.append((ext, current_commit[:8], False, "No commit hash found in config"))
+
+ if not entry.get("enabled", False):
+ ext.disabled = True
+ disabled.append(ext.name)
+ else:
+ ext.disabled = False
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ print("* Finished restoring extensions. Results:")
+ for ext, prev_commit, success, result in results:
+ if success:
+ print(f" + {ext.name}: {prev_commit} -> {result}")
+ else:
+ print(f" ! {ext.name}: FAILURE ({result})")
diff --git a/modules/devices.py b/modules/devices.py
index 52c3e7cd..c705a3cb 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape):
+ from modules.shared import opts
+
torch.manual_seed(seed)
- if device.type == 'mps':
+ if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
- if device.type == 'mps':
+ from modules.shared import opts
+
+ if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
diff --git a/modules/extensions.py b/modules/extensions.py
index 3a7a0372..34d9d654 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -3,10 +3,11 @@ import sys
import traceback
import time
+from datetime import datetime
import git
from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
extensions = []
@@ -31,12 +32,15 @@ class Extension:
self.status = ''
self.can_update = False
self.is_builtin = is_builtin
+ self.commit_hash = ''
+ self.commit_date = None
self.version = ''
+ self.branch = None
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
- if self.have_info_from_repo:
+ if self.is_builtin or self.have_info_from_repo:
return
self.have_info_from_repo = True
@@ -56,10 +60,15 @@ class Extension:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit
- ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
- self.version = f'{head.hexsha[:8]} ({ts})'
-
- except Exception:
+ self.commit_date = repo.head.commit.committed_date
+ ts = time.asctime(time.gmtime(self.commit_date))
+ if repo.active_branch:
+ self.branch = repo.active_branch.name
+ self.commit_hash = head.hexsha
+ self.version = f'{self.commit_hash[:8]} ({ts})'
+
+ except Exception as ex:
+ print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
def list_files(self, subdir, extension):
@@ -82,18 +91,30 @@ class Extension:
for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
- self.status = "behind"
+ self.status = "new commits"
+ return
+
+ try:
+ origin = repo.rev_parse('origin')
+ if repo.head.commit != origin:
+ self.can_update = True
+ self.status = "behind HEAD"
return
+ except Exception:
+ self.can_update = False
+ self.status = "unknown (remote error)"
+ return
self.can_update = False
self.status = "latest"
- def fetch_and_reset_hard(self):
+ def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
- repo.git.reset('origin', hard=True)
+ repo.git.reset(commit, hard=True)
+ self.have_info_from_repo = False
def list_extensions():
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index d3a4d7ad..33d100dd 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
- if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
diff --git a/modules/extras.py b/modules/extras.py
index d8ece955..ff4e9c4e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,6 +1,7 @@
import os
import re
import shutil
+import json
import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -241,13 +242,54 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
+ metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+
+ if save_metadata:
+ merge_recipe = {
+ "type": "webui", # indicate this model was merged with webui's built-in merger
+ "primary_model_hash": primary_model_info.sha256,
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+ "interp_method": interp_method,
+ "multiplier": multiplier,
+ "save_as_half": save_as_half,
+ "custom_name": custom_name,
+ "config_source": config_source,
+ "bake_in_vae": bake_in_vae,
+ "discard_weights": discard_weights,
+ "is_inpainting": result_is_inpainting_model,
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
+ }
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+
+ def add_model_metadata(checkpoint_info):
+ checkpoint_info.calculate_shorthash()
+ metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ "name": checkpoint_info.name,
+ "legacy_hash": checkpoint_info.hash,
+ "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
+ }
+
+ metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+
+ add_model_metadata(primary_model_info)
+ if secondary_model_info:
+ add_model_metadata(secondary_model_info)
+ if tertiary_model_info:
+ add_model_metadata(tertiary_model_info)
+
+ metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
else:
torch.save(theta_0, output_modelname)
sd_models.list_models()
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+ if created_model:
+ created_model.calculate_shorthash()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 6df76858..99f1a0d3 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res)
+ # Missing RNG means the default was set, which is GPU RNG
+ if "RNG" not in res:
+ res["RNG"] = "GPU"
+
return res
@@ -304,6 +308,8 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ('RNG', 'randn_source'),
+ ('NGMS', 's_min_uncond'),
]
diff --git a/modules/images.py b/modules/images.py
index b3535070..fd173829 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -318,6 +318,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
+NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
def sanitize_filename_part(text, replace_spaces=True):
@@ -352,6 +353,10 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
+ 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
}
default_time_format = '%Y%m%d%H%M%S'
@@ -360,6 +365,22 @@ class FilenameGenerator:
self.seed = seed
self.prompt = prompt
self.image = image
+
+ def hasprompt(self, *args):
+ lower = self.prompt.lower()
+ if self.p is None or self.prompt is None:
+ return None
+ outres = ""
+ for arg in args:
+ if arg != "":
+ division = arg.split("|")
+ expected = division[0].lower()
+ default = division[1] if len(division) > 1 else ""
+ if lower.find(expected) >= 0:
+ outres = f'{outres}{expected}'
+ else:
+ outres = outres if default == "" else f'{outres}{default}'
+ return sanitize_filename_part(outres)
def prompt_no_style(self):
if self.p is None or self.prompt is None:
@@ -403,9 +424,9 @@ class FilenameGenerator:
for m in re_pattern.finditer(x):
text, pattern = m.groups()
- res += text
if pattern is None:
+ res += text
continue
pattern_args = []
@@ -426,11 +447,13 @@ class FilenameGenerator:
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- if replacement is not None:
- res += str(replacement)
+ if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
+ continue
+ elif replacement is not None:
+ res += text + str(replacement)
continue
- res += f'[{pattern}]'
+ res += f'{text}[{pattern}]'
return res
diff --git a/modules/img2img.py b/modules/img2img.py
index 953ac5d2..56c846d6 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -4,7 +4,7 @@ import sys
import traceback
import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
from modules import devices, sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
@@ -46,7 +46,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
if state.interrupted:
break
- img = Image.open(image)
+ try:
+ img = Image.open(image)
+ except UnidentifiedImageError:
+ continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
@@ -78,7 +81,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -114,6 +117,12 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if image is not None:
image = ImageOps.exif_transpose(image)
+ if selected_scale_tab == 1:
+ assert image, "Can't scale by because no image is selected"
+
+ width = int(image.width * scale_by)
+ height = int(image.height * scale_by)
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
@@ -151,7 +160,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
override_settings=override_settings,
)
- p.scripts = modules.scripts.scripts_txt2img
+ p.scripts = modules.scripts.scripts_img2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts:
diff --git a/modules/interrogate.py b/modules/interrogate.py
index cbb80683..e1665708 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
category_types = ["artists", "flavors", "mediums", "movements"]
try:
- os.makedirs(tmpdir)
+ os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
@@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):