aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/Lora/ui_edit_user_metadata.py2
-rw-r--r--javascript/hints.js11
-rw-r--r--javascript/localization.js10
-rw-r--r--javascript/ui.js14
-rw-r--r--modules/call_queue.py4
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/devices.py83
-rw-r--r--modules/errors.py53
-rw-r--r--modules/extensions.py10
-rw-r--r--modules/extra_networks.py19
-rw-r--r--modules/extras.py39
-rw-r--r--modules/gradio_extensons.py60
-rw-r--r--modules/hypernetworks/hypernetwork.py5
-rw-r--r--modules/images.py2
-rw-r--r--modules/img2img.py36
-rw-r--r--modules/processing.py14
-rw-r--r--modules/prompt_parser.py25
-rw-r--r--modules/rng_philox.py102
-rw-r--r--modules/scripts.py46
-rw-r--r--modules/sd_hijack.py6
-rw-r--r--modules/sd_hijack_clip.py2
-rw-r--r--modules/sd_hijack_optimizations.py4
-rw-r--r--modules/sd_models.py35
-rw-r--r--modules/sd_samplers_common.py20
-rw-r--r--modules/sd_samplers_kdiffusion.py9
-rw-r--r--modules/sd_vae.py16
-rw-r--r--modules/shared.py34
-rw-r--r--modules/styles.py5
-rw-r--r--modules/sysinfo.py6
-rw-r--r--modules/textual_inversion/textual_inversion.py4
-rw-r--r--modules/ui.py439
-rw-r--r--modules/ui_checkpoint_merger.py124
-rw-r--r--modules/ui_common.py34
-rw-r--r--modules/ui_components.py2
-rw-r--r--modules/ui_extensions.py26
-rw-r--r--modules/ui_extra_networks.py13
-rw-r--r--modules/ui_extra_networks_checkpoints.py6
-rw-r--r--modules/ui_extra_networks_checkpoints_user_metadata.py60
-rw-r--r--modules/ui_extra_networks_hypernets.py2
-rw-r--r--modules/ui_extra_networks_textual_inversion.py2
-rw-r--r--modules/ui_extra_networks_user_metadata.py1
-rw-r--r--modules/ui_postprocessing.py2
-rw-r--r--modules/ui_prompt_styles.py110
-rw-r--r--modules/ui_settings.py2
-rw-r--r--requirements.txt4
-rw-r--r--requirements_versions.txt14
-rw-r--r--scripts/xyz_grid.py27
-rw-r--r--style.css29
-rw-r--r--webui.py43
49 files changed, 1057 insertions, 561 deletions
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
index 2ca997f7..390d9dde 100644
--- a/extensions-builtin/Lora/ui_edit_user_metadata.py
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
with gr.Column(scale=1, min_width=120):
- generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+ generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
self.edit_notes = gr.TextArea(label='Notes', lines=4)
diff --git a/javascript/hints.js b/javascript/hints.js
index 4167cb28..6de9372e 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
}
});
+
+onUiLoaded(function() {
+ for (var comp of window.gradio_config.components) {
+ if (comp.props.webui_tooltip && comp.props.elem_id) {
+ var elem = gradioApp().getElementById(comp.props.elem_id);
+ if (elem) {
+ elem.title = comp.props.webui_tooltip;
+ }
+ }
+ }
+});
diff --git a/javascript/localization.js b/javascript/localization.js
index eb22b8a7..0c9032f9 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION',
img2img_styles: 'OPTION',
- setting_random_artist_categories: 'SPAN',
- setting_face_restoration_model: 'SPAN',
- setting_realesrgan_enabled_models: 'SPAN',
- extras_upscaler_1: 'SPAN',
- extras_upscaler_2: 'SPAN',
+ setting_random_artist_categories: 'OPTION',
+ setting_face_restoration_model: 'OPTION',
+ setting_realesrgan_enabled_models: 'OPTION',
+ extras_upscaler_1: 'OPTION',
+ extras_upscaler_2: 'OPTION',
};
var re_num = /^[.\d]+$/;
diff --git a/javascript/ui.js b/javascript/ui.js
index d70a681b..abf23a78 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -152,7 +152,11 @@ function submit() {
showSubmitButtons('txt2img', false);
var id = randomId();
- localStorage.setItem("txt2img_task_id", id);
+ try {
+ localStorage.setItem("txt2img_task_id", id);
+ } catch (e) {
+ console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
+ }
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
showSubmitButtons('txt2img', true);
@@ -171,7 +175,11 @@ function submit_img2img() {
showSubmitButtons('img2img', false);
var id = randomId();
- localStorage.setItem("img2img_task_id", id);
+ try {
+ localStorage.setItem("img2img_task_id", id);
+ } catch (e) {
+ console.warn(`Failed to save img2img task id to localStorage: ${e}`);
+ }
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
showSubmitButtons('img2img', true);
@@ -191,8 +199,6 @@ function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false);
var id = localStorage.getItem("txt2img_task_id");
- id = localStorage.getItem("txt2img_task_id");
-
if (id) {
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
showSubmitButtons('txt2img', true);
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 61aa240f..f2eb17d6 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -3,7 +3,7 @@ import html
import threading
import time
-from modules import shared, progress, errors
+from modules import shared, progress, errors, devices
queue_lock = threading.Lock()
@@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
+ devices.torch_gc()
+
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index cb4ec5f7..64f21e01 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -112,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
+parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
diff --git a/modules/devices.py b/modules/devices.py
index 57e51da3..00a00b18 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors
+from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True
-
errors.run(enable_tf32, "Enabling TF32")
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -90,23 +93,87 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
+nv_rng = None
+
+
def randn(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
from modules.shared import opts
- torch.manual_seed(seed)
+ manual_seed(seed)
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def randn_local(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ rng = rng_philox.Generator(seed)
+ return torch.asarray(rng.randn(shape), device=device)
+
+ local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+ local_generator = torch.Generator(local_device).manual_seed(int(seed))
+ return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
+def randn_like(x):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
+ return torch.randn_like(x, device=cpu).to(x.device)
+
+ return torch.randn_like(x)
+
+
def randn_without_seed(shape):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
from modules.shared import opts
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def manual_seed(seed):
+ """Set up a global random number generator using the specified seed."""
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ global nv_rng
+ nv_rng = rng_philox.Generator(seed)
+ return
+
+ torch.manual_seed(seed)
+
+
def autocast(disable=False):
from modules import shared
diff --git a/modules/errors.py b/modules/errors.py
index 5271a9fe..192cd8ff 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -14,7 +14,8 @@ def record_exception():
if exception_records and exception_records[-1] == e:
return
- exception_records.append((e, tb))
+ from modules import sysinfo
+ exception_records.append(sysinfo.format_exception(e, tb))
if len(exception_records) > 5:
exception_records.pop(0)
@@ -83,3 +84,53 @@ def run(code, task):
code()
except Exception as e:
display(task, e)
+
+
+def check_versions():
+ from packaging import version
+ from modules import shared
+
+ import torch
+ import gradio
+
+ expected_torch_version = "2.0.0"
+ expected_xformers_version = "0.0.20"
+ expected_gradio_version = "3.39.0"
+
+ if version.parse(torch.__version__) < version.parse(expected_torch_version):
+ print_error_explanation(f"""
+You are running torch {torch.__version__}.
+The program is tested to work with torch {expected_torch_version}.
+To reinstall the desired version, run with commandline flag --reinstall-torch.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if shared.xformers_available:
+ import xformers
+
+ if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
+ print_error_explanation(f"""
+You are running xformers {xformers.__version__}.
+The program is tested to work with xformers {expected_xformers_version}.
+To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if gradio.__version__ != expected_gradio_version:
+ print_error_explanation(f"""
+You are running gradio {gradio.__version__}.
+The program is designed to work with gradio {expected_gradio_version}.
+Using a different version of gradio is extremely likely to break the program.
+
+Reasons why you have the mismatched gradio version can be:
+ - you use --skip-install flag.
+ - you use webui.py to start the program instead of launch.py.
+ - an extension installs the incompatible gradio version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
diff --git a/modules/extensions.py b/modules/extensions.py
index 3ad5ed53..e4633af4 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
def active():
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
return []
- elif shared.opts.disable_all_extensions == "extra":
+ elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled]
@@ -141,8 +141,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir):
return
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions:
+ print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
+ elif shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
+ elif shared.cmd_opts.disable_extra_extensions:
+ print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 6ae07e91..fa28ac75 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -1,3 +1,5 @@
+import json
+import os
import re
from collections import defaultdict
@@ -177,3 +179,20 @@ def parse_prompts(prompts):
return res, extra_data
+
+def get_user_metadata(filename):
+ if filename is None:
+ return {}
+
+ basename, ext = os.path.splitext(filename)
+ metadata_filename = basename + '.json'
+
+ metadata = {}
+ try:
+ if os.path.isfile(metadata_filename):
+ with open(metadata_filename, "r", encoding="utf8") as file:
+ metadata = json.load(file)
+ except Exception as e:
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+ return metadata
diff --git a/modules/extras.py b/modules/extras.py
index e9c0263e..2a310ae3 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -7,7 +7,7 @@ import json
import torch
import tqdm
-from modules import shared, images, sd_models, sd_vae, sd_models_config
+from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
+def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
+ metadata = {}
+
+ for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
+ checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
+ if checkpoint_info is None:
+ continue
+
+ metadata.update(checkpoint_info.metadata)
+
+ return json.dumps(metadata, indent=4, ensure_ascii=False)
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
shared.state.begin(job="model-merge")
def fail(message):
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = None
+ metadata = {}
+
+ if save_metadata and copy_metadata_fields:
+ if primary_model_info:
+ metadata.update(primary_model_info.metadata)
+ if secondary_model_info:
+ metadata.update(secondary_model_info.metadata)
+ if tertiary_model_info:
+ metadata.update(tertiary_model_info.metadata)
if save_metadata:
- metadata = {"format": "pt"}
+ try:
+ metadata.update(json.loads(metadata_json))
+ except Exception as e:
+ errors.display(e, "readin metadata from json")
+
+ metadata["format"] = "pt"
+ if save_metadata and add_merge_recipe:
merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256,
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
"is_inpainting": result_is_inpainting_model,
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
}
- metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
sd_merge_models = {}
@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info:
add_model_metadata(tertiary_model_info)
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
else:
torch.save(theta_0, output_modelname)
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
new file mode 100644
index 00000000..5af7fd8e
--- /dev/null
+++ b/modules/gradio_extensons.py
@@ -0,0 +1,60 @@
+import gradio as gr
+
+from modules import scripts
+
+def add_classes_to_gradio_component(comp):
+ """
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+ """
+
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
+
+ if getattr(comp, 'multiselect', False):
+ comp.elem_classes.append('multiselect')
+
+
+def IOComponent_init(self, *args, **kwargs):
+ self.webui_tooltip = kwargs.pop('tooltip', None)
+
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.before_component(self, **kwargs)
+
+ scripts.script_callbacks.before_component_callback(self, **kwargs)
+
+ res = original_IOComponent_init(self, *args, **kwargs)
+
+ add_classes_to_gradio_component(self)
+
+ scripts.script_callbacks.after_component_callback(self, **kwargs)
+
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.after_component(self, **kwargs)
+
+ return res
+
+
+def Block_get_config(self):
+ config = original_Block_get_config(self)
+
+ webui_tooltip = getattr(self, 'webui_tooltip', None)
+ if webui_tooltip:
+ config["webui_tooltip"] = webui_tooltip
+
+ return config
+
+
+def BlockContext_init(self, *args, **kwargs):
+ res = original_BlockContext_init(self, *args, **kwargs)
+
+ add_classes_to_gradio_component(self)
+
+ return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+original_Block_get_config = gr.blocks.Block.get_config
+original_BlockContext_init = gr.blocks.BlockContext.__init__
+
+gr.components.IOComponent.__init__ = IOComponent_init
+gr.blocks.Block.get_config = Block_get_config
+gr.blocks.BlockContext.__init__ = BlockContext_init
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c4821d21..70f1cbd2 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -10,7 +10,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
- from modules import images
+ from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
diff --git a/modules/images.py b/modules/images.py
index 38aa933d..ba3c43a4 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
return res
-invalid_filename_chars = '<>:"/\\|?*\n'
+invalid_filename_chars = '<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '