aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py49
-rw-r--r--modules/api/models.py8
-rw-r--r--modules/call_queue.py4
-rw-r--r--modules/cmd_args.py5
-rw-r--r--modules/devices.py83
-rw-r--r--modules/errors.py53
-rw-r--r--modules/extensions.py20
-rw-r--r--modules/extra_networks.py16
-rw-r--r--modules/extras.py39
-rw-r--r--modules/gradio_extensons.py60
-rw-r--r--modules/images.py2
-rw-r--r--modules/img2img.py32
-rw-r--r--modules/launch_utils.py75
-rw-r--r--modules/lowvram.py7
-rw-r--r--modules/processing.py93
-rw-r--r--modules/prompt_parser.py18
-rw-r--r--modules/rng_philox.py102
-rw-r--r--modules/script_loading.py5
-rw-r--r--modules/scripts.py80
-rw-r--r--modules/sd_disable_initialization.py106
-rw-r--r--modules/sd_hijack.py10
-rw-r--r--modules/sd_hijack_clip.py13
-rw-r--r--modules/sd_hijack_open_clip.py2
-rw-r--r--modules/sd_hijack_optimizations.py4
-rw-r--r--modules/sd_hijack_unet.py8
-rw-r--r--modules/sd_models.py53
-rw-r--r--modules/sd_models_xl.py13
-rw-r--r--modules/sd_samplers_common.py12
-rw-r--r--modules/sd_samplers_extra.py74
-rw-r--r--modules/sd_samplers_kdiffusion.py13
-rw-r--r--modules/shared.py9
-rw-r--r--modules/styles.py5
-rw-r--r--modules/sysinfo.py6
-rw-r--r--modules/textual_inversion/textual_inversion.py19
-rw-r--r--modules/timer.py23
-rw-r--r--modules/ui.py433
-rw-r--r--modules/ui_checkpoint_merger.py124
-rw-r--r--modules/ui_common.py34
-rw-r--r--modules/ui_components.py2
-rw-r--r--modules/ui_extensions.py26
-rw-r--r--modules/ui_extra_networks.py5
-rw-r--r--modules/ui_extra_networks_checkpoints.py3
-rw-r--r--modules/ui_extra_networks_hypernets.py2
-rw-r--r--modules/ui_extra_networks_textual_inversion.py2
-rw-r--r--modules/ui_extra_networks_user_metadata.py8
-rw-r--r--modules/ui_postprocessing.py2
-rw-r--r--modules/ui_prompt_styles.py110
47 files changed, 1323 insertions, 549 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 2a4cd8a2..908c4514 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -197,6 +197,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
+ self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
@@ -333,14 +334,17 @@ class Api:
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
- shared.state.begin(job="scripts_txt2img")
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ try:
+ shared.state.begin(job="scripts_txt2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ finally:
+ shared.state.end()
+ shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -390,14 +394,17 @@ class Api:
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
- shared.state.begin(job="scripts_img2img")
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ try:
+ shared.state.begin(job="scripts_img2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ finally:
+ shared.state.end()
+ shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -604,6 +611,10 @@ class Api:
with self.queue_lock:
shared.refresh_checkpoints()
+ def refresh_vae(self):
+ with self.queue_lock:
+ shared_items.refresh_vae_list()
+
def create_embedding(self, args: dict):
try:
shared.state.begin(job="create_embedding")
@@ -720,9 +731,9 @@ class Api:
cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda)
- def launch(self, server_name, port):
+ def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
def kill_webui(self):
restart.stop_program()
diff --git a/modules/api/models.py b/modules/api/models.py
index b5683071..800c9b93 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,4 +1,5 @@
import inspect
+
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
@@ -207,11 +208,10 @@ class PreprocessResponse(BaseModel):
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
- optType = opts.typemap.get(type(metadata.default), type(value))
+ optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
- if (metadata is not None):
- fields.update({key: (Optional[optType], Field(
- default=metadata.default ,description=metadata.label))})
+ if metadata is not None:
+ fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 61aa240f..f2eb17d6 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -3,7 +3,7 @@ import html
import threading
import time
-from modules import shared, progress, errors
+from modules import shared, progress, errors, devices
queue_lock = threading.Lock()
@@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
+ devices.torch_gc()
+
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index ae78f469..64f21e01 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -13,8 +13,10 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
+parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
+parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@@ -65,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -109,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
+parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
diff --git a/modules/devices.py b/modules/devices.py
index 57e51da3..00a00b18 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors
+from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True
-
errors.run(enable_tf32, "Enabling TF32")
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -90,23 +93,87 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
+nv_rng = None
+
+
def randn(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
from modules.shared import opts
- torch.manual_seed(seed)
+ manual_seed(seed)
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def randn_local(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ rng = rng_philox.Generator(seed)
+ return torch.asarray(rng.randn(shape), device=device)
+
+ local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+ local_generator = torch.Generator(local_device).manual_seed(int(seed))
+ return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
+def randn_like(x):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
+ return torch.randn_like(x, device=cpu).to(x.device)
+
+ return torch.randn_like(x)
+
+
def randn_without_seed(shape):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
from modules.shared import opts
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def manual_seed(seed):
+ """Set up a global random number generator using the specified seed."""
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ global nv_rng
+ nv_rng = rng_philox.Generator(seed)
+ return
+
+ torch.manual_seed(seed)
+
+
def autocast(disable=False):
from modules import shared
diff --git a/modules/errors.py b/modules/errors.py
index 5271a9fe..192cd8ff 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -14,7 +14,8 @@ def record_exception():
if exception_records and exception_records[-1] == e:
return
- exception_records.append((e, tb))
+ from modules import sysinfo
+ exception_records.append(sysinfo.format_exception(e, tb))
if len(exception_records) > 5:
exception_records.pop(0)
@@ -83,3 +84,53 @@ def run(code, task):
code()
except Exception as e:
display(task, e)
+
+
+def check_versions():
+ from packaging import version
+ from modules import shared
+
+ import torch
+ import gradio
+
+ expected_torch_version = "2.0.0"
+ expected_xformers_version = "0.0.20"
+ expected_gradio_version = "3.39.0"
+
+ if version.parse(torch.__version__) < version.parse(expected_torch_version):
+ print_error_explanation(f"""
+You are running torch {torch.__version__}.
+The program is tested to work with torch {expected_torch_version}.
+To reinstall the desired version, run with commandline flag --reinstall-torch.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if shared.xformers_available:
+ import xformers
+
+ if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
+ print_error_explanation(f"""
+You are running xformers {xformers.__version__}.
+The program is tested to work with xformers {expected_xformers_version}.
+To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if gradio.__version__ != expected_gradio_version:
+ print_error_explanation(f"""
+You are running gradio {gradio.__version__}.
+The program is designed to work with gradio {expected_gradio_version}.
+Using a different version of gradio is extremely likely to break the program.
+
+Reasons why you have the mismatched gradio version can be:
+ - you use --skip-install flag.
+ - you use webui.py to start the program instead of launch.py.
+ - an extension installs the incompatible gradio version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
diff --git a/modules/extensions.py b/modules/extensions.py
index c561159a..e4633af4 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
def active():
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
return []
- elif shared.opts.disable_all_extensions == "extra":
+ elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled]
@@ -56,10 +56,12 @@ class Extension:
self.do_read_info_from_repo()
return self.to_dict()
-
- d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
- self.from_dict(d)
- self.status = 'unknown'
+ try:
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
+ self.from_dict(d)
+ except FileNotFoundError:
+ pass
+ self.status = 'unknown' if self.status == '' else self.status
def do_read_info_from_repo(self):
repo = None
@@ -139,8 +141,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir):
return
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions:
+ print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
+ elif shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
+ elif shared.cmd_opts.disable_extra_extensions:
+ print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 41799b0a..6ae07e91 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -4,16 +4,22 @@ from collections import defaultdict
from modules import errors
extra_network_registry = {}
+extra_network_aliases = {}
def initialize():
extra_network_registry.clear()
+ extra_network_aliases.clear()
def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_extra_network_alias(extra_network, alias):
+ extra_network_aliases[alias] = extra_network
+
+
def register_default_extra_networks():
from modules.extra_networks_hypernet import ExtraNetworkHypernet
register_extra_network(ExtraNetworkHypernet())
@@ -82,20 +88,26 @@ def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
+ activated = []
+
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
+
+ if extra_network is None:
+ extra_network = extra_network_aliases.get(extra_network_name, None)
+
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue
try:
extra_network.activate(p, extra_network_args)
+ activated.append(extra_network)
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
+ if extra_network in activated:
continue
try:
diff --git a/modules/extras.py b/modules/extras.py
index e9c0263e..2a310ae3 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -7,7 +7,7 @@ import json
import torch
import tqdm
-from modules import shared, images, sd_models, sd_vae, sd_models_config
+from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
+def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
+ metadata = {}
+
+ for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
+ checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
+ if checkpoint_info is None:
+ continue
+
+ metadata.update(checkpoint_info.metadata)
+
+ return json.dumps(metadata, indent=4, ensure_ascii=False)
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
shared.state.begin(job="model-merge")
def fail(message):
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = None
+ metadata = {}
+
+ if save_metadata and copy_metadata_fields:
+ if primary_model_info:
+ metadata.update(primary_model_info.metadata)
+ if secondary_model_info:
+ metadata.update(secondary_model_info.metadata)
+ if tertiary_model_info:
+ metadata.update(tertiary_model_info.metadata)
if save_metadata:
- metadata = {"format": "pt"}
+ try:
+ metadata.update(json.loads(metadata_json))
+ except Exception as e:
+ errors.display(e, "readin metadata from json")
+
+ metadata["format"] = "pt"