diff options
Diffstat (limited to 'modules')
28 files changed, 691 insertions, 179 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 11045292..2a4cd8a2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,6 @@ import base64 import io +import os import time import datetime import uvicorn @@ -98,14 +99,16 @@ def encode_pil_to_base64(image): def api_middleware(app: FastAPI): - rich_available = True + rich_available = False try: - import anyio # importing just so it can be placed on silent list - import starlette # importing just so it can be placed on silent list - from rich.console import Console - console = Console() + if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + console = Console() + rich_available = True except Exception: - rich_available = False + pass @app.middleware("http") async def log_and_time(req: Request, call_next): @@ -116,14 +119,14 @@ def api_middleware(app: FastAPI): endpoint = req.scope.get('path', 'err') if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( - t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), - code = res.status_code, - ver = req.scope.get('http_version', '0.0'), - cli = req.scope.get('client', ('0:0.0.0', 0))[0], - prot = req.scope.get('scheme', 'err'), - method = req.scope.get('method', 'err'), - endpoint = endpoint, - duration = duration, + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get('http_version', '0.0'), + cli=req.scope.get('client', ('0:0.0.0', 0))[0], + prot=req.scope.get('scheme', 'err'), + method=req.scope.get('method', 'err'), + endpoint=endpoint, + duration=duration, )) return res @@ -134,7 +137,7 @@ def api_middleware(app: FastAPI): "body": vars(e).get('body', ''), "errors": str(e), } - if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions message = f"API error: {request.method}: {request.url} {err}" if rich_available: print(message) diff --git a/modules/api/models.py b/modules/api/models.py index b5683071..bf97b1a3 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,4 +1,5 @@ import inspect + from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal @@ -207,11 +208,12 @@ class PreprocessResponse(BaseModel): fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) - optType = opts.typemap.get(type(metadata.default), type(value)) + optType = opts.typemap.get(type(metadata.default), type(metadata.default)) - if (metadata is not None): - fields.update({key: (Optional[optType], Field( - default=metadata.default ,description=metadata.label))}) + if metadata.default is None: + pass + elif metadata is not None: + fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))}) else: fields.update({key: (Optional[optType], Field())}) diff --git a/modules/cache.py b/modules/cache.py new file mode 100644 index 00000000..71fe6302 --- /dev/null +++ b/modules/cache.py @@ -0,0 +1,120 @@ +import json
+import os.path
+import threading
+import time
+
+from modules.paths import data_path, script_path
+
+cache_filename = os.path.join(data_path, "cache.json")
+cache_data = None
+cache_lock = threading.Lock()
+
+dump_cache_after = None
+dump_cache_thread = None
+
+
+def dump_cache():
+ """
+ Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written.
+ """
+
+ global dump_cache_after
+ global dump_cache_thread
+
+ def thread_func():
+ global dump_cache_after
+ global dump_cache_thread
+
+ while dump_cache_after is not None and time.time() < dump_cache_after:
+ time.sleep(1)
+
+ with cache_lock:
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+ dump_cache_after = None
+ dump_cache_thread = None
+
+ with cache_lock:
+ dump_cache_after = time.time() + 5
+ if dump_cache_thread is None:
+ dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func)
+ dump_cache_thread.start()
+
+
+def cache(subsection):
+ """
+ Retrieves or initializes a cache for a specific subsection.
+
+ Parameters:
+ subsection (str): The subsection identifier for the cache.
+
+ Returns:
+ dict: The cache data for the specified subsection.
+ """
+
+ global cache_data
+
+ if cache_data is None:
+ with cache_lock:
+ if cache_data is None:
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ try:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+ except Exception:
+ os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
+ print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
+ cache_data = {}
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def cached_data_for_file(subsection, title, filename, func):
+ """
+ Retrieves or generates data for a specific file, using a caching mechanism.
+
+ Parameters:
+ subsection (str): The subsection of the cache to use.
+ title (str): The title of the data entry in the subsection of the cache.
+ filename (str): The path to the file to be checked for modifications.
+ func (callable): A function that generates the data if it is not available in the cache.
+
+ Returns:
+ dict or None: The cached or generated data, or None if data generation fails.
+
+ The `cached_data_for_file` function implements a caching mechanism for data stored in files.
+ It checks if the data associated with the given `title` is present in the cache and compares the
+ modification time of the file with the cached modification time. If the file has been modified,
+ the cache is considered invalid and the data is regenerated using the provided `func`.
+ Otherwise, the cached data is returned.
+
+ If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
+ or cached data is returned as a dictionary.
+ """
+
+ existing_cache = cache(subsection)
+ ondisk_mtime = os.path.getmtime(filename)
+
+ entry = existing_cache.get(title)
+ if entry:
+ cached_mtime = entry.get("mtime", 0)
+ if ondisk_mtime > cached_mtime:
+ entry = None
+
+ if not entry or 'value' not in entry:
+ value = func()
+ if value is None:
+ return None
+
+ entry = {'mtime': ondisk_mtime, 'value': value}
+ existing_cache[title] = entry
+
+ dump_cache()
+
+ return entry['value']
diff --git a/modules/call_queue.py b/modules/call_queue.py index 3b94f8a4..61aa240f 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -85,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
- elapsed_text = f"{elapsed_s:.2f}s"
+ elapsed_text = f"{elapsed_s:.1f} sec."
if elapsed_m > 0:
- elapsed_text = f"{elapsed_m}m "+elapsed_text
+ elapsed_text = f"{elapsed_m} min. "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@@ -95,14 +95,22 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
- sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+ sys_pct = sys_peak/max(sys_total, 1) * 100
- vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+ toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
+ toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
+ toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
+
+ text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
+ text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
+ text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
+
+ vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
else:
vram_html = ''
# last item is always HTML
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
return tuple(res)
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index ae78f469..e401f641 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -15,6 +15,7 @@ parser.add_argument("--update-check", action='store_true', help="launch.py argum parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
+parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
diff --git a/modules/extensions.py b/modules/extensions.py index abc6e2b1..c561159a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,7 @@ import os
import threading
-from modules import shared, errors
+from modules import shared, errors, cache
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -21,6 +21,7 @@ def active(): class Extension:
lock = threading.Lock()
+ cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
@@ -36,15 +37,29 @@ class Extension: self.remote = None
self.have_info_from_repo = False
+ def to_dict(self):
+ return {x: getattr(self, x) for x in self.cached_fields}
+
+ def from_dict(self, d):
+ for field in self.cached_fields:
+ setattr(self, field, d[field])
+
def read_info_from_repo(self):
if self.is_builtin or self.have_info_from_repo:
return
- with self.lock:
- if self.have_info_from_repo:
- return
+ def read_from_repo():
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+
+ self.do_read_info_from_repo()
+
+ return self.to_dict()
- self.do_read_info_from_repo()
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
+ self.from_dict(d)
+ self.status = 'unknown'
def do_read_info_from_repo(self):
repo = None
@@ -58,7 +73,6 @@ class Extension: self.remote = None
else:
try:
- self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
commit = repo.head.commit
self.commit_date = commit.committed_date
diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 41799b0a..6ae07e91 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -4,16 +4,22 @@ from collections import defaultdict from modules import errors
extra_network_registry = {}
+extra_network_aliases = {}
def initialize():
extra_network_registry.clear()
+ extra_network_aliases.clear()
def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_extra_network_alias(extra_network, alias):
+ extra_network_aliases[alias] = extra_network
+
+
def register_default_extra_networks():
from modules.extra_networks_hypernet import ExtraNetworkHypernet
register_extra_network(ExtraNetworkHypernet())
@@ -82,20 +88,26 @@ def activate(p, extra_network_data): """call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
+ activated = []
+
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
+
+ if extra_network is None:
+ extra_network = extra_network_aliases.get(extra_network_name, None)
+
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue
try:
extra_network.activate(p, extra_network_args)
+ activated.append(extra_network)
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
+ if extra_network in activated:
continue
try:
diff --git a/modules/hashes.py b/modules/hashes.py index ec1187fe..b7a33b42 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -1,43 +1,11 @@ import hashlib
-import json
import os.path
-import filelock
-
from modules import shared
-from modules.paths import data_path, script_path
-
-
-cache_filename = os.path.join(data_path, "cache.json")
-cache_data = None
-
-
-def dump_cache():
- with filelock.FileLock(f"{cache_filename}.lock"):
- with open(cache_filename, "w", encoding="utf8") as file:
- json.dump(cache_data, file, indent=4)
-
-
-def cache(subsection):
- global cache_data
-
- if cache_data is None:
- with filelock.FileLock(f"{cache_filename}.lock"):
- if not os.path.isfile(cache_filename):
- cache_data = {}
- else:
- try:
- with open(cache_filename, "r", encoding="utf8") as file:
- cache_data = json.load(file)
- except Exception:
- os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
- print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
- cache_data = {}
-
- s = cache_data.get(subsection, {})
- cache_data[subsection] = s
+import modules.cache
- return s
+dump_cache = modules.cache.dump_cache
+cache = modules.cache.cache
def calculate_sha256(filename):
diff --git a/modules/images.py b/modules/images.py index 4bdedb7f..38aa933d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -363,7 +363,7 @@ class FilenameGenerator: 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
- 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
@@ -380,6 +380,7 @@ class FilenameGenerator: 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
+ 'none': lambda self: '', # Overrides the default so you can get just the sequence number
}
default_time_format = '%Y%m%d%H%M%S'
@@ -601,13 +602,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else:
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+ file_decoration = namegen.apply(file_decoration) + suffix
+
add_number = opts.save_images_add_number or file_decoration == ''
if file_decoration != "" and add_number:
file_decoration = f"-{file_decoration}"
- file_decoration = namegen.apply(file_decoration) + suffix
-
if add_number:
basecount = get_next_sequence_number(path, basename)
fullfn = None
diff --git a/modules/img2img.py b/modules/img2img.py index 664e2688..a811e7a4 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -240,4 +240,4 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 01ea7c91..03552bc2 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -1,4 +1,5 @@ # this scripts installs necessary requirements and launches main program in webui.py
+import re
import subprocess
import os
import sys
@@ -9,6 +10,9 @@ from functools import lru_cache from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
+from modules import timer
+
+timer.startup_timer.record("start")
args, _ = cmd_args.parser.parse_known_args()
@@ -69,10 +73,12 @@ def git_tag(): return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
try:
- from pathlib import Path
- changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
- with changelog_md.open(encoding="utf-8") as file:
- return next((line.strip() for line in file if line.strip()), "<none>")
+
+ changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")
+ with open(changelog_md, "r", encoding="utf-8") as file:
+ line = next((line.strip() for line in file if line.strip()), "<none>")
+ line = line.replace("## ", "")
+ return line
except Exception:
return "<none>"
@@ -224,6 +230,44 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension))
+re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
+
+
+def requrements_met(requirements_file):
+ """
+ Does a simple parse of a requirements.txt file to determine if all rerqirements in it
+ are already installed. Returns True if so, False if not installed or parsing fails.
+ """
+
+ import importlib.metadata
+ import packaging.version
+
+ with open(requirements_file, "r", encoding="utf8") as file:
+ for line in file:
+ if line.strip() == "":
+ continue
+
+ m = re.match(re_requirement, line)
+ if m is None:
+ return False
+
+ package = m.group(1).strip()
+ version_required = (m.group(2) or "").strip()
+
+ if version_required == "":
+ continue
+
+ try:
+ version_installed = importlib.metadata.version(package)
+ except Exception:
+ return False
+
+ if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
+ return False
+
+ return True
+
+
def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
@@ -309,7 +353,9 @@ def prepare_environment(): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
- run_pip(f"install -r \"{requirements_file}\"", "requirements")
+
+ if not requrements_met(requirements_file):
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file)
diff --git a/modules/processing.py b/modules/processing.py index eb4a60eb..a74a5302 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -14,7 +14,7 @@ from skimage import exposure from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -538,6 +538,40 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x
+def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
+ samples = []
+
+ for i in range(batch.shape[0]):
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if check_for_nans:
+ try:
+ devices.test_for_nans(sample, "vae")
+ except devices.NansException as e:
+ if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ raise e
+
+ errors.print_error_explanation(
+ "A tensor with all NaNs was produced in VAE.\n"
+ "Web UI will now convert VAE into 32-bit float and retry.\n"
+ "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
+ "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ )
+
+ devices.dtype_vae = torch.float32
+ model.first_stage_model.to(devices.dtype_vae)
+ batch = batch.to(devices.dtype_vae)
+
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if target_device is not None:
+ sample = sample.to(target_device)
+
+ samples.append(sample)
+
+ return samples
+
+
def decode_first_stage(model, x):
x = model.decode_first_stage(x.to(devices.dtype_vae))
@@ -587,7 +621,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
+ "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -747,9 +781,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds()
- if len(model_hijack.comments) > 0:
- for comment in model_hijack.comments:
- comments[comment] = 1
+ for comment in model_hijack.comments:
+ comments[comment] = 1
+
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
@@ -757,10 +792,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
- for x in x_samples_ddim:
- devices.test_for_nans(x, "vae")
-
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -1028,7 +1060,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): image = sd_samplers.sample_to_image(image, index, approximation=0)
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
if latent_scale_mode is not None:
for i in range(samples.shape[0]):
@@ -1302,7 +1334,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images)
image = 2. * image - 1.
|