From 762265eab58cdb8f2d6398769bab43d8b8db0075 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 10 May 2023 07:52:45 +0300
Subject: autofixes from ruff
---
webui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index 727ebd31..ec3d2aba 100644
--- a/webui.py
+++ b/webui.py
@@ -360,7 +360,7 @@ def webui():
if cmd_opts.subpath:
redirector = FastAPI()
redirector.get("/")
- mounted_app = gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
+ gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
wait_on_server(shared.demo)
print('Restarting UI...')
--
cgit v1.2.3
From f741a98baccae100fcfb40c017b5c35c5cba1b0c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 10 May 2023 08:43:42 +0300
Subject: imports cleanup for ruff
---
extensions-builtin/Lora/lora.py | 1 -
extensions-builtin/ScuNET/scripts/scunet_model.py | 1 -
extensions-builtin/SwinIR/scripts/swinir_model.py | 3 +--
modules/codeformer/codeformer_arch.py | 4 +---
modules/codeformer/vqgan_arch.py | 2 --
modules/codeformer_model.py | 4 +---
modules/config_states.py | 2 +-
modules/esrgan_model.py | 2 +-
modules/esrgan_model_arch.py | 1 -
modules/extensions.py | 1 -
modules/generation_parameters_copypaste.py | 4 ----
modules/hypernetworks/hypernetwork.py | 3 +--
modules/hypernetworks/ui.py | 2 --
modules/images.py | 2 +-
modules/img2img.py | 5 +----
modules/mac_specific.py | 1 -
modules/modelloader.py | 1 -
modules/models/diffusion/uni_pc/uni_pc.py | 1 -
modules/processing.py | 5 ++---
modules/sd_hijack.py | 2 +-
modules/sd_hijack_inpainting.py | 6 ------
modules/sd_hijack_ip2p.py | 5 +----
modules/sd_hijack_xlmr.py | 2 --
modules/sd_models.py | 2 +-
modules/sd_models_config.py | 1 -
modules/sd_samplers_kdiffusion.py | 1 -
modules/sd_vae.py | 3 ---
modules/shared.py | 3 ---
modules/styles.py | 9 ---------
modules/textual_inversion/autocrop.py | 4 +---
modules/textual_inversion/image_embedding.py | 2 +-
modules/textual_inversion/preprocess.py | 4 ----
modules/textual_inversion/textual_inversion.py | 1 -
modules/txt2img.py | 9 +++------
modules/ui.py | 5 ++---
modules/ui_extra_networks.py | 1 -
modules/ui_postprocessing.py | 2 +-
modules/upscaler.py | 2 --
modules/xlmr.py | 2 +-
pyproject.toml | 11 +++++++----
scripts/custom_code.py | 2 +-
scripts/outpainting_mk_2.py | 4 ++--
scripts/poor_mans_outpainting.py | 4 ++--
scripts/prompt_matrix.py | 7 ++-----
scripts/prompts_from_file.py | 5 +----
scripts/sd_upscale.py | 4 ++--
scripts/xyz_grid.py | 6 ++----
webui.py | 2 +-
48 files changed, 42 insertions(+), 114 deletions(-)
(limited to 'webui.py')
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index ba1293df..0ab43229 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -1,4 +1,3 @@
-import glob
import os
import re
import torch
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index c7fd5739..aa2fdb3a 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -13,7 +13,6 @@ import modules.upscaler
from modules import devices, modelloader
from scunet_model_arch import SCUNet as net
from modules.shared import opts
-from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index d77c3a92..55dd94ab 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,4 +1,3 @@
-import contextlib
import os
import numpy as np
@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts, state
+from modules.shared import opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
index f1a7cf09..00c407de 100644
--- a/modules/codeformer/codeformer_arch.py
+++ b/modules/codeformer/codeformer_arch.py
@@ -1,14 +1,12 @@
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math
-import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from typing import Optional, List
+from typing import Optional
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
-from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
index e7293683..820e6b12 100644
--- a/modules/codeformer/vqgan_arch.py
+++ b/modules/codeformer/vqgan_arch.py
@@ -5,11 +5,9 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 8d84bbc9..8e56cb89 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -33,11 +33,9 @@ def setup_model(dirname):
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils.download_util import load_file_from_url
- from basicsr.utils import imwrite, img2tensor, tensor2img
+ from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
- from modules.shared import cmd_opts
net_class = CodeFormer
diff --git a/modules/config_states.py b/modules/config_states.py
index 2ea00929..8f1ff428 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -14,7 +14,7 @@ from collections import OrderedDict
import git
from modules import shared, extensions
-from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
+from modules.paths_internal import script_path, config_states_dir
all_config_states = OrderedDict()
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index f4369257..85aa6934 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -6,7 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
-from modules import shared, modelloader, images, devices
+from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index 7f8bc7c0..4de9dd8d 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -2,7 +2,6 @@
from collections import OrderedDict
import math
-import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
diff --git a/modules/extensions.py b/modules/extensions.py
index 34d9d654..829f8cd9 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -3,7 +3,6 @@ import sys
import traceback
import time
-from datetime import datetime
import git
from modules import shared
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index fe8b18b2..f1c59c46 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,15 +1,11 @@
import base64
-import html
import io
-import math
import os
import re
-from pathlib import Path
import gradio as gr
from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks
-import tempfile
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 1fc49537..9fe749b7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,4 +1,3 @@
-import csv
import datetime
import glob
import html
@@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
-from collections import defaultdict, deque
+from collections import deque
from statistics import stdev, mean
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 76599f5a..be168736 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -1,6 +1,4 @@
import html
-import os
-import re
import gradio as gr
import modules.hypernetworks.hypernetwork
diff --git a/modules/images.py b/modules/images.py
index 5eb6d855..7392cb8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -19,7 +19,7 @@ import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
-from modules.shared import opts, cmd_opts
+from modules.shared import opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
diff --git a/modules/img2img.py b/modules/img2img.py
index 32b1ecd6..d704bf90 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,12 +1,9 @@
-import math
import os
-import sys
-import traceback
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
-from modules import devices, sd_samplers
+from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 40ce2101..5c2f92a1 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,6 +1,5 @@
import torch
import platform
-from modules import paths
from modules.sd_hijack_utils import CondFunc
from packaging import version
diff --git a/modules/modelloader.py b/modules/modelloader.py
index cf685000..92ada694 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,4 +1,3 @@
-import glob
import os
import shutil
import importlib
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
index 11b330bc..a4c4ef4e 100644
--- a/modules/models/diffusion/uni_pc/uni_pc.py
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -1,5 +1,4 @@
import torch
-import torch.nn.functional as F
import math
from tqdm.auto import trange
diff --git a/modules/processing.py b/modules/processing.py
index 6f5233c1..c3932d6b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -2,7 +2,6 @@ import json
import math
import os
import sys
-import warnings
import hashlib
import torch
@@ -11,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
import random
import cv2
from skimage import exposure
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index d8135211..81573b78 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules import devices, sd_hijack_optimizations, shared
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 55a2ce4d..344d75c8 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,15 +1,9 @@
-import os
import torch
-from einops import repeat
-from omegaconf import ListConfig
-
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from ldm.models.diffusion.ddpm import LatentDiffusion
-from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py
index 41ed54a2..6fe6b6ff 100644
--- a/modules/sd_hijack_ip2p.py
+++ b/modules/sd_hijack_ip2p.py
@@ -1,8 +1,5 @@
-import collections
import os.path
-import sys
-import gc
-import time
+
def should_hijack_ip2p(checkpoint_info):
from modules import sd_models_config
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
index 4ac51c38..28528329 100644
--- a/modules/sd_hijack_xlmr.py
+++ b/modules/sd_hijack_xlmr.py
@@ -1,8 +1,6 @@
-import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
-from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 11c1a344..1c09c709 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -565,7 +565,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 7a79925a..9bfe1237 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -1,4 +1,3 @@
-import re
import os
import torch
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 0fc9f456..3b8e9622 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,7 +1,6 @@
from collections import deque
import torch
import inspect
-import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 521e485a..b7176125 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,8 +1,5 @@
-import torch
-import safetensors.torch
import os
import collections
-from collections import namedtuple
from modules import paths, shared, devices, script_callbacks, sd_models
import glob
from copy import deepcopy
diff --git a/modules/shared.py b/modules/shared.py
index 4631965b..44cd2c0c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,12 +1,9 @@
-import argparse
import datetime
import json
import os
import sys
import time
-import requests
-from PIL import Image
import gradio as gr
import tqdm
diff --git a/modules/styles.py b/modules/styles.py
index 11642075..c22769cf 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,18 +1,9 @@
-# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
-from __future__ import annotations
-
import csv
import os
import os.path
import typing
-import collections.abc as abc
-import tempfile
import shutil
-if typing.TYPE_CHECKING:
- # Only import this when code is being type-checked, it doesn't have any effect at runtime
- from .processing import StableDiffusionProcessing
-
class PromptStyle(typing.NamedTuple):
name: str
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index d7d8d2e3..7770d22f 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,10 +1,8 @@
import cv2
import requests
import os
-from collections import defaultdict
-from math import log, sqrt
import numpy as np
-from PIL import Image, ImageDraw
+from PIL import ImageDraw
GREEN = "#0F0"
BLUE = "#00F"
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 5593f88c..ee0e850a 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -2,7 +2,7 @@ import base64
import json
import numpy as np
import zlib
-from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
+from PIL import Image, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
from modules.shared import opts
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index da0bcb26..d0cad09e 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,13 +1,9 @@
import os
from PIL import Image, ImageOps
import math
-import platform
-import sys
import tqdm
-import time
from modules import paths, shared, images, deepbooru
-from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f753b75f..9ed9ba45 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,7 +1,6 @@
import os
import sys
import traceback
-import inspect
from collections import namedtuple
import torch
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 16841d0f..f022381c 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,18 +1,15 @@
import modules.scripts
-from modules import sd_samplers
+from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
- StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
-import modules.processing as processing
from modules.ui import plaintext_to_html
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
- p = StableDiffusionProcessingTxt2Img(
+ p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
@@ -53,7 +50,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p)
+ processed = processing.process_images(p)
p.close()
diff --git a/modules/ui.py b/modules/ui.py
index 6beda76f..f7e57593 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -14,10 +14,10 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress
-from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import opts, cmd_opts
import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
@@ -28,7 +28,6 @@ import modules.shared as shared
import modules.styles
import modules.textual_inversion.ui
from modules import prompt_parser
-from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.textual_inversion import textual_inversion
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 49e06289..800e467a 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,4 +1,3 @@
-import glob
import os.path
import urllib.parse
from pathlib import Path
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index f25639e5..c7dc1154 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,5 +1,5 @@
import gradio as gr
-from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+from modules import scripts, shared, ui_common, postprocessing, call_queue
import modules.generation_parameters_copypaste as parameters_copypaste
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 0ad4fe99..777593b0 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -2,8 +2,6 @@ import os
from abc import abstractmethod
import PIL
-import numpy as np
-import torch
from PIL import Image
import modules.shared
diff --git a/modules/xlmr.py b/modules/xlmr.py
index beab3fdf..e056c3f6 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -1,4 +1,4 @@
-from transformers import BertPreTrainedModel,BertModel,BertConfig
+from transformers import BertPreTrainedModel, BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
diff --git a/pyproject.toml b/pyproject.toml
index 1e164abc..9caa9ba2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,10 +1,13 @@
[tool.ruff]
+exclude = ["extensions"]
+
ignore = [
"E501",
- "E731",
- "E402", # Module level import not at top of file
- "F401" # Module imported but unused
+
+ "F401", # Module imported but unused
]
-exclude = ["extensions"]
+
+[tool.ruff.per-file-ignores]
+"webui.py" = ["E402"] # Module level import not at top of file
\ No newline at end of file
diff --git a/scripts/custom_code.py b/scripts/custom_code.py
index f36a3675..cc6f0d49 100644
--- a/scripts/custom_code.py
+++ b/scripts/custom_code.py
@@ -4,7 +4,7 @@ import ast
import copy
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import cmd_opts
def convertExpr2Expression(expr):
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index b10fed6c..665dbe89 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -7,9 +7,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
index ddcbd2d3..c0bbecc1 100644
--- a/scripts/poor_mans_outpainting.py
+++ b/scripts/poor_mans_outpainting.py
@@ -4,9 +4,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images, devices
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index e9b11517..fb06beab 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -1,14 +1,11 @@
import math
-from collections import namedtuple
-from copy import copy
-import random
import modules.scripts as scripts
import gradio as gr
from modules import images
-from modules.processing import process_images, Processed
-from modules.shared import opts, cmd_opts, state
+from modules.processing import process_images
+from modules.shared import opts, state
import modules.sd_samplers
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 76dc5778..149bc85f 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -1,6 +1,4 @@
import copy
-import math
-import os
import random
import sys
import traceback
@@ -11,8 +9,7 @@ import gradio as gr
from modules import sd_samplers
from modules.processing import Processed, process_images
-from PIL import Image
-from modules.shared import opts, cmd_opts, state
+from modules.shared import state
def process_string_tag(tag):
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index 332d76d9..d873a09c 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -4,9 +4,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image
-from modules import processing, shared, sd_samplers, images, devices
+from modules import processing, shared, images, devices
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 2ff42ef8..332e0ecd 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -10,15 +10,13 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
import modules.shared as shared
import modules.sd_samplers
import modules.sd_models
import modules.sd_vae
-import glob
-import os
import re
from modules.ui_components import ToolButton
diff --git a/webui.py b/webui.py
index ec3d2aba..48277075 100644
--- a/webui.py
+++ b/webui.py
@@ -43,7 +43,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
--
cgit v1.2.3
From 4b854806d98cf5ccd48e5cd99c172613da7937f0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 10 May 2023 09:02:23 +0300
Subject: F401 fixes for ruff
---
extensions-builtin/LDSR/scripts/ldsr_model.py | 4 ++--
modules/cmd_args.py | 2 +-
modules/deepbooru.py | 1 -
modules/extensions.py | 2 +-
modules/gfpgan_model.py | 2 +-
modules/models/diffusion/uni_pc/__init__.py | 2 +-
modules/paths.py | 4 ++--
modules/realesrgan_model.py | 6 +++---
modules/script_loading.py | 1 -
modules/sd_hijack_inpainting.py | 2 +-
modules/sd_models.py | 4 +---
modules/sd_samplers.py | 2 +-
modules/shared.py | 2 +-
modules/ui.py | 4 ++--
modules/upscaler.py | 2 +-
pyproject.toml | 9 +++++----
webui.py | 8 ++++----
17 files changed, 27 insertions(+), 30 deletions(-)
(limited to 'webui.py')
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index e8dc083c..fbbe9005 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -7,8 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
-import sd_hijack_autoencoder
-import sd_hijack_ddpm_v1
+import sd_hijack_autoencoder # noqa: F401
+import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler):
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index d906a571..e01ca655 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -1,6 +1,6 @@
import argparse
import os
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser()
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 122fce7f..1c4554a2 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -2,7 +2,6 @@ import os
import re
import torch
-from PIL import Image
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
diff --git a/modules/extensions.py b/modules/extensions.py
index 829f8cd9..bc2c0450 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -6,7 +6,7 @@ import time
import git
from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index fbe6215a..0131dea4 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -78,7 +78,7 @@ def setup_model(dirname):
try:
from gfpgan import GFPGANer
- from facexlib import detection, parsing
+ from facexlib import detection, parsing # noqa: F401
global user_path
global have_gfpgan
global gfpgan_constructor
diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py
index e1265e3f..dbb35964 100644
--- a/modules/models/diffusion/uni_pc/__init__.py
+++ b/modules/models/diffusion/uni_pc/__init__.py
@@ -1 +1 @@
-from .sampler import UniPCSampler
+from .sampler import UniPCSampler # noqa: F401
diff --git a/modules/paths.py b/modules/paths.py
index acf1894b..5f6474c0 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,8 +1,8 @@
import os
import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
-import modules.safe
+import modules.safe # noqa: F401
# data_path = cmd_opts_pre.data
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 9ec1adf2..c24d8dbb 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -17,9 +17,9 @@ class UpscalerRealESRGAN(Upscaler):
self.user_path = path
super().__init__()
try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+ from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
+ from realesrgan import RealESRGANer # noqa: F401
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True
self.scalers = []
scalers = self.load_models(path)
diff --git a/modules/script_loading.py b/modules/script_loading.py
index a7d2203f..57b15862 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -2,7 +2,6 @@ import os
import sys
import traceback
import importlib.util
-from types import ModuleType
def load_module(path):
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 344d75c8..058575b7 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -4,7 +4,7 @@ import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+from ldm.models.diffusion.ddim import noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 1c09c709..d1e946a5 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
@@ -87,8 +86,7 @@ class CheckpointInfo:
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index ff361f22..4f1bf21d 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,7 +1,7 @@
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
# imports for functions that previously were here and are used by other modules
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
diff --git a/modules/shared.py b/modules/shared.py
index 44cd2c0c..7d70f041 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -12,7 +12,7 @@ import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
-from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None
diff --git a/modules/ui.py b/modules/ui.py
index f7e57593..782b569d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -10,10 +10,10 @@ import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np
-from PIL import Image, PngImagePlugin
+from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 777593b0..e145be30 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -41,7 +41,7 @@ class Upscaler:
os.makedirs(self.model_path, exist_ok=True)
try:
- import cv2
+ import cv2 # noqa: F401
self.can_tile = True
except Exception:
pass
diff --git a/pyproject.toml b/pyproject.toml
index 9caa9ba2..0883c127 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,13 +1,14 @@
[tool.ruff]
+target-version = "py310"
+
exclude = ["extensions"]
ignore = [
- "E501",
-
- "F401", # Module imported but unused
+ "E501", # Line too long
+ "E731", # Do not assign a `lambda` expression, use a `def`
]
[tool.ruff.per-file-ignores]
-"webui.py" = ["E402"] # Module level import not at top of file
\ No newline at end of file
+"webui.py" = ["E402"] # Module level import not at top of file
diff --git a/webui.py b/webui.py
index 48277075..5d5e80b5 100644
--- a/webui.py
+++ b/webui.py
@@ -16,12 +16,12 @@ from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors
+from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer = timer.Timer()
import torch
-import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
+import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
@@ -31,12 +31,12 @@ startup_timer.record("import torch")
import gradio
startup_timer.record("import gradio")
-import ldm.modules.encoders.modules
+import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules.call_queue import wrap_queued_call, queue_lock
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
--
cgit v1.2.3
From 8aa87c564a79965013715d56a5f90d2a34d5d6ee Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 10 May 2023 23:41:08 +0300
Subject: add UI to edit defaults allow setting defaults for elements in
extensions' tabs fix a problem with ESRGAN upscalers disappearing after UI
reload implicit change: HTML element id for train tab from tab_ti to
tab_train (will this break things?)
---
modules/modelloader.py | 27 +++----
modules/ui.py | 122 +++++------------------------
modules/ui_loadsave.py | 208 +++++++++++++++++++++++++++++++++++++++++++++++++
style.css | 4 +
webui.py | 6 +-
5 files changed, 242 insertions(+), 125 deletions(-)
create mode 100644 modules/ui_loadsave.py
(limited to 'webui.py')
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 25612bf8..2a479bcb 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -116,20 +116,6 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
-builtin_upscaler_classes = []
-forbidden_upscaler_classes = set()
-
-
-def list_builtin_upscalers():
- builtin_upscaler_classes.clear()
- builtin_upscaler_classes.extend(Upscaler.__subclasses__())
-
-def forbid_loaded_nonbuiltin_upscalers():
- for cls in Upscaler.__subclasses__():
- if cls not in builtin_upscaler_classes:
- forbidden_upscaler_classes.add(cls)
-
-
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -145,10 +131,17 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
- for cls in Upscaler.__subclasses__():
- if cls in forbidden_upscaler_classes:
- continue
+ # some of upscaler classes will not go away after reloading their modules, and we'll end
+ # up with two copies of those classes. The newest copy will always be the last in the list,
+ # so we go from end to beginning and ignore duplicates
+ used_classes = {}
+ for cls in reversed(Upscaler.__subclasses__()):
+ classname = str(cls)
+ if classname not in used_classes:
+ used_classes[classname] = cls
+
+ for cls in reversed(used_classes.values()):
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))
diff --git a/modules/ui.py b/modules/ui.py
index 7ee99473..1efb656a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,7 +13,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
@@ -86,16 +86,6 @@ def send_gradio_gallery_to_image(x):
return None
return image_from_url_text(x[0])
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- if isinstance(x, gr.Tabs) and x.elem_id is not None:
- # Tabs element can't have a label, have to use elem_id instead
- func(f"{path}/Tabs@{x.elem_id}", x)
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(f"{path}/{x.label}", x)
-
def add_style(name: str, prompt: str, negative_prompt: str):
if name is None:
@@ -1471,6 +1461,8 @@ def create_ui():
return res
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
+
components = []
component_dict = {}
shared.settings_components = component_dict
@@ -1558,6 +1550,9 @@ def create_ui():
current_row.__exit__()
current_tab.__exit__()
+ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
+ loadsave.create_ui()
+
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
@@ -1631,7 +1626,7 @@ def create_ui():
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
- (train_interface, "Train", "ti"),
+ (train_interface, "Train", "train"),
]
interfaces += script_callbacks.ui_tabs_callback()
@@ -1659,6 +1654,16 @@ def create_ui():
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
interface.render()
+ for interface, _label, ifid in interfaces:
+ if ifid in ["extensions", "settings"]:
+ continue
+
+ loadsave.add_block(interface, ifid)
+
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
+
+ loadsave.setup_ui()
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
@@ -1747,97 +1752,8 @@ def create_ui():
]
)
- ui_config_file = cmd_opts.ui_config_file
- ui_settings = {}
- settings_count = len(ui_settings)
- error_loading = False
-
- try:
- if os.path.exists(ui_config_file):
- with open(ui_config_file, "r", encoding="utf8") as file:
- ui_settings = json.load(file)
- except Exception:
- error_loading = True
- print("Error loading settings:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def loadsave(path, x):
- def apply_field(obj, field, condition=None, init_field=None):
- key = f"{path}/{field}"
-
- if getattr(obj, 'custom_script_source', None) is not None:
- key = f"customscript/{obj.custom_script_source}/{key}"
-
- if getattr(obj, 'do_not_save_to_config', False):
- return
-
- saved_value = ui_settings.get(key, None)
- if saved_value is None:
- ui_settings[key] = getattr(obj, field)
- elif condition and not condition(saved_value):
- pass
-
- # this warning is generally not useful;
- # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
- else:
- setattr(obj, field, saved_value)
- if init_field is not None:
- init_field(saved_value)
-
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
- apply_field(x, 'visible')
-
- if type(x) == gr.Slider:
- apply_field(x, 'value')
- apply_field(x, 'minimum')
- apply_field(x, 'maximum')
- apply_field(x, 'step')
-
- if type(x) == gr.Radio:
- apply_field(x, 'value', lambda val: val in x.choices)
-
- if type(x) == gr.Checkbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Textbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Number:
- apply_field(x, 'value')
-
- if type(x) == gr.Dropdown:
- def check_dropdown(val):
- if getattr(x, 'multiselect', False):
- return all(value in x.choices for value in val)
- else:
- return val in x.choices
-
- apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
-
- def check_tab_id(tab_id):
- tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
- if type(tab_id) == str:
- tab_ids = [t.id for t in tab_items]
- return tab_id in tab_ids
- elif type(tab_id) == int:
- return tab_id >= 0 and tab_id < len(tab_items)
- else:
- return False
-
- if type(x) == gr.Tabs:
- apply_field(x, 'selected', check_tab_id)
-
- visit(txt2img_interface, loadsave, "txt2img")
- visit(img2img_interface, loadsave, "img2img")
- visit(extras_interface, loadsave, "extras")
- visit(modelmerger_interface, loadsave, "modelmerger")
- visit(train_interface, loadsave, "train")
-
- loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
-
- if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
- with open(ui_config_file, "w", encoding="utf8") as file:
- json.dump(ui_settings, file, indent=4)
+ loadsave.dump_defaults()
+ demo.ui_loadsave = loadsave
# Required as a workaround for change() event not triggering when loading values from ui-config.json
interp_description.value = update_interp_description(interp_method.value)
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
new file mode 100644
index 00000000..728fec9e
--- /dev/null
+++ b/modules/ui_loadsave.py
@@ -0,0 +1,208 @@
+import json
+import os
+
+import gradio as gr
+
+from modules import errors
+from modules.ui_components import ToolButton
+
+
+class UiLoadsave:
+ """allows saving and restorig default values for gradio components"""
+
+ def __init__(self, filename):
+ self.filename = filename
+ self.ui_settings = {}
+ self.component_mapping = {}
+ self.error_loading = False
+ self.finalized_ui = False
+
+ self.ui_defaults_view = None
+ self.ui_defaults_apply = None
+ self.ui_defaults_review = None
+
+ try:
+ if os.path.exists(self.filename):
+ self.ui_settings = self.read_from_file()
+ except Exception as e:
+ self.error_loading = True
+ errors.display(e, "loading settings")
+
+ def add_component(self, path, x):
+ """adds component to the registry of tracked components"""
+
+ assert not self.finalized_ui
+
+ def apply_field(obj, field, condition=None, init_field=None):
+ key = f"{path}/{field}"
+
+ if getattr(obj, 'custom_script_source', None) is not None:
+ key = f"customscript/{obj.custom_script_source}/{key}"
+
+ if getattr(obj, 'do_not_save_to_config', False):
+ return
+
+ saved_value = self.ui_settings.get(key, None)
+ if saved_value is None:
+ self.ui_settings[key] = getattr(obj, field)
+ elif condition and not condition(saved_value):
+ pass
+ else:
+ setattr(obj, field, saved_value)
+ if init_field is not None:
+ init_field(saved_value)
+
+ if field == 'value' and key not in self.component_mapping:
+ self.component_mapping[key] = x
+
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
+ apply_field(x, 'visible')
+
+ if type(x) == gr.Slider:
+ apply_field(x, 'value')
+ apply_field(x, 'minimum')
+ apply_field(x, 'maximum')
+ apply_field(x, 'step')
+
+ if type(x) == gr.Radio:
+ apply_field(x, 'value', lambda val: val in x.choices)
+
+ if type(x) == gr.Checkbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Textbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Number:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Dropdown:
+ def check_dropdown(val):
+ if getattr(x, 'multiselect', False):
+ return all(value in x.choices for value in val)
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+
+ def check_tab_id(tab_id):
+ tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
+ if type(tab_id) == str:
+ tab_ids = [t.id for t in tab_items]
+ return tab_id in tab_ids
+ elif type(tab_id) == int:
+ return 0 <= tab_id < len(tab_items)
+ else:
+ return False
+
+ if type(x) == gr.Tabs:
+ apply_field(x, 'selected', check_tab_id)
+
+ def add_block(self, x, path=""):
+ """adds all components inside a gradio block x to the registry of tracked components"""
+
+ if hasattr(x, 'children'):
+ if isinstance(x, gr.Tabs) and x.elem_id is not None:
+ # Tabs element can't have a label, have to use elem_id instead
+ self.add_component(f"{path}/Tabs@{x.elem_id}", x)
+ for c in x.children:
+ self.add_block(c, path)
+ elif x.label is not None:
+ self.add_component(f"{path}/{x.label}", x)
+
+ def read_from_file(self):
+ with open(self.filename, "r", encoding="utf8") as file:
+ return json.load(file)
+
+ def write_to_file(self, current_ui_settings):
+ with open(self.filename, "w", encoding="utf8") as file:
+ json.dump(current_ui_settings, file, indent=4)
+
+ def dump_defaults(self):
+ """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
+
+ if self.error_loading and os.path.exists(self.filename):
+ return
+
+ self.write_to_file(self.ui_settings)
+
+ def iter_changes(self, current_ui_settings, values):
+ """
+ given a dictionary with defaults from a file and current values from gradio elements, returns
+ an iterator over tuples of values that are not the same between the file and the current;
+ tuple contents are: path, old value, new value
+ """
+
+ for (path, component), new_value in zip(self.component_mapping.items(), values):
+ old_value = current_ui_settings.get(path)
+
+ choices = getattr(component, 'choices', None)
+ if isinstance(new_value, int) and choices:
+ if new_value >= len(choices):
+ continue
+
+ new_value = choices[new_value]
+
+ if new_value == old_value:
+ continue
+
+ if old_value is None and new_value == '' or new_value == []:
+ continue
+
+ yield path, old_value, new_value
+
+ def ui_view(self, *values):
+ text = ["
Path | Old value | New value |
"]
+
+ for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
+ if old_value is None:
+ old_value = "None"
+
+ text.append(f"{path} | {old_value} | {new_value} |
")
+
+ if len(text) == 1:
+ text.append("No changes |
")
+
+ text.append("")
+ return "".join(text)
+
+ def ui_apply(self, *values):
+ num_changed = 0
+
+ current_ui_settings = self.read_from_file()
+
+ for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
+ num_changed += 1
+ current_ui_settings[path] = new_value
+
+ if num_changed == 0:
+ return "No changes."
+
+ self.write_to_file(current_ui_settings)
+
+ return f"Wrote {num_changed} changes."
+
+ def create_ui(self):
+ """creates ui elements for editing defaults UI, without adding any logic to them"""
+
+ gr.HTML(
+ f"This page allows you to change default values in UI elements on other tabs.
"
+ f"Make your changes, press 'View changes' to review the changed default values,
"
+ f"then press 'Apply' to write them to {self.filename}.
"
+ f"New defaults will apply after you restart the UI.
"
+ )
+
+ with gr.Row():
+ self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
+ self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
+
+ self.ui_defaults_review = gr.HTML("")
+
+ def setup_ui(self):
+ """adds logic to elements created with create_ui; all add_block class must be made before this"""
+
+ assert not self.finalized_ui
+ self.finalized_ui = True
+
+ self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
+ self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
diff --git a/style.css b/style.css
index b823c7dd..4ac919b5 100644
--- a/style.css
+++ b/style.css
@@ -414,6 +414,10 @@ table.settings-value-table td{
max-width: 36em;
}
+.ui-defaults-none{
+ color: #aaa !important;
+}
+
/* live preview */
.progressDiv{
position: relative;
diff --git a/webui.py b/webui.py
index 5d5e80b5..2eecfaa0 100644
--- a/webui.py
+++ b/webui.py
@@ -181,14 +181,11 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
- modelloader.list_builtin_upscalers()
- startup_timer.record("list builtin upscalers")
-
modules.scripts.load_scripts()
startup_timer.record("load scripts")
modelloader.load_upscalers()
- #startup_timer.record("load upscalers") #Is this necessary? I don't know.
+ startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
@@ -388,7 +385,6 @@ def webui():
localization.list_localizations(cmd_opts.localizations_dir)
- modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
startup_timer.record("load scripts")
--
cgit v1.2.3
From 87c3aa7389cea993710c4182f5314e5cea0ad4c6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 11 May 2023 10:09:29 +0300
Subject: return wrap_gradio_gpu_call to webui.py for extensions
---
webui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index 2eecfaa0..293a16cc 100644
--- a/webui.py
+++ b/webui.py
@@ -36,7 +36,7 @@ startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
--
cgit v1.2.3
From 85b4f89926f7c3aaa7846dcbb47df3fd3b483b6b Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Thu, 11 May 2023 23:46:45 +0300
Subject: Replace state.need_restart with state.server_command + replace poll
loop with signal
---
modules/shared.py | 42 +++++++++++++++++++++++++++++++++++++++++-
modules/ui.py | 6 +-----
modules/ui_extensions.py | 7 ++-----
webui.py | 39 ++++++++++++++++++++++++---------------
4 files changed, 68 insertions(+), 26 deletions(-)
(limited to 'webui.py')
diff --git a/modules/shared.py b/modules/shared.py
index 3abf71c0..648a2a19 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -2,6 +2,7 @@ import datetime
import json
import os
import sys
+import threading
import time
import gradio as gr
@@ -110,8 +111,47 @@ class State:
id_live_preview = 0
textinfo = None
time_start = None
- need_restart = False
server_start = None
+ _server_command_signal = threading.Event()
+ _server_command: str | None = None
+
+ @property
+ def need_restart(self) -> bool:
+ # Compatibility getter for need_restart.
+ return self.server_command == "restart"
+
+ @need_restart.setter
+ def need_restart(self, value: bool) -> None:
+ # Compatibility setter for need_restart.
+ if value:
+ self.server_command = "restart"
+
+ @property
+ def server_command(self):
+ return self._server_command
+
+ @server_command.setter
+ def server_command(self, value: str | None) -> None:
+ """
+ Set the server command to `value` and signal that it's been set.
+ """
+ self._server_command = value
+ self._server_command_signal.set()
+
+ def wait_for_server_command(self, timeout: float | None = None) -> str | None:
+ """
+ Wait for server command to get set; return and clear the value and signal.
+ """
+ if self._server_command_signal.wait(timeout):
+ self._server_command_signal.clear()
+ req = self._server_command
+ self._server_command = None
+ return req
+ return None
+
+ def request_restart(self) -> None:
+ self.interrupt()
+ self.server_command = True
def skip(self):
self.skipped = True
diff --git a/modules/ui.py b/modules/ui.py
index 8e51e782..bed8464e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1609,12 +1609,8 @@ def create_ui():
outputs=[]
)
- def request_restart():
- shared.state.interrupt()
- shared.state.need_restart = True
-
restart_gradio.click(
- fn=request_restart,
+ fn=shared.state.request_restart,
_js='restart_reload',
inputs=[],
outputs=[],
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index d7a0f685..4ba3bdd7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename)
-
- shared.state.interrupt()
- shared.state.need_restart = True
+ shared.state.request_restart()
def save_config_state(name):
@@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state)
- shared.state.interrupt()
- shared.state.need_restart = True
+ shared.state.request_restart()
return ""
diff --git a/webui.py b/webui.py
index 293a16cc..39dec3ca 100644
--- a/webui.py
+++ b/webui.py
@@ -234,7 +234,10 @@ def initialize():
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
- signal.signal(signal.SIGINT, sigint_handler)
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
def setup_middleware(app):
@@ -255,19 +258,6 @@ def create_api(app):
return api
-def wait_on_server(demo=None):
- while 1:
- time.sleep(0.5)
- if shared.state.need_restart:
- shared.state.need_restart = False
- time.sleep(0.5)
- demo.close()
- time.sleep(0.5)
-
- modules.script_callbacks.app_reload_callback()
- break
-
-
def api_only():
initialize()
@@ -328,6 +318,7 @@ def webui():
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
+
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
@@ -359,8 +350,26 @@ def webui():
redirector.get("/")
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
- wait_on_server(shared.demo)
+ try:
+ while True:
+ server_command = shared.state.wait_for_server_command(timeout=5)
+ if server_command:
+ if server_command in ("stop", "restart"):
+ break
+ else:
+ print(f"Unknown server command: {server_command}")
+ except KeyboardInterrupt:
+ server_command = "stop"
+
+ if server_command == "stop":
+ # If we catch a keyboard interrupt, we want to stop the server and exit.
+ print('Caught KeyboardInterrupt, stopping...')
+ shared.demo.close()
+ break
print('Restarting UI...')
+ shared.demo.close()
+ time.sleep(0.5)
+ modules.script_callbacks.app_reload_callback()
startup_timer.reset()
--
cgit v1.2.3
From 875990a23213c63c19b8fdd3c87345f7a8ea2ceb Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Tue, 16 May 2023 20:58:35 +0300
Subject: Add option for /_stop route (for graceful shutdown)
---
modules/cmd_args.py | 1 +
webui.py | 13 +++++++++++--
2 files changed, 12 insertions(+), 2 deletions(-)
(limited to 'webui.py')
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index f4a4ab36..6144db5c 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -103,3 +103,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
+parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
diff --git a/webui.py b/webui.py
index 39dec3ca..5172f049 100644
--- a/webui.py
+++ b/webui.py
@@ -8,7 +8,7 @@ import warnings
import json
from threading import Thread
-from fastapi import FastAPI
+from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
@@ -270,6 +270,12 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
+
+def stop_route(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
+
+
def webui():
launch_api = cmd_opts.api
initialize()
@@ -318,6 +324,8 @@ def webui():
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
+ if cmd_opts.add_stop_route:
+ app.add_route("/_stop", stop_route, methods=["POST"])
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
@@ -359,11 +367,12 @@ def webui():
else:
print(f"Unknown server command: {server_command}")
except KeyboardInterrupt:
+ print('Caught KeyboardInterrupt, stopping...')
server_command = "stop"
if server_command == "stop":
+ print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
- print('Caught KeyboardInterrupt, stopping...')
shared.demo.close()
break
print('Restarting UI...')
--
cgit v1.2.3
From f8ca37b9035dc8cb09e15afc5ade6976b927e923 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 17 May 2023 17:06:45 +0300
Subject: fix inability to run with --freeze-settings
---
webui.py | 38 +++++++++++++++++---------------------
1 file changed, 17 insertions(+), 21 deletions(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index 293a16cc..2cd551bd 100644
--- a/webui.py
+++ b/webui.py
@@ -144,16 +144,11 @@ Use --skip-version-check commandline argument to disable this check.
""".strip())
-def initialize():
- fix_asyncio_event_loop_policy()
-
- check_versions()
-
- extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
- startup_timer.record("list extensions")
-
+def restore_config_state_file():
config_state_file = shared.opts.restore_config_state_file
+ if config_state_file == "":
+ return
+
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
@@ -166,6 +161,18 @@ def initialize():
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
+
+def initialize():
+ fix_asyncio_event_loop_policy()
+
+ check_versions()
+
+ extensions.list_extensions()
+ localization.list_localizations(cmd_opts.localizations_dir)
+ startup_timer.record("list extensions")
+
+ restore_config_state_file()
+
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
@@ -370,18 +377,7 @@ def webui():
extensions.list_extensions()
startup_timer.record("list extensions")
- config_state_file = shared.opts.restore_config_state_file
- shared.opts.restore_config_state_file = ""
- shared.opts.save(shared.config_filename)
-
- if os.path.isfile(config_state_file):
- print(f"*** About to restore extension state from file: {config_state_file}")
- with open(config_state_file, "r", encoding="utf-8") as f:
- config_state = json.load(f)
- config_states.restore_extension_config(config_state)
- startup_timer.record("restore extension config")
- elif config_state_file:
- print(f"!!! Config state backup not found: {config_state_file}")
+ restore_config_state_file()
localization.list_localizations(cmd_opts.localizations_dir)
--
cgit v1.2.3
From ae252cd5bc6daa8295ed1ded8ca101812d0df43b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 18 May 2023 10:36:11 +0300
Subject: add --gradio-allowed-path commandline option
---
modules/cmd_args.py | 1 +
webui.py | 3 ++-
2 files changed, 3 insertions(+), 1 deletion(-)
(limited to 'webui.py')
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index a533a454..7bde161e 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -77,6 +77,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
+parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
diff --git a/webui.py b/webui.py
index cebfba96..b4a21e73 100644
--- a/webui.py
+++ b/webui.py
@@ -329,7 +329,8 @@ def webui():
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
inbrowser=cmd_opts.autolaunch,
- prevent_thread_lock=True
+ prevent_thread_lock=True,
+ allowed_paths=cmd_opts.gradio_allowed_path,
)
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
--
cgit v1.2.3
From 2582a0fd3b3e91c5fba9e5e561cbdf5fee835063 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 18 May 2023 22:48:28 +0300
Subject: make it possible for scripts to add cross attention optimizations add
UI selection for cross attention optimization
---
modules/cmd_args.py | 14 ++--
modules/script_callbacks.py | 21 ++++++
modules/sd_hijack.py | 90 ++++++++++++++-----------
modules/sd_hijack_optimizations.py | 135 ++++++++++++++++++++++++++++++++++++-
modules/shared.py | 1 +
modules/shared_items.py | 8 +++
webui.py | 10 +++
7 files changed, 228 insertions(+), 51 deletions(-)
(limited to 'webui.py')
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 7bde161e..85db93f3 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
-parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
+parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
+parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
-parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
-parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
-parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
-parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
+parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
+parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
+parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 3c21a362..40f388a5 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -110,6 +110,7 @@ callback_map = dict(
callbacks_script_unloaded=[],
callbacks_before_ui=[],
callbacks_on_reload=[],
+ callbacks_list_optimizers=[],
)
@@ -258,6 +259,18 @@ def before_ui_callback():
report_exception(c, 'before_ui')
+def list_optimizers_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_optimizers']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_optimizers')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -409,3 +422,11 @@ def on_before_ui(callback):
"""register a function to be called before the UI is created."""
add_callback(callback_map['callbacks_before_ui'], callback)
+
+
+def on_list_optimizers(callback):
+ """register a function to be called when UI is making a list of cross attention optimization options.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
+ to it."""
+
+ add_callback(callback_map['callbacks_list_optimizers'], callback)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 14e7f799..39193be8 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,8 +3,9 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules.hypernetworks import hypernetwork
+from modules.sd_hijack_optimizations import diffusionmodules_model_AttnBlock_forward
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -28,57 +29,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
+optimizers = []
+current_optimizer: sd_hijack_optimizations.SdOptimization = None
+
+
+def list_optimizers():
+ new_optimizers = script_callbacks.list_optimizers_callback()
+
+ new_optimizers = [x for x in new_optimizers if x.is_available()]
+
+ new_optimizers = sorted(new_optimizers, key=lambda x: x.priority(), reverse=True)
+
+ optimizers.clear()
+ optimizers.extend(new_optimizers)
+
def apply_optimizations():
+ global current_optimizer
+
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
- optimization_method = None
+ if current_optimizer is not None:
+ current_optimizer.undo()
+ current_optimizer = None
+
+ selection = shared.opts.cross_attention_optimization
+ if selection == "Automatic" and len(optimizers) > 0:
+ matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
+ else:
+ matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
-
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
- print("Applying xformers cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
- optimization_method = 'xformers'
- elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
- optimization_method = 'sdp-no-mem'
- elif cmd_opts.opt_sdp_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
- optimization_method = 'sdp'
- elif cmd_opts.opt_sub_quad_attention:
- print("Applying sub-quadratic cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
- optimization_method = 'sub-quadratic'
- elif cmd_opts.opt_split_attention_v1:
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization (Doggettx).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- optimization_method = 'Doggettx'
-
- return optimization_method
+ if selection == "None":
+ matching_optimizer = None
+ elif matching_optimizer is None:
+ matching_optimizer = optimizers[0]
+
+ if matching_optimizer is not None:
+ print(f"Applying optimization: {matching_optimizer.name}")
+ matching_optimizer.apply()
+ current_optimizer = matching_optimizer
+ return current_optimizer.name
+ else:
+ return ''
def undo_optimizations():
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
@@ -169,7 +169,11 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
- self.optimization_method = apply_optimizations()
+ try:
+ self.optimization_method = apply_optimizations()
+ except Exception as e:
+ errors.display(e, "applying cross attention optimization")
+ undo_optimizations()
self.clip = m.cond_stage_model
@@ -223,6 +227,10 @@ class StableDiffusionModelHijack:
return token_count, self.clip.get_target_prompt_token_count(token_count)
+ def redo_hijack(self, m):
+ self.undo_hijack(m)
+ self.hijack(m)
+
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index f00fe55c..1c5b709b 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,10 +9,139 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, errors, devices
+from modules import shared, errors, devices, sub_quadratic_attention, script_callbacks
from modules.hypernetworks import hypernetwork
-from .sub_quadratic_attention import efficient_dot_product_attention
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.model
+
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+
+class SdOptimization:
+ def __init__(self, name, label=None, cmd_opt=None):
+ self.name = name
+ self.label = label
+ self.cmd_opt = cmd_opt
+
+ def title(self):
+ if self.label is None:
+ return self.name
+
+ return f"{self.name} - {self.label}"
+
+ def is_available(self):
+ return True
+
+ def priority(self):
+ return 0
+
+ def apply(self):
+ pass
+
+ def undo(self):
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
+
+class SdOptimizationXformers(SdOptimization):
+ def __init__(self):
+ super().__init__("xformers", cmd_opt="xformers")
+
+ def is_available(self):
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+
+ def priority(self):
+ return 100
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+
+
+class SdOptimizationSdpNoMem(SdOptimization):
+ def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"):
+ super().__init__(name, label, cmd_opt)
+
+ def is_available(self):
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
+
+ def priority(self):
+ return 90
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+
+
+class SdOptimizationSdp(SdOptimizationSdpNoMem):
+ def __init__(self):
+ super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention")
+
+ def priority(self):
+ return 80
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+
+
+class SdOptimizationSubQuad(SdOptimization):
+ def __init__(self):
+ super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention")
+
+ def priority(self):
+ return 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+
+
+class SdOptimizationV1(SdOptimization):
+ def __init__(self):
+ super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1")
+
+ def priority(self):
+ return 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+
+
+class SdOptimizationInvokeAI(SdOptimization):
+ def __init__(self):
+ super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai")
+
+ def priority(self):
+ return 1000 if not torch.cuda.is_available() else 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+
+
+class SdOptimizationDoggettx(SdOptimization):
+ def __init__(self):
+ super().__init__("Doggettx", cmd_opt="opt_split_attention")
+
+ def priority(self):
+ return 20
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+
+
+def list_optimizers(res):
+ res.extend([
+ SdOptimizationXformers(),
+ SdOptimizationSdpNoMem(),
+ SdOptimizationSdp(),
+ SdOptimizationSubQuad(),
+ SdOptimizationV1(),
+ SdOptimizationInvokeAI(),
+ SdOptimizationDoggettx(),
+ ])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -299,7 +428,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
+ return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,
diff --git a/modules/shared.py b/modules/shared.py
index fdbab5c4..7cfbaa0c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -417,6 +417,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {
+ "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
diff --git a/modules/shared_items.py b/modules/shared_items.py
index e792a134..2a8713c8 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -21,3 +21,11 @@ def refresh_vae_list():
import modules.sd_vae
modules.sd_vae.refresh_vae_list()
+
+
+def cross_attention_optimizations():
+ import modules.sd_hijack
+
+ return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+
+
diff --git a/webui.py b/webui.py
index b4a21e73..afe3c5fa 100644
--- a/webui.py
+++ b/webui.py
@@ -52,6 +52,7 @@ import modules.img2img
import modules.lowvram
import modules.scripts
import modules.sd_hijack
+import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
import modules.txt2img
@@ -200,6 +201,10 @@ def initialize():
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
+ modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
+ modules.sd_hijack.list_optimizers()
+ startup_timer.record("scripts list_optimizers")
+
# load model in parallel to other startup stuff
Thread(target=lambda: shared.sd_model).start()
@@ -208,6 +213,7 @@ def initialize():
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
+ shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
startup_timer.record("opts onchange")
shared.reload_hypernetworks()
@@ -428,6 +434,10 @@ def webui():
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
startup_timer.record("initialize extra networks")
+ modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
+ modules.sd_hijack.list_optimizers()
+ startup_timer.record("scripts list_optimizers")
+
if __name__ == "__main__":
if cmd_opts.nowebui:
--
cgit v1.2.3
From de3abc29ae91d002789c49c836df9c8d8b35cada Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 15:27:23 +0300
Subject: Fix typo "intialize"
---
modules/ui_extra_networks.py | 2 +-
webui.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
(limited to 'webui.py')
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 80cfa841..24eeef0e 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -232,7 +232,7 @@ class ExtraNetworksPage:
return None
-def intialize():
+def initialize():
extra_pages.clear()
diff --git a/webui.py b/webui.py
index b4a21e73..30e4f239 100644
--- a/webui.py
+++ b/webui.py
@@ -213,7 +213,7 @@ def initialize():
shared.reload_hypernetworks()
startup_timer.record("reload hypernets")
- ui_extra_networks.intialize()
+ ui_extra_networks.initialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
@@ -419,7 +419,7 @@ def webui():
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
- ui_extra_networks.intialize()
+ ui_extra_networks.initialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
--
cgit v1.2.3
From 21ee46eea791d83b3b49cedd2306c7f0f1807250 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 15:35:16 +0300
Subject: Deduplicate default extra network registration
---
modules/extra_networks.py | 5 +++++
modules/ui_extra_networks.py | 9 +++++++++
webui.py | 16 ++++++----------
3 files changed, 20 insertions(+), 10 deletions(-)
(limited to 'webui.py')
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index f9db41bc..94347275 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -14,6 +14,11 @@ def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_default_extra_networks():
+ from modules.extra_networks_hypernet import ExtraNetworkHypernet
+ register_extra_network(ExtraNetworkHypernet())
+
+
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 24eeef0e..19fbaae5 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -236,6 +236,15 @@ def initialize():
extra_pages.clear()
+def register_default_pages():
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
+ from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
+ from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
+ register_page(ExtraNetworksPageTextualInversion())
+ register_page(ExtraNetworksPageHypernetworks())
+ register_page(ExtraNetworksPageCheckpoints())
+
+
class ExtraNetworksUi:
def __init__(self):
self.pages = None
diff --git a/webui.py b/webui.py
index 30e4f239..ad6be239 100644
--- a/webui.py
+++ b/webui.py
@@ -34,8 +34,7 @@ startup_timer.record("import gradio")
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
-from modules import extra_networks, ui_extra_networks_checkpoints
-from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
+from modules import extra_networks
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
@@ -214,12 +213,11 @@ def initialize():
startup_timer.record("reload hypernets")
ui_extra_networks.initialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
+ ui_extra_networks.register_default_pages()
extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+ extra_networks.register_default_extra_networks()
+
startup_timer.record("extra networks")
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
@@ -420,12 +418,10 @@ def webui():
startup_timer.record("reload hypernetworks")
ui_extra_networks.initialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
+ ui_extra_networks.register_default_pages()
extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+ extra_networks.register_default_extra_networks()
startup_timer.record("initialize extra networks")
--
cgit v1.2.3
From a0005121aee9db3b65e55891b4490ed3555b4b09 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 15:37:13 +0300
Subject: Simplify CORS middleware configuration
---
webui.py | 24 ++++++++++++++++--------
1 file changed, 16 insertions(+), 8 deletions(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index ad6be239..198f4f1a 100644
--- a/webui.py
+++ b/webui.py
@@ -246,15 +246,23 @@ def initialize():
def setup_middleware(app):
- app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
- if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- app.build_middleware_stack() # rebuild middleware stack on-the-fly
+ configure_cors_middleware(app)
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
+
+
+def configure_cors_middleware(app):
+ cors_options = {
+ "allow_methods": ["*"],
+ "allow_headers": ["*"],
+ "allow_credentials": True,
+ }
+ if cmd_opts.cors_allow_origins:
+ cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
+ if cmd_opts.cors_allow_origins_regex:
+ cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
+ app.add_middleware(CORSMiddleware, **cors_options)
def create_api(app):
--
cgit v1.2.3
From 1482c89376037896da1873712bae4b4795cc7b4b Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 15:52:29 +0300
Subject: Refactor validate_tls_options out, fix typo (keyfile was there twice)
---
webui.py | 33 ++++++++++++++++++---------------
1 file changed, 18 insertions(+), 15 deletions(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index 198f4f1a..2110b31b 100644
--- a/webui.py
+++ b/webui.py
@@ -161,9 +161,26 @@ def restore_config_state_file():
print(f"!!! Config state backup not found: {config_state_file}")
+def validate_tls_options():
+ if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
+ return
+
+ try:
+ if not os.path.exists(cmd_opts.tls_keyfile):
+ print("Invalid path to TLS keyfile given")
+ if not os.path.exists(cmd_opts.tls_certfile):
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+ except TypeError:
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+ print("TLS setup invalid, running webui without TLS")
+ else:
+ print("Running with TLS")
+ startup_timer.record("TLS")
+
+
def initialize():
fix_asyncio_event_loop_policy()
-
+ validate_tls_options()
check_versions()
extensions.list_extensions()
@@ -220,20 +237,6 @@ def initialize():
startup_timer.record("extra networks")
- if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
-
- try:
- if not os.path.exists(cmd_opts.tls_keyfile):
- print("Invalid path to TLS keyfile given")
- if not os.path.exists(cmd_opts.tls_certfile):
- print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
- except TypeError:
- cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
- print("TLS setup invalid, running webui without TLS")
- else:
- print("Running with TLS")
- startup_timer.record("TLS")
-
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
--
cgit v1.2.3
From 8200e0c27bb4b7f00b30f5b96186413b752d1f84 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 15:54:47 +0300
Subject: Refactor configure_sigint_handler out
---
webui.py | 22 +++++++++++++---------
1 file changed, 13 insertions(+), 9 deletions(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index 2110b31b..ace38123 100644
--- a/webui.py
+++ b/webui.py
@@ -178,9 +178,22 @@ def validate_tls_options():
startup_timer.record("TLS")
+def configure_sigint_handler():
+ # make the program just exit at ctrl+c without waiting for anything
+ def sigint_handler(sig, frame):
+ print(f'Interrupted with signal {sig} in {frame}')
+ os._exit(0)
+
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
+
+
def initialize():
fix_asyncio_event_loop_policy()
validate_tls_options()
+ configure_sigint_handler()
check_versions()
extensions.list_extensions()
@@ -237,15 +250,6 @@ def initialize():
startup_timer.record("extra networks")
- # make the program just exit at ctrl+c without waiting for anything
- def sigint_handler(sig, frame):
- print(f'Interrupted with signal {sig} in {frame}')
- os._exit(0)
-
- if not os.environ.get("COVERAGE_RUN"):
- # Don't install the immediate-quit handler when running under coverage,
- # as then the coverage report won't be generated.
- signal.signal(signal.SIGINT, sigint_handler)
def setup_middleware(app):
--
cgit v1.2.3
From 8a178e67172f4677cad747b3364db3f2a0636911 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 16:00:53 +0300
Subject: Refactor configure opts_onchange out
---
webui.py | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index ace38123..53fe260e 100644
--- a/webui.py
+++ b/webui.py
@@ -190,6 +190,15 @@ def configure_sigint_handler():
signal.signal(signal.SIGINT, sigint_handler)
+def configure_opts_onchange():
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
+ startup_timer.record("opts onchange")
+
+
def initialize():
fix_asyncio_event_loop_policy()
validate_tls_options()
@@ -232,13 +241,6 @@ def initialize():
# load model in parallel to other startup stuff
Thread(target=lambda: shared.sd_model).start()
- shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
- shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
- shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
- startup_timer.record("opts onchange")
-
shared.reload_hypernetworks()
startup_timer.record("reload hypernets")
--
cgit v1.2.3
From 674e80c6255655b9163477913742b36cc5e05003 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 17:22:32 +0300
Subject: Note pending PR for app_kwargs
---
webui.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index 53fe260e..e568ef42 100644
--- a/webui.py
+++ b/webui.py
@@ -326,6 +326,7 @@ def webui():
# this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'):
+ # TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged
def fastapi_setup(self):
self.docs_url = "/docs"
self.redoc_url = "/redoc"
--
cgit v1.2.3
From 0f28aee9cd12b8294df80506e6466cd90a9ae195 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 17:28:41 +0300
Subject: Refactor gradio auth
---
webui.py | 37 +++++++++++++++++++++++++++++--------
1 file changed, 29 insertions(+), 8 deletions(-)
(limited to 'webui.py')
diff --git a/webui.py b/webui.py
index e568ef42..64b113dd 100644
--- a/webui.py
+++ b/webui.py
@@ -7,6 +7,7 @@ import re
import warnings
import json
from threading import Thread
+from typing import Iterable
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
@@ -178,6 +179,32 @@ def validate_tls_options():
startup_timer.record("TLS")
+def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
+ """
+ Convert the gradio_auth and gradio_auth_path commandline arguments into
+ an iterable of (username, password) tuples.
+ """
+ def process_credential_line(s) -> tuple[str, ...] | None:
+ s = s.strip()
+ if not s:
+ return None
+ return tuple(s.split(':', 1))
+
+ if cmd_opts.gradio_auth:
+ for cred in cmd_opts.gradio_auth.split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+ if cmd_opts.gradio_auth_path:
+ with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ for cred in line.strip().split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
@@ -316,13 +343,7 @@ def webui():
if not cmd_opts.no_gradio_queue:
shared.demo.queue(64)
- gradio_auth_creds = []
- if cmd_opts.gradio_auth:
- gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
- if cmd_opts.gradio_auth_path:
- with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
- for line in file.readlines():
- gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+ gradio_auth_creds = list(get_gradio_auth_creds()) or None
# this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'):
@@ -343,7 +364,7 @@ def webui():
ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug,
- auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
+ auth=gradio_auth_creds,
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path,
--
cgit v1.2.3
From 71f4a4afdfe2da8cbf23a74b82c32b4d113d996e Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 19 May 2023 16:08:40 +0300
Subject: Deduplicate webui.py initial-load/reload code
---
modules/sd_models.py | 1 -
webui.py | 84 +++++++++++++++++++++-------------------------------
2 files changed, 34 insertions(+), 51 deletions(-)
(limited to 'webui.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 8e42bfea..b1afbaa7 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -98,7 +98,6 @@ def setup_model():
if not os.path.exists(model_path):
os.makedirs(model_path)
- list_models()
enable_midas_autodownload()
diff --git a/webui.py b/webui.py
index 64b113dd..5c89a3b8 100644
--- a/webui.py
+++ b/webui.py
@@ -15,6 +15,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
+
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import paths, timer, import_hook, errors # noqa: F401
@@ -231,9 +232,27 @@ def initialize():
validate_tls_options()
configure_sigint_handler()
check_versions()
+ modelloader.cleanup_models()
+ configure_opts_onchange()
+
+ modules.sd_models.setup_model()
+ startup_timer.record("setup SD model")
+
+ codeformer.setup_model(cmd_opts.codeformer_models_path)
+ startup_timer.record("setup codeformer")
+
+ gfpgan.setup_model(cmd_opts.gfpgan_models_path)
+ startup_timer.record("setup gfpgan")
+
+ initialize_rest(reload_script_modules=False)
+
+def initialize_rest(*, reload_script_modules=False):
+ """
+ Called both from initialize() and when reloading the webui.
+ """
+ sd_samplers.set_samplers()
extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
restore_config_state_file()
@@ -243,42 +262,40 @@ def initialize():
modules.scripts.load_scripts()
return
- modelloader.cleanup_models()
- modules.sd_models.setup_model()
+ modules.sd_models.list_models()
startup_timer.record("list SD models")
- codeformer.setup_model(cmd_opts.codeformer_models_path)
- startup_timer.record("setup codeformer")
-
- gfpgan.setup_model(cmd_opts.gfpgan_models_path)
- startup_timer.record("setup gfpgan")
+ localization.list_localizations(cmd_opts.localizations_dir)
modules.scripts.load_scripts()
startup_timer.record("load scripts")
+ if reload_script_modules:
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+ startup_timer.record("reload script modules")
+
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
-
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
# load model in parallel to other startup stuff
+ # (when reloading, this does nothing)
Thread(target=lambda: shared.sd_model).start()
shared.reload_hypernetworks()
- startup_timer.record("reload hypernets")
+ startup_timer.record("reload hypernetworks")
ui_extra_networks.initialize()
ui_extra_networks.register_default_pages()
extra_networks.initialize()
extra_networks.register_default_extra_networks()
-
- startup_timer.record("extra networks")
-
+ startup_timer.record("initialize extra networks")
def setup_middleware(app):
@@ -423,45 +440,12 @@ def webui():
print('Restarting UI...')
shared.demo.close()
time.sleep(0.5)
- modules.script_callbacks.app_reload_callback()
-
startup_timer.reset()
-
- sd_samplers.set_samplers()
-
+ modules.script_callbacks.app_reload_callback()
+ startup_timer.record("app reload callback")
modules.script_callbacks.script_unloaded_callback()
- extensions.list_extensions()
- startup_timer.record("list extensions")
-
- restore_config_state_file()
-
- localization.list_localizations(cmd_opts.localizations_dir)
-
- modules.scripts.reload_scripts()
- startup_timer.record("load scripts")
-
- modules.script_callbacks.model_loaded_callback(shared.sd_model)
- startup_timer.record("model loaded callback")
-
- modelloader.load_upscalers()
- startup_timer.record("load upscalers")
-
- for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
- importlib.reload(module)
- startup_timer.record("reload script modules")
-
- modules.sd_models.list_models()
- startup_timer.record("list SD models")
-
- shared.reload_hypernetworks()
- startup_timer.record("reload hypernetworks")
-
- ui_extra_networks.initialize()
- ui_extra_networks.register_default_pages()
-
- extra_networks.initialize()
- extra_networks.register_default_extra_networks()
- startup_timer.record("initialize extra networks")
+ startup_timer.record("scripts unloaded callback")
+ initialize_rest(reload_script_modules=True)
if __name__ == "__main__":
--
cgit v1.2.3
From df004be2fc4b2c68adfb75565d97551a1a5e7ed6 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Sun, 21 May 2023 00:26:16 +0300
Subject: Add a couple `from __future__ import annotations`es for Py3.9 compat
---
modules/sd_hijack_optimizations.py | 1 +
webui.py | 2 ++
2 files changed, 3 insertions(+)
(limited to 'webui.py')
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 0eb4c525..2ec0b049 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import math
import sys
import traceback
diff --git a/webui.py b/webui.py
index a76e377c..d4402f55 100644
--- a/webui.py
+++ b/webui.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import sys
import time
--
cgit v1.2.3