From 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 18:28:15 +0300 Subject: Autofix Ruff W (not W605) (mostly whitespace) --- modules/script_callbacks.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 17109732..7d9dd736 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -32,22 +32,22 @@ class CFGDenoiserParams: def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): self.x = x """Latent image representation in the process of being denoised""" - + self.image_cond = image_cond """Conditioning image""" - + self.sigma = sigma """Current sigma noise step value""" - + self.sampling_step = sampling_step """Current Sampling step number""" - + self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" - + self.text_cond = text_cond """ Encoder hidden states of text conditioning from prompt""" - + self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" @@ -240,7 +240,7 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) - + def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' -- cgit v1.2.3 From 3078001439d25b66ef5627c9e3d431aa23bbed73 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 14 May 2023 01:49:41 +0000 Subject: Add/modify CFG callbacks Required by self-attn guidance extension https://github.com/ashen-sensored/sd_webui_SAG --- modules/script_callbacks.py | 35 +++++++++++++++++++++++++++++++++++ modules/sd_samplers_kdiffusion.py | 8 +++++++- 2 files changed, 42 insertions(+), 1 deletion(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 7d9dd736..e83c6ecf 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -53,6 +53,21 @@ class CFGDenoiserParams: class CFGDenoisedParams: + def __init__(self, x, sampling_step, total_sampling_steps, inner_model): + self.x = x + """Latent image representation in the process of being denoised""" + + self.sampling_step = sampling_step + """Current Sampling step number""" + + self.total_sampling_steps = total_sampling_steps + """Total number of sampling steps planned""" + + self.inner_model = inner_model + """Inner model reference that is being used for denoising""" + + +class AfterCFGCallbackParams: def __init__(self, x, sampling_step, total_sampling_steps): self.x = x """Latent image representation in the process of being denoised""" @@ -63,6 +78,9 @@ class CFGDenoisedParams: self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" + self.output_altered = False + """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" + class UiTrainTabParams: def __init__(self, txt2img_preview_params): @@ -87,6 +105,7 @@ callback_map = dict( callbacks_image_saved=[], callbacks_cfg_denoiser=[], callbacks_cfg_denoised=[], + callbacks_cfg_after_cfg=[], callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], @@ -186,6 +205,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): report_exception(c, 'cfg_denoised_callback') +def cfg_after_cfg_callback(params: AfterCFGCallbackParams): + for c in callback_map['callbacks_cfg_after_cfg']: + try: + c.callback(params) + except Exception: + report_exception(c, 'cfg_after_cfg_callback') + + def before_component_callback(component, **kwargs): for c in callback_map['callbacks_before_component']: try: @@ -332,6 +359,14 @@ def on_cfg_denoised(callback): add_callback(callback_map['callbacks_cfg_denoised'], callback) +def on_cfg_after_cfg(callback): + """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. + The callback is called with one argument: + - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. + """ + add_callback(callback_map['callbacks_cfg_after_cfg'], callback) + + def on_before_component(callback): """register a function to be called before a component is created. The callback is called with arguments: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e9e41818..55f0d3a3 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -8,6 +8,7 @@ from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), @@ -160,7 +161,7 @@ class CFGDenoiser(torch.nn.Module): fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) cfg_denoised_callback(denoised_params) devices.test_for_nans(x_out, "unet") @@ -180,6 +181,11 @@ class CFGDenoiser(torch.nn.Module): if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised + after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) + cfg_after_cfg_callback(after_cfg_callback_params) + if after_cfg_callback_params.output_altered: + denoised = after_cfg_callback_params.x + self.step += 1 return denoised -- cgit v1.2.3 From 8abfc95013d247c8a863d048574bc1f9d1eb0443 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 14 May 2023 12:56:34 +0800 Subject: Update script_callbacks.py --- modules/script_callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index e83c6ecf..57dfd457 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -64,7 +64,7 @@ class CFGDenoisedParams: """Total number of sampling steps planned""" self.inner_model = inner_model - """Inner model reference that is being used for denoising""" + """Inner model reference used for denoising""" class AfterCFGCallbackParams: @@ -79,7 +79,7 @@ class AfterCFGCallbackParams: """Total number of sampling steps planned""" self.output_altered = False - """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" + """A flag for CFGDenoiser indicating whether the output has been altered by the callback""" class UiTrainTabParams: @@ -360,9 +360,9 @@ def on_cfg_denoised(callback): def on_cfg_after_cfg(callback): - """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. + """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. The callback is called with one argument: - - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. + - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. """ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) -- cgit v1.2.3 From 005849331e82cded96f6f3e5ff828037c672c38d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 14 May 2023 08:15:22 +0300 Subject: remove output_altered flag from AfterCFGCallbackParams --- modules/script_callbacks.py | 3 --- modules/sd_samplers_kdiffusion.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 57dfd457..3c21a362 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -78,9 +78,6 @@ class AfterCFGCallbackParams: self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" - self.output_altered = False - """A flag for CFGDenoiser indicating whether the output has been altered by the callback""" - class UiTrainTabParams: def __init__(self, txt2img_preview_params): diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 55f0d3a3..61f23ad7 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -183,8 +183,7 @@ class CFGDenoiser(torch.nn.Module): after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) cfg_after_cfg_callback(after_cfg_callback_params) - if after_cfg_callback_params.output_altered: - denoised = after_cfg_callback_params.x + denoised = after_cfg_callback_params.x self.step += 1 return denoised -- cgit v1.2.3 From 2582a0fd3b3e91c5fba9e5e561cbdf5fee835063 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 18 May 2023 22:48:28 +0300 Subject: make it possible for scripts to add cross attention optimizations add UI selection for cross attention optimization --- modules/cmd_args.py | 14 ++-- modules/script_callbacks.py | 21 ++++++ modules/sd_hijack.py | 90 ++++++++++++++----------- modules/sd_hijack_optimizations.py | 135 ++++++++++++++++++++++++++++++++++++- modules/shared.py | 1 + modules/shared_items.py | 8 +++ webui.py | 10 +++ 7 files changed, 228 insertions(+), 51 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 7bde161e..85db93f3 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") -parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") -parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") +parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization") +parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) -parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") -parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") -parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*") -parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization") +parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization") +parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*") +parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*") +parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 3c21a362..40f388a5 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -110,6 +110,7 @@ callback_map = dict( callbacks_script_unloaded=[], callbacks_before_ui=[], callbacks_on_reload=[], + callbacks_list_optimizers=[], ) @@ -258,6 +259,18 @@ def before_ui_callback(): report_exception(c, 'before_ui') +def list_optimizers_callback(): + res = [] + + for c in callback_map['callbacks_list_optimizers']: + try: + c.callback(res) + except Exception: + report_exception(c, 'list_optimizers') + + return res + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -409,3 +422,11 @@ def on_before_ui(callback): """register a function to be called before the UI is created.""" add_callback(callback_map['callbacks_before_ui'], callback) + + +def on_list_optimizers(callback): + """register a function to be called when UI is making a list of cross attention optimization options. + The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization + to it.""" + + add_callback(callback_map['callbacks_list_optimizers'], callback) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 14e7f799..39193be8 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,8 +3,9 @@ from torch.nn.functional import silu from types import MethodType import modules.textual_inversion.textual_inversion -from modules import devices, sd_hijack_optimizations, shared +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors from modules.hypernetworks import hypernetwork +from modules.sd_hijack_optimizations import diffusionmodules_model_AttnBlock_forward from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr @@ -28,57 +29,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] ldm.modules.attention.print = lambda *args: None ldm.modules.diffusionmodules.model.print = lambda *args: None +optimizers = [] +current_optimizer: sd_hijack_optimizations.SdOptimization = None + + +def list_optimizers(): + new_optimizers = script_callbacks.list_optimizers_callback() + + new_optimizers = [x for x in new_optimizers if x.is_available()] + + new_optimizers = sorted(new_optimizers, key=lambda x: x.priority(), reverse=True) + + optimizers.clear() + optimizers.extend(new_optimizers) + def apply_optimizations(): + global current_optimizer + undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th - optimization_method = None + if current_optimizer is not None: + current_optimizer.undo() + current_optimizer = None + + selection = shared.opts.cross_attention_optimization + if selection == "Automatic" and len(optimizers) > 0: + matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) + else: + matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None) - can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp - - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): - print("Applying xformers cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward - optimization_method = 'xformers' - elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: - print("Applying scaled dot product cross attention optimization (without memory efficient attention).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward - optimization_method = 'sdp-no-mem' - elif cmd_opts.opt_sdp_attention and can_use_sdp: - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward - optimization_method = 'sdp' - elif cmd_opts.opt_sub_quad_attention: - print("Applying sub-quadratic cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward - optimization_method = 'sub-quadratic' - elif cmd_opts.opt_split_attention_v1: - print("Applying v1 cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - optimization_method = 'V1' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()): - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI - optimization_method = 'InvokeAI' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - print("Applying cross attention optimization (Doggettx).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - optimization_method = 'Doggettx' - - return optimization_method + if selection == "None": + matching_optimizer = None + elif matching_optimizer is None: + matching_optimizer = optimizers[0] + + if matching_optimizer is not None: + print(f"Applying optimization: {matching_optimizer.name}") + matching_optimizer.apply() + current_optimizer = matching_optimizer + return current_optimizer.name + else: + return '' def undo_optimizations(): - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward @@ -169,7 +169,11 @@ class StableDiffusionModelHijack: if m.cond_stage_key == "edit": sd_hijack_unet.hijack_ddpm_edit() - self.optimization_method = apply_optimizations() + try: + self.optimization_method = apply_optimizations() + except Exception as e: + errors.display(e, "applying cross attention optimization") + undo_optimizations() self.clip = m.cond_stage_model @@ -223,6 +227,10 @@ class StableDiffusionModelHijack: return token_count, self.clip.get_target_prompt_token_count(token_count) + def redo_hijack(self, m): + self.undo_hijack(m) + self.hijack(m) + class EmbeddingsWithFixes(torch.nn.Module): def __init__(self, wrapped, embeddings): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f00fe55c..1c5b709b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,10 +9,139 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors, devices +from modules import shared, errors, devices, sub_quadratic_attention, script_callbacks from modules.hypernetworks import hypernetwork -from .sub_quadratic_attention import efficient_dot_product_attention +import ldm.modules.attention +import ldm.modules.diffusionmodules.model + +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + + +class SdOptimization: + def __init__(self, name, label=None, cmd_opt=None): + self.name = name + self.label = label + self.cmd_opt = cmd_opt + + def title(self): + if self.label is None: + return self.name + + return f"{self.name} - {self.label}" + + def is_available(self): + return True + + def priority(self): + return 0 + + def apply(self): + pass + + def undo(self): + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + + +class SdOptimizationXformers(SdOptimization): + def __init__(self): + super().__init__("xformers", cmd_opt="xformers") + + def is_available(self): + return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) + + def priority(self): + return 100 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + + +class SdOptimizationSdpNoMem(SdOptimization): + def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"): + super().__init__(name, label, cmd_opt) + + def is_available(self): + return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) + + def priority(self): + return 90 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + + +class SdOptimizationSdp(SdOptimizationSdpNoMem): + def __init__(self): + super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention") + + def priority(self): + return 80 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + + +class SdOptimizationSubQuad(SdOptimization): + def __init__(self): + super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention") + + def priority(self): + return 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + + +class SdOptimizationV1(SdOptimization): + def __init__(self): + super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1") + + def priority(self): + return 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + + +class SdOptimizationInvokeAI(SdOptimization): + def __init__(self): + super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai") + + def priority(self): + return 1000 if not torch.cuda.is_available() else 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + + +class SdOptimizationDoggettx(SdOptimization): + def __init__(self): + super().__init__("Doggettx", cmd_opt="opt_split_attention") + + def priority(self): + return 20 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + + +def list_optimizers(res): + res.extend([ + SdOptimizationXformers(), + SdOptimizationSdpNoMem(), + SdOptimizationSdp(), + SdOptimizationSubQuad(), + SdOptimizationV1(), + SdOptimizationInvokeAI(), + SdOptimizationDoggettx(), + ]) if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: @@ -299,7 +428,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ kv_chunk_size = k_tokens with devices.without_autocast(disable=q.dtype == v.dtype): - return efficient_dot_product_attention( + return sub_quadratic_attention.efficient_dot_product_attention( q, k, v, diff --git a/modules/shared.py b/modules/shared.py index fdbab5c4..7cfbaa0c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -417,6 +417,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { })) options_templates.update(options_section(('optimizations', "Optimizations"), { + "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), diff --git a/modules/shared_items.py b/modules/shared_items.py index e792a134..2a8713c8 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -21,3 +21,11 @@ def refresh_vae_list(): import modules.sd_vae modules.sd_vae.refresh_vae_list() + + +def cross_attention_optimizations(): + import modules.sd_hijack + + return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] + + diff --git a/webui.py b/webui.py index b4a21e73..afe3c5fa 100644 --- a/webui.py +++ b/webui.py @@ -52,6 +52,7 @@ import modules.img2img import modules.lowvram import modules.scripts import modules.sd_hijack +import modules.sd_hijack_optimizations import modules.sd_models import modules.sd_vae import modules.txt2img @@ -200,6 +201,10 @@ def initialize(): modules.textual_inversion.textual_inversion.list_textual_inversion_templates() startup_timer.record("refresh textual inversion templates") + modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) + modules.sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + # load model in parallel to other startup stuff Thread(target=lambda: shared.sd_model).start() @@ -208,6 +213,7 @@ def initialize(): shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) + shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) startup_timer.record("opts onchange") shared.reload_hypernetworks() @@ -428,6 +434,10 @@ def webui(): extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) startup_timer.record("initialize extra networks") + modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) + modules.sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + if __name__ == "__main__": if cmd_opts.nowebui: -- cgit v1.2.3 From 0cc05fc492a9360d3b2f1b3f64c7d74f9041f74e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 21 May 2023 00:41:41 +0300 Subject: work on startup profile display --- html/footer.html | 2 + javascript/profilerVisualization.js | 91 +++++++++++++++++++++++++++++++++++++ javascript/ui_settings_hints.js | 2 +- modules/script_callbacks.py | 3 ++ modules/scripts.py | 3 +- modules/timer.py | 46 +++++++++++++++++-- modules/ui.py | 4 +- style.css | 8 +++- webui.py | 14 ++++-- 9 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 javascript/profilerVisualization.js (limited to 'modules/script_callbacks.py') diff --git a/html/footer.html b/html/footer.html index bad87ff6..1ce13295 100644 --- a/html/footer.html +++ b/html/footer.html @@ -5,6 +5,8 @@  •  Gradio  •  + Startup profile +  •  Reload UI
diff --git a/javascript/profilerVisualization.js b/javascript/profilerVisualization.js new file mode 100644 index 00000000..1bd75986 --- /dev/null +++ b/javascript/profilerVisualization.js @@ -0,0 +1,91 @@ + +function createRow(table, cellName, items) { + var tr = document.createElement('tr'); + var res = []; + + items.forEach(function(x) { + var td = document.createElement(cellName); + td.textContent = x; + tr.appendChild(td); + res.push(td); + }); + + table.appendChild(tr); + + return res; +} + +function showProfile(path, cutoff = 0.0005) { + requestGet(path, {}, function(data) { + var table = document.createElement('table'); + table.className = 'popup-table'; + + data.records['total'] = data.total; + var keys = Object.keys(data.records).sort(function(a, b) { + return data.records[b] - data.records[a]; + }); + var items = keys.map(function(x) { + return {key: x, parts: x.split('/'), time: data.records[x]}; + }); + var maxLength = items.reduce(function(a, b) { + return Math.max(a, b.parts.length); + }, 0); + + var cols = createRow(table, 'th', ['record', 'seconds']); + cols[0].colSpan = maxLength; + + function arraysEqual(a, b) { + return !(a < b || b < a); + } + + var addLevel = function(level, parent) { + var matching = items.filter(function(x) { + return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent); + }); + var sorted = matching.sort(function(a, b) { + return b.time - a.time; + }); + var othersTime = 0; + var othersList = []; + sorted.forEach(function(x) { + if (x.time < cutoff) { + othersTime += x.time; + othersList.push(x.parts[level]); + return; + } + + var cells = []; + for (var i = 0; i < maxLength; i++) { + cells.push(x.parts[i]); + } + cells.push(x.time.toFixed(3)); + var cols = createRow(table, 'td', cells); + for (i = 0; i < level; i++) { + cols[i].className = 'muted'; + } + + addLevel(level + 1, parent.concat([x.parts[level]])); + }); + + if (othersTime > 0) { + var cells = []; + for (var i = 0; i < maxLength; i++) { + cells.push(parent[i]); + } + cells.push(othersTime.toFixed(3)); + var cols = createRow(table, 'td', cells); + for (i = 0; i < level; i++) { + cols[i].className = 'muted'; + } + + cols[level].textContent = 'others'; + cols[level].title = othersList.join(", "); + } + }; + + addLevel(0, []); + + popup(table); + }); +} + diff --git a/javascript/ui_settings_hints.js b/javascript/ui_settings_hints.js index e216852b..d088f949 100644 --- a/javascript/ui_settings_hints.js +++ b/javascript/ui_settings_hints.js @@ -42,7 +42,7 @@ onOptionsChanged(function() { function settingsHintsShowQuicksettings() { requestGet("./internal/quicksettings-hint", {}, function(data) { var table = document.createElement('table'); - table.className = 'settings-value-table'; + table.className = 'popup-table'; data.forEach(function(obj) { var tr = document.createElement('tr'); diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 40f388a5..ecffc206 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -7,6 +7,8 @@ from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks +from modules import timer + def report_exception(c, job): print(f"Error executing callback {job} for {c.script}", file=sys.stderr) @@ -123,6 +125,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: c.callback(demo, app) + timer.startup_timer.record(c.script) except Exception: report_exception(c, 'app_started_callback') diff --git a/modules/scripts.py b/modules/scripts.py index c902804b..7ef1a8f8 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,7 +6,7 @@ from collections import namedtuple import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, timer AlwaysVisible = object() @@ -270,6 +270,7 @@ def load_scripts(): finally: sys.path = syspath current_basedir = paths.script_path + timer.startup_timer.record(scriptfile.filename) global scripts_txt2img, scripts_img2img, scripts_postproc diff --git a/modules/timer.py b/modules/timer.py index ba92be33..da99e49f 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -1,11 +1,30 @@ import time +class TimerSubcategory: + def __init__(self, timer, category): + self.timer = timer + self.category = category + self.start = None + self.original_base_category = timer.base_category + + def __enter__(self): + self.start = time.time() + self.timer.base_category = self.original_base_category + self.category + "/" + + def __exit__(self, exc_type, exc_val, exc_tb): + elapsed_for_subcategroy = time.time() - self.start + self.timer.base_category = self.original_base_category + self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) + self.timer.record(self.category) + + class Timer: def __init__(self): self.start = time.time() self.records = {} self.total = 0 + self.base_category = '' def elapsed(self): end = time.time() @@ -13,18 +32,29 @@ class Timer: self.start = end return res - def record(self, category, extra_time=0): - e = self.elapsed() + def add_time_to_record(self, category, amount): if category not in self.records: self.records[category] = 0 - self.records[category] += e + extra_time + self.records[category] += amount + + def record(self, category, extra_time=0): + e = self.elapsed() + + self.add_time_to_record(self.base_category + category, e + extra_time) + self.total += e + extra_time + def subcategory(self, name): + self.elapsed() + + subcat = TimerSubcategory(self, name) + return subcat + def summary(self): res = f"{self.total:.1f}s" - additions = [x for x in self.records.items() if x[1] >= 0.1] + additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category] if not additions: return res @@ -34,5 +64,13 @@ class Timer: return res + def dump(self): + return {'total': self.total, 'records': self.records} + def reset(self): self.__init__() + + +startup_timer = Timer() + +startup_record = None diff --git a/modules/ui.py b/modules/ui.py index 82820ab5..5174da63 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, timer from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -1901,3 +1901,5 @@ def setup_ui_api(app): app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) + + app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"]) diff --git a/style.css b/style.css index ba12723a..f2491726 100644 --- a/style.css +++ b/style.css @@ -403,19 +403,23 @@ div#extras_scale_to_tab div.form{ margin: 0 1.2em; } -table.settings-value-table{ +table.popup-table{ background: white; border-collapse: collapse; margin: 1em; border: 4px solid white; } -table.settings-value-table td{ +table.popup-table td{ padding: 0.4em; border: 1px solid #ccc; max-width: 36em; } +table.popup-table .muted{ + color: #aaa; +} + .ui-defaults-none{ color: #aaa !important; } diff --git a/webui.py b/webui.py index a76e377c..940966eb 100644 --- a/webui.py +++ b/webui.py @@ -20,7 +20,7 @@ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not from modules import paths, timer, import_hook, errors # noqa: F401 -startup_timer = timer.Timer() +startup_timer = timer.startup_timer import torch import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them @@ -269,8 +269,8 @@ def initialize_rest(*, reload_script_modules=False): localization.list_localizations(cmd_opts.localizations_dir) - modules.scripts.load_scripts() - startup_timer.record("load scripts") + with startup_timer.subcategory("load scripts"): + modules.scripts.load_scripts() if reload_script_modules: for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: @@ -416,9 +416,12 @@ def webui(): ui_extra_networks.add_pages_to_demo(app) - modules.script_callbacks.app_started_callback(shared.demo, app) - startup_timer.record("scripts app_started_callback") + startup_timer.record("add APIs") + + with startup_timer.subcategory("app_started_callback"): + modules.script_callbacks.app_started_callback(shared.demo, app) + timer.startup_record = startup_timer.dump() print(f"Startup time: {startup_timer.summary()}.") if cmd_opts.subpath: @@ -443,6 +446,7 @@ def webui(): # If we catch a keyboard interrupt, we want to stop the server and exit. shared.demo.close() break + print('Restarting UI...') shared.demo.close() time.sleep(0.5) -- cgit v1.2.3 From 339b5315700a469f4a9f0d5afc08ca2aca60c579 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 27 May 2023 15:47:33 +0300 Subject: custom unet support --- modules/processing.py | 4 +- modules/script_callbacks.py | 20 ++++++++++ modules/sd_hijack.py | 20 +++++++--- modules/sd_models.py | 4 +- modules/sd_unet.py | 92 +++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + modules/shared_items.py | 11 ++++++ webui.py | 4 ++ 8 files changed, 148 insertions(+), 8 deletions(-) create mode 100644 modules/sd_unet.py (limited to 'modules/script_callbacks.py') diff --git a/modules/processing.py b/modules/processing.py index 29a3743f..b75f2515 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -674,6 +674,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() + sd_unet.apply_unet() + if state.job_count == -1: state.job_count = p.n_iter diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 40f388a5..d2728e12 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -111,6 +111,7 @@ callback_map = dict( callbacks_before_ui=[], callbacks_on_reload=[], callbacks_list_optimizers=[], + callbacks_list_unets=[], ) @@ -271,6 +272,18 @@ def list_optimizers_callback(): return res +def list_unets_callback(): + res = [] + + for c in callback_map['callbacks_list_unets']: + try: + c.callback(res) + except Exception: + report_exception(c, 'list_unets') + + return res + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -430,3 +443,10 @@ def on_list_optimizers(callback): to it.""" add_callback(callback_map['callbacks_list_optimizers'], callback) + + +def on_list_unets(callback): + """register a function to be called when UI is making a list of alternative options for unet. + The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" + + add_callback(callback_map['callbacks_list_unets'], callback) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f93df0a6..487dfd60 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,7 +3,7 @@ from torch.nn.functional import silu from types import MethodType import modules.textual_inversion.textual_inversion -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr @@ -43,7 +43,7 @@ def list_optimizers(): optimizers.extend(new_optimizers) -def apply_optimizations(): +def apply_optimizations(option=None): global current_optimizer undo_optimizations() @@ -60,7 +60,7 @@ def apply_optimizations(): current_optimizer.undo() current_optimizer = None - selection = shared.opts.cross_attention_optimization + selection = option or shared.opts.cross_attention_optimization if selection == "Automatic" and len(optimizers) > 0: matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) else: @@ -72,12 +72,13 @@ def apply_optimizations(): matching_optimizer = optimizers[0] if matching_optimizer is not None: - print(f"Applying optimization: {matching_optimizer.name}... ", end='') + print(f"Applying attention optimization: {matching_optimizer.name}... ", end='') matching_optimizer.apply() print("done.") current_optimizer = matching_optimizer return current_optimizer.name else: + print("Disabling attention optimization") return '' @@ -155,9 +156,9 @@ class StableDiffusionModelHijack: def __init__(self): self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) - def apply_optimizations(self): + def apply_optimizations(self, option=None): try: - self.optimization_method = apply_optimizations() + self.optimization_method = apply_optimizations(option) except Exception as e: errors.display(e, "applying cross attention optimization") undo_optimizations() @@ -194,6 +195,11 @@ class StableDiffusionModelHijack: self.layers = flatten(m) + if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'): + ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward + + ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward + def undo_hijack(self, m): if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped @@ -215,6 +221,8 @@ class StableDiffusionModelHijack: self.layers = None self.clip = None + ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui + def apply_circular(self, enable): if self.circular_enabled == enable: return diff --git a/modules/sd_models.py b/modules/sd_models.py index 91b3eb11..835bc016 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -532,6 +532,8 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return + sd_unet.apply_unet("None") + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: diff --git a/modules/sd_unet.py b/modules/sd_unet.py new file mode 100644 index 00000000..6d708ad2 --- /dev/null +++ b/modules/sd_unet.py @@ -0,0 +1,92 @@ +import torch.nn +import ldm.modules.diffusionmodules.openaimodel + +from modules import script_callbacks, shared, devices + +unet_options = [] +current_unet_option = None +current_unet = None + + +def list_unets(): + new_unets = script_callbacks.list_unets_callback() + + unet_options.clear() + unet_options.extend(new_unets) + + +def get_unet_option(option=None): + option = option or shared.opts.sd_unet + + if option == "None": + return None + + if option == "Automatic": + name = shared.sd_model.sd_checkpoint_info.model_name + + options = [x for x in unet_options if x.model_name == name] + + option = options[0].label if options else "None" + + return next(iter([x for x in unet_options if x.label == option]), None) + + +def apply_unet(option=None): + global current_unet_option + global current_unet + + new_option = get_unet_option(option) + if new_option == current_unet_option: + return + + if current_unet is not None: + print(f"Dectivating unet: {current_unet.option.label}") + current_unet.deactivate() + + current_unet_option = new_option + if current_unet_option is None: + current_unet = None + + if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): + shared.sd_model.model.diffusion_model.to(devices.device) + + return + + shared.sd_model.model.diffusion_model.to(devices.cpu) + devices.torch_gc() + + current_unet = current_unet_option.create_unet() + current_unet.option = current_unet_option + print(f"Activating unet: {current_unet.option.label}") + current_unet.activate() + + +class SdUnetOption: + model_name = None + """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" + + label = None + """name of the unet in UI""" + + def create_unet(self): + """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" + raise NotImplementedError() + + +class SdUnet(torch.nn.Module): + def forward(self, x, timesteps, context, *args, **kwargs): + raise NotImplementedError() + + def activate(self): + pass + + def deactivate(self): + pass + + +def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): + if current_unet is not None: + return current_unet.forward(x, timesteps, context, *args, **kwargs) + + return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) + diff --git a/modules/shared.py b/modules/shared.py index 0897f937..a5e7824a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -403,6 +403,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), diff --git a/modules/shared_items.py b/modules/shared_items.py index 2a8713c8..7f306a06 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -29,3 +29,14 @@ def cross_attention_optimizations(): return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] +def sd_unet_items(): + import modules.sd_unet + + return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"] + + +def refresh_unet_list(): + import modules.sd_unet + + modules.sd_unet.list_unets() + diff --git a/webui.py b/webui.py index f9210f41..1e3ff061 100644 --- a/webui.py +++ b/webui.py @@ -58,6 +58,7 @@ import modules.sd_hijack import modules.sd_hijack_optimizations import modules.sd_models import modules.sd_vae +import modules.sd_unet import modules.txt2img import modules.script_callbacks import modules.textual_inversion.textual_inversion @@ -291,6 +292,9 @@ def initialize_rest(*, reload_script_modules=False): modules.sd_hijack.list_optimizers() startup_timer.record("scripts list_optimizers") + modules.sd_unet.list_unets() + startup_timer.record("scripts list_unets") + def load_model(): """ Accesses shared.sd_model property to load model. -- cgit v1.2.3 From 00dfe27f59727407c5b408a80ff2a262934df495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 08:54:13 +0300 Subject: Add & use modules.errors.print_error where currently printing exception info by hand --- extensions-builtin/LDSR/scripts/ldsr_model.py | 7 ++--- extensions-builtin/ScuNET/scripts/scunet_model.py | 6 ++-- modules/api/api.py | 7 +++-- modules/call_queue.py | 22 ++++++-------- modules/codeformer_model.py | 10 +++---- modules/config_states.py | 12 +++----- modules/errors.py | 16 +++++++++++ modules/extensions.py | 10 +++---- modules/gfpgan_model.py | 6 ++-- modules/hypernetworks/hypernetwork.py | 14 ++++----- modules/images.py | 9 ++---- modules/interrogate.py | 5 ++-- modules/launch_utils.py | 7 +++-- modules/localization.py | 6 ++-- modules/processing.py | 2 +- modules/realesrgan_model.py | 14 ++++----- modules/safe.py | 26 +++++++++-------- modules/script_callbacks.py | 9 +++--- modules/script_loading.py | 7 ++--- modules/scripts.py | 35 ++++++++--------------- modules/sd_hijack_optimizations.py | 6 ++-- modules/textual_inversion/textual_inversion.py | 9 ++---- modules/ui.py | 10 +++---- modules/ui_extensions.py | 9 ++---- scripts/prompts_from_file.py | 6 ++-- 25 files changed, 117 insertions(+), 153 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index c4da79f3..95f1669d 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,9 +1,8 @@ import os -import sys -import traceback from basicsr.utils.download_util import load_file_from_url +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks @@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 45d9297b..dd1b822e 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,6 +1,5 @@ import os.path import sys -import traceback import PIL.Image import numpy as np @@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader, script_callbacks from scunet_model_arch import SCUNet as net + +from modules.errors import print_error from modules.shared import opts @@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print(f"Error loading ScuNET model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/modules/api/api.py b/modules/api/api.py index 6a456861..79ce9228 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -16,6 +16,7 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api import models +from modules.errors import print_error from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding @@ -108,7 +109,6 @@ def api_middleware(app: FastAPI): from rich.console import Console console = Console() except Exception: - import traceback rich_available = False @app.middleware("http") @@ -139,11 +139,12 @@ def api_middleware(app: FastAPI): "errors": str(e), } if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions - print(f"API error: {request.method}: {request.url} {err}") + message = f"API error: {request.method}: {request.url} {err}" if rich_available: + print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - traceback.print_exc() + print_error(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") diff --git a/modules/call_queue.py b/modules/call_queue.py index 447bb764..dba2a9b4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,10 +1,9 @@ import html -import sys import threading -import traceback import time from modules import shared, progress +from modules.errors import print_error queue_lock = threading.Lock() @@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): try: res = list(func(*args, **kwargs)) except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {args} {kwargs}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) + # When printing out our debug argument list, + # do not print out more than a 100 KB of text + max_debug_str_len = 131072 + message = "Error completing request" + arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] + if len(arg_str) > max_debug_str_len: + arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" + print_error(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 @@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): return tuple(res) return f - diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ececdbae..76143e9f 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,6 +1,4 @@ import os -import sys -import traceback import cv2 import torch @@ -8,6 +6,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader +from modules.errors import print_error from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,8 +104,8 @@ def setup_model(dirname): restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) + except Exception: + print_error('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -135,7 +134,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print("Error setting up CodeFormer:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index db65bcdb..faeaf28b 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c """ import os -import sys -import traceback import json import time import tqdm @@ -14,6 +12,7 @@ from collections import OrderedDict import git from modules import shared, extensions +from modules.errors import print_error from modules.paths_internal import script_path, config_states_dir @@ -53,8 +52,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -134,8 +132,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -143,8 +140,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/errors.py b/modules/errors.py index da4694f8..41d8dc93 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -1,7 +1,23 @@ import sys +import textwrap import traceback +def print_error( + message: str, + *, + exc_info: bool = False, +) -> None: + """ + Print an error message to stderr, with optional traceback. + """ + for line in message.splitlines(): + print("***", line, file=sys.stderr) + if exc_info: + print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) + print("---") + + def print_error_explanation(message): lines = message.strip().split("\n") max_len = max([len(x) for x in lines]) diff --git a/modules/extensions.py b/modules/extensions.py index 624832a0..369d2584 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,11 +1,10 @@ import os -import sys import threading -import traceback import git from modules import shared +from modules.errors import print_error from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 extensions = [] @@ -56,8 +55,7 @@ class Extension: if os.path.exists(os.path.join(self.path, ".git")): repo = git.Repo(self.path) except Exception: - print(f"Error reading github repository info from {self.path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -72,8 +70,8 @@ class Extension: self.commit_hash = commit.hexsha self.version = self.commit_hash[:8] - except Exception as ex: - print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) + except Exception: + print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 0131dea4..d2f647fe 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import facexlib import gfpgan import modules.face_restoration from modules import paths, shared, devices, modelloader +from modules.errors import print_error model_dir = "GFPGAN" user_path = None @@ -112,5 +111,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print("Error setting up GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up GFPGAN", exc_info=True) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 570b5603..fcc1ef20 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -2,8 +2,6 @@ import datetime import glob import html import os -import sys -import traceback import inspect import modules.textual_inversion.dataset @@ -12,6 +10,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint +from modules.errors import print_error from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -325,17 +324,14 @@ def load_hypernetwork(name): if path is None: return None - hypernetwork = Hypernetwork() - try: + hypernetwork = Hypernetwork() hypernetwork.load(path) + return hypernetwork except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading hypernetwork {path}", exc_info=True) return None - return hypernetwork - def load_hypernetworks(names, multipliers=None): already_loaded = {} @@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print(traceback.format_exc(), file=sys.stderr) + print_error("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index e21e554c..69151bec 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,6 +1,4 @@ import datetime -import sys -import traceback import pytz import io @@ -18,6 +16,7 @@ import json import hashlib from modules import sd_samplers, shared, script_callbacks, errors +from modules.errors import print_error from modules.paths_internal import roboto_ttf_file from modules.shared import opts @@ -464,8 +463,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print(f"Error adding [{pattern}] to filename", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -697,8 +695,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print("Error parsing NovelAI image generation parameters:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index 111b1322..d36e1a5a 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,6 +1,5 @@ import os import sys -import traceback from collections import namedtuple from pathlib import Path import re @@ -12,6 +11,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.errors import print_error blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -216,8 +216,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print("Error interrogating", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 35a52310..22edc106 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -8,6 +8,7 @@ import json from functools import lru_cache from modules import cmd_args +from modules.errors import print_error from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -188,7 +189,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print(e, file=sys.stderr) + print_error(str(e)) def list_extensions(settings_file): @@ -198,8 +199,8 @@ def list_extensions(settings_file): if os.path.isfile(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) - except Exception as e: - print(e, file=sys.stderr) + except Exception: + print_error("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/localization.py b/modules/localization.py index ee9c65e7..9a1df343 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,8 +1,7 @@ import json import os -import sys -import traceback +from modules.errors import print_error localizations = {} @@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print(f"Error loading localization from {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/processing.py b/modules/processing.py index b75f2515..5c9bcce8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,5 @@ import json +import logging import math import os import sys @@ -23,7 +24,6 @@ import modules.images as images import modules.styles import modules.sd_models as sd_models import modules.sd_vae as sd_vae -import logging from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 99983678..c8d0c64f 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import numpy as np from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader @@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print("Error importing Real-ESRGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) return info - except Exception as e: - print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + except Exception: + print_error("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -135,5 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print("Error making Real-ESRGAN models list:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/safe.py b/modules/safe.py index e8f50774..b596f565 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -2,8 +2,6 @@ import pickle import collections -import sys -import traceback import torch import numpy @@ -11,6 +9,8 @@ import _codecs import zipfile import re +from modules.errors import print_error + # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage @@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + "-----> !!!! The file is most likely corrupted !!!! <-----\n" + "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", + exc_info=True, + ) return None - except Exception: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + f"The file may be malicious, so the program is not going to read it.\n" + f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", + exc_info=True, + ) return None return unsafe_torch_load(filename, *args, **kwargs) @@ -190,4 +193,3 @@ with safe.Extra(handler): unsafe_torch_load = torch.load torch.load = load global_extra_handler = None - diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index d2728e12..6aa9c3b6 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,16 +1,15 @@ -import sys -import traceback -from collections import namedtuple import inspect +from collections import namedtuple from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks +from modules.errors import print_error + def report_exception(c, job): - print(f"Error executing callback {job} for {c.script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: diff --git a/modules/script_loading.py b/modules/script_loading.py index 57b15862..26efffcb 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,8 +1,8 @@ import os -import sys -import traceback import importlib.util +from modules.errors import print_error + def load_module(path): module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) @@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print(f"Error running preload() for {preload_script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index c902804b..a7168fd1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,12 +1,12 @@ import os import re import sys -import traceback from collections import namedtuple import gradio as gr from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing +from modules.errors import print_error AlwaysVisible = object() @@ -264,8 +264,7 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath @@ -280,11 +279,9 @@ def load_scripts(): def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: - res = func(*args, **kwargs) - return res + return func(*args, **kwargs) except Exception: - print(f"Error calling: {filename}/{funcname}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -450,8 +447,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print(f"Error running process: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -459,8 +455,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -468,8 +463,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -477,8 +471,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print(f"Error running postprocess: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -486,8 +479,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -495,24 +487,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print(f"Error running before_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print(f"Error running after_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 2ec0b049..fd186fa2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,5 @@ from __future__ import annotations import math -import sys -import traceback import psutil import torch @@ -11,6 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention +from modules.errors import print_error from modules.hypernetworks import hypernetwork import ldm.modules.attention @@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print("Cannot import xformers", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Cannot import xformers", exc_info=True) def get_available_vram(): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d489ed1e..a040a988 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,4 @@ import os -import sys -import traceback from collections import namedtuple import torch @@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint import modules.textual_inversion.dataset +from modules.errors import print_error from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay @@ -207,8 +206,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -632,8 +630,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print(traceback.format_exc(), file=sys.stderr) - pass + print_error("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/ui.py b/modules/ui.py index 001b9792..1ad94f02 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -2,7 +2,6 @@ import json import mimetypes import os import sys -import traceback from functools import reduce import warnings @@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave +from modules.errors import print_error from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: res = all_seeds[index if 0 <= index < len(all_seeds) else 0] except json.decoder.JSONDecodeError: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) + if gen_info_string: + print_error(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -1753,8 +1752,7 @@ def create_ui(): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 515ec262..cadf56be 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,10 +1,8 @@ import json import os.path -import sys import threading import time from datetime import datetime -import traceback import git @@ -14,6 +12,7 @@ import shutil import errno from modules import extensions, shared, paths, config_states +from modules.errors import print_error from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print(f"Error getting updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all @@ -113,8 +111,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print(f"Error checking updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index b918a764..4dc24615 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,13 +1,12 @@ import copy import random -import sys -import traceback import shlex import modules.scripts as scripts import gradio as gr from modules import sd_samplers +from modules.errors import print_error from modules.processing import Processed, process_images from modules.shared import state @@ -136,8 +135,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print(f"Error parsing line {line} as commandline:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} -- cgit v1.2.3 From 05933840f0676dd1a90a7e2ad3f2a0672624b2cd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 31 May 2023 19:56:37 +0300 Subject: rename print_error to report, use it with together with package name --- extensions-builtin/LDSR/scripts/ldsr_model.py | 5 ++--- extensions-builtin/ScuNET/scripts/scunet_model.py | 5 ++--- modules/api/api.py | 5 ++--- modules/call_queue.py | 5 ++--- modules/codeformer_model.py | 7 +++---- modules/config_states.py | 9 ++++----- modules/errors.py | 8 ++------ modules/extensions.py | 7 +++---- modules/gfpgan_model.py | 5 ++--- modules/hypernetworks/hypernetwork.py | 7 +++---- modules/images.py | 5 ++--- modules/interrogate.py | 3 +-- modules/launch_utils.py | 7 +++---- modules/localization.py | 4 ++-- modules/realesrgan_model.py | 10 +++++----- modules/safe.py | 7 ++++--- modules/script_callbacks.py | 4 ++-- modules/script_loading.py | 4 ++-- modules/scripts.py | 23 +++++++++++------------ modules/sd_hijack_optimizations.py | 3 +-- modules/textual_inversion/textual_inversion.py | 7 +++---- modules/ui.py | 7 +++---- modules/ui_extensions.py | 7 +++---- scripts/prompts_from_file.py | 5 ++--- 24 files changed, 69 insertions(+), 90 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 95f1669d..dbd6d331 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -2,10 +2,9 @@ import os from basicsr.utils.download_util import load_file_from_url -from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR -from modules import shared, script_callbacks +from modules import shared, script_callbacks, errors import sd_hijack_autoencoder # noqa: F401 import sd_hijack_ddpm_v1 # noqa: F401 @@ -51,7 +50,7 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) except Exception: - print_error("Error importing LDSR", exc_info=True) + errors.report("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index dd1b822e..85b4505f 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -9,10 +9,9 @@ from tqdm import tqdm from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import devices, modelloader, script_callbacks +from modules import devices, modelloader, script_callbacks, errors from scunet_model_arch import SCUNet as net -from modules.errors import print_error from modules.shared import opts @@ -39,7 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print_error(f"Error loading ScuNET model: {file}", exc_info=True) + errors.report(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/modules/api/api.py b/modules/api/api.py index fbd616a3..d34ab422 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,9 +14,8 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors from modules.api import models -from modules.errors import print_error from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding @@ -145,7 +144,7 @@ def api_middleware(app: FastAPI): print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - print_error(message, exc_info=True) + errors.report(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") diff --git a/modules/call_queue.py b/modules/call_queue.py index dba2a9b4..53af6d70 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -2,8 +2,7 @@ import html import threading import time -from modules import shared, progress -from modules.errors import print_error +from modules import shared, progress, errors queue_lock = threading.Lock() @@ -62,7 +61,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] if len(arg_str) > max_debug_str_len: arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" - print_error(f"{message}\n{arg_str}", exc_info=True) + errors.report(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 76143e9f..4260b016 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -5,8 +5,7 @@ import torch import modules.face_restoration import modules.shared -from modules import shared, devices, modelloader -from modules.errors import print_error +from modules import shared, devices, modelloader, errors from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,7 +104,7 @@ def setup_model(dirname): del output torch.cuda.empty_cache() except Exception: - print_error('Failed inference for CodeFormer', exc_info=True) + errors.report('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -134,6 +133,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print_error("Error setting up CodeFormer", exc_info=True) + errors.report("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index faeaf28b..6f1ab53f 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -11,8 +11,7 @@ from datetime import datetime from collections import OrderedDict import git -from modules import shared, extensions -from modules.errors import print_error +from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir @@ -52,7 +51,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print_error(f"Error reading webui git info from {script_path}", exc_info=True) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -132,7 +131,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print_error(f"Error reading webui git info from {script_path}", exc_info=True) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -140,7 +139,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print_error(f"Error restoring webui to commit{webui_commit_hash}") + errors.report(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/errors.py b/modules/errors.py index 41d8dc93..e408f500 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -3,11 +3,7 @@ import textwrap import traceback -def print_error( - message: str, - *, - exc_info: bool = False, -) -> None: +def report(message: str, *, exc_info: bool = False) -> None: """ Print an error message to stderr, with optional traceback. """ @@ -15,7 +11,7 @@ def print_error( print("***", line, file=sys.stderr) if exc_info: print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) - print("---") + print("---", file=sys.stderr) def print_error_explanation(message): diff --git a/modules/extensions.py b/modules/extensions.py index 92f93ad9..8608584b 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,8 +1,7 @@ import os import threading -from modules import shared -from modules.errors import print_error +from modules import shared, errors from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -54,7 +53,7 @@ class Extension: if os.path.exists(os.path.join(self.path, ".git")): repo = Repo(self.path) except Exception: - print_error(f"Error reading github repository info from {self.path}", exc_info=True) + errors.report(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -70,7 +69,7 @@ class Extension: self.version = self.commit_hash[:8] except Exception: - print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) + errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index d2f647fe..e239a09d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -4,8 +4,7 @@ import facexlib import gfpgan import modules.face_restoration -from modules import paths, shared, devices, modelloader -from modules.errors import print_error +from modules import paths, shared, devices, modelloader, errors model_dir = "GFPGAN" user_path = None @@ -111,4 +110,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print_error("Error setting up GFPGAN", exc_info=True) + errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fcc1ef20..5d12b449 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -9,8 +9,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint -from modules.errors import print_error +from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -329,7 +328,7 @@ def load_hypernetwork(name): hypernetwork.load(path) return hypernetwork except Exception: - print_error(f"Error loading hypernetwork {path}", exc_info=True) + errors.report(f"Error loading hypernetwork {path}", exc_info=True) return None @@ -766,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print_error("Exception in training hypernetwork", exc_info=True) + errors.report("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index 09f728df..30e9ffc5 100644 --- a/modules/images.py +++ b/modules/images.py @@ -16,7 +16,6 @@ import json import hashlib from modules import sd_samplers, shared, script_callbacks, errors -from modules.errors import print_error from modules.paths_internal import roboto_ttf_file from modules.shared import opts @@ -463,7 +462,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print_error(f"Error adding [{pattern}] to filename", exc_info=True) + errors.report(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -698,7 +697,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print_error("Error parsing NovelAI image generation parameters", exc_info=True) + errors.report("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index d36e1a5a..9b2c5b60 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,7 +11,6 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.errors import print_error blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -216,7 +215,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print_error("Error interrogating", exc_info=True) + errors.report("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0bf4cb7e..6e9bb770 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -7,8 +7,7 @@ import platform import json from functools import lru_cache -from modules import cmd_args -from modules.errors import print_error +from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -189,7 +188,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print_error(str(e)) + errors.report(str(e)) def list_extensions(settings_file): @@ -200,7 +199,7 @@ def list_extensions(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) except Exception: - print_error("Could not load settings", exc_info=True) + errors.report("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/localization.py b/modules/localization.py index 9a1df343..e8f585da 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,7 +1,7 @@ import json import os -from modules.errors import print_error +from modules import errors localizations = {} @@ -30,6 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print_error(f"Error loading localization from {fn}", exc_info=True) + errors.report(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index c8d0c64f..2d27b321 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -5,10 +5,10 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer -from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts -from modules import modelloader +from modules import modelloader, errors + class UpscalerRealESRGAN(Upscaler): def __init__(self, path): @@ -35,7 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print_error("Error importing Real-ESRGAN", exc_info=True) + errors.report("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -75,7 +75,7 @@ class UpscalerRealESRGAN(Upscaler): return info except Exception: - print_error("Error making Real-ESRGAN models list", exc_info=True) + errors.report("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -132,4 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print_error("Error making Real-ESRGAN models list", exc_info=True) + errors.report("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/safe.py b/modules/safe.py index b596f565..b1d08a79 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -9,9 +9,10 @@ import _codecs import zipfile import re -from modules.errors import print_error # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +from modules import errors + TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage def encode(*args): @@ -136,7 +137,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print_error( + errors.report( f"Error verifying pickled file from {filename}\n" "-----> !!!! The file is most likely corrupted !!!! <-----\n" "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", @@ -144,7 +145,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): ) return None except Exception: - print_error( + errors.report( f"Error verifying pickled file from {filename}\n" f"The file may be malicious, so the program is not going to read it.\n" f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6aa9c3b6..ec1469d0 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -5,11 +5,11 @@ from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks -from modules.errors import print_error +from modules import errors def report_exception(c, job): - print_error(f"Error executing callback {job} for {c.script}", exc_info=True) + errors.report(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: diff --git a/modules/script_loading.py b/modules/script_loading.py index 26efffcb..306a1f35 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,7 +1,7 @@ import os import importlib.util -from modules.errors import print_error +from modules import errors def load_module(path): @@ -27,4 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print_error(f"Error running preload() for {preload_script}", exc_info=True) + errors.report(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index a7168fd1..0970f38e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -5,8 +5,7 @@ from collections import namedtuple import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing -from modules.errors import print_error +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors AlwaysVisible = object() @@ -264,7 +263,7 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print_error(f"Error loading script: {scriptfile.filename}", exc_info=True) + errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath @@ -281,7 +280,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: return func(*args, **kwargs) except Exception: - print_error(f"Error calling: {filename}/{funcname}", exc_info=True) + errors.report(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -447,7 +446,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print_error(f"Error running process: {script.filename}", exc_info=True) + errors.report(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -455,7 +454,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print_error(f"Error running before_process_batch: {script.filename}", exc_info=True) + errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -463,7 +462,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print_error(f"Error running process_batch: {script.filename}", exc_info=True) + errors.report(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -471,7 +470,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print_error(f"Error running postprocess: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -479,7 +478,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -487,21 +486,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print_error(f"Error running postprocess_image: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print_error(f"Error running before_component: {script.filename}", exc_info=True) + errors.report(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print_error(f"Error running after_component: {script.filename}", exc_info=True) + errors.report(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index fd186fa2..5f0ff513 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,6 @@ from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention -from modules.errors import print_error from modules.hypernetworks import hypernetwork import ldm.modules.attention @@ -139,7 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print_error("Cannot import xformers", exc_info=True) + errors.report("Cannot import xformers", exc_info=True) def get_available_vram(): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b3dcb140..8da050ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,9 +12,8 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors import modules.textual_inversion.dataset -from modules.errors import print_error from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay @@ -219,7 +218,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print_error(f"Error loading embedding {fn}", exc_info=True) + errors.report(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -643,7 +642,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print_error("Error training embedding", exc_info=True) + errors.report("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/ui.py b/modules/ui.py index fb6b2498..f361264c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,8 +12,7 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave -from modules.errors import print_error +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -232,7 +231,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: except json.decoder.JSONDecodeError: if gen_info_string: - print_error(f"Error parsing JSON generation info: {gen_info_string}") + errors.report(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -1752,7 +1751,7 @@ def create_ui(): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print_error("Error loading/saving model file", exc_info=True) + errors.report("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index e2ee9d72..3140ed64 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -11,8 +11,7 @@ import html import shutil import errno -from modules import extensions, shared, paths, config_states -from modules.errors import print_error +from modules import extensions, shared, paths, config_states, errors from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -45,7 +44,7 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print_error(f"Error getting updates for {ext.name}", exc_info=True) + errors.report(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all @@ -111,7 +110,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print_error(f"Error checking updates for {ext.name}", exc_info=True) + errors.report(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 4dc24615..83a2f220 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -5,8 +5,7 @@ import shlex import modules.scripts as scripts import gradio as gr -from modules import sd_samplers -from modules.errors import print_error +from modules import sd_samplers, errors from modules.processing import Processed, process_images from modules.shared import state @@ -135,7 +134,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print_error(f"Error parsing line {line} as commandline", exc_info=True) + errors.report(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} -- cgit v1.2.3 From a5e851028e23e411c392a66e7c791e388b3e4aba Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 1 Jun 2023 10:01:42 +0300 Subject: add hiding and a colspans to startup profile table --- javascript/profilerVisualization.js | 84 ++++++++++++++++++++++++++++++++----- modules/script_callbacks.py | 3 +- style.css | 6 +++ 3 files changed, 81 insertions(+), 12 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/javascript/profilerVisualization.js b/javascript/profilerVisualization.js index 1bd75986..9d8e5f42 100644 --- a/javascript/profilerVisualization.js +++ b/javascript/profilerVisualization.js @@ -3,11 +3,29 @@ function createRow(table, cellName, items) { var tr = document.createElement('tr'); var res = []; - items.forEach(function(x) { + items.forEach(function(x, i) { + if (x === undefined) { + res.push(null); + return; + } + var td = document.createElement(cellName); td.textContent = x; tr.appendChild(td); res.push(td); + + var colspan = 1; + for (var n = i + 1; n < items.length; n++) { + if (items[n] !== undefined) { + break; + } + + colspan += 1; + } + + if (colspan > 1) { + td.colSpan = colspan; + } }); table.appendChild(tr); @@ -15,7 +33,7 @@ function createRow(table, cellName, items) { return res; } -function showProfile(path, cutoff = 0.0005) { +function showProfile(path, cutoff = 0.05) { requestGet(path, {}, function(data) { var table = document.createElement('table'); table.className = 'popup-table'; @@ -38,7 +56,7 @@ function showProfile(path, cutoff = 0.0005) { return !(a < b || b < a); } - var addLevel = function(level, parent) { + var addLevel = function(level, parent, hide) { var matching = items.filter(function(x) { return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent); }); @@ -47,12 +65,10 @@ function showProfile(path, cutoff = 0.0005) { }); var othersTime = 0; var othersList = []; + var othersRows = []; + var childrenRows = []; sorted.forEach(function(x) { - if (x.time < cutoff) { - othersTime += x.time; - othersList.push(x.parts[level]); - return; - } + var visible = x.time >= cutoff && !hide; var cells = []; for (var i = 0; i < maxLength; i++) { @@ -64,7 +80,32 @@ function showProfile(path, cutoff = 0.0005) { cols[i].className = 'muted'; } - addLevel(level + 1, parent.concat([x.parts[level]])); + var tr = cols[0].parentNode; + if (!visible) { + tr.classList.add("hidden"); + } + + if (x.time >= cutoff) { + childrenRows.push(tr); + } else { + othersTime += x.time; + othersList.push(x.parts[level]); + othersRows.push(tr); + } + + var children = addLevel(level + 1, parent.concat([x.parts[level]]), true); + if (children.length > 0) { + var cell = cols[level]; + var onclick = function() { + cell.classList.remove("link"); + cell.removeEventListener("click", onclick); + children.forEach(function(x) { + x.classList.remove("hidden"); + }); + }; + cell.classList.add("link"); + cell.addEventListener("click", onclick); + } }); if (othersTime > 0) { @@ -73,14 +114,35 @@ function showProfile(path, cutoff = 0.0005) { cells.push(parent[i]); } cells.push(othersTime.toFixed(3)); + cells[level] = 'others'; var cols = createRow(table, 'td', cells); for (i = 0; i < level; i++) { cols[i].className = 'muted'; } - cols[level].textContent = 'others'; - cols[level].title = othersList.join(", "); + var cell = cols[level]; + var tr = cell.parentNode; + var onclick = function() { + tr.classList.add("hidden"); + cell.classList.remove("link"); + cell.removeEventListener("click", onclick); + othersRows.forEach(function(x) { + x.classList.remove("hidden"); + }); + }; + + cell.title = othersList.join(", "); + cell.classList.add("link"); + cell.addEventListener("click", onclick); + + if (hide) { + tr.classList.add("hidden"); + } + + childrenRows.push(tr); } + + return childrenRows; }; addLevel(0, []); diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 54824582..f755283c 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,4 +1,5 @@ import inspect +import os from collections import namedtuple from typing import Optional, Dict, Any @@ -123,7 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: c.callback(demo, app) - timer.startup_timer.record(c.script) + timer.startup_timer.record(os.path.basename(c.script)) except Exception: report_exception(c, 'app_started_callback') diff --git a/style.css b/style.css index cd7491ba..34b85b80 100644 --- a/style.css +++ b/style.css @@ -420,6 +420,12 @@ table.popup-table .muted{ color: #aaa; } +table.popup-table .link{ + text-decoration: underline; + cursor: pointer; + font-weight: bold; +} + .ui-defaults-none{ color: #aaa !important; } -- cgit v1.2.3 From 51864790fd72386fbbbb015d24a43ce501ecaa4b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 2 Jun 2023 14:58:10 +0300 Subject: Simplify a bunch of `len(x) > 0`/`len(x) == 0` style expressions --- extensions-builtin/LDSR/sd_hijack_autoencoder.py | 3 ++- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 4 ++-- extensions-builtin/Lora/extra_networks_lora.py | 4 ++-- extensions-builtin/Lora/lora.py | 4 ++-- .../extra-options-section/scripts/extra_options_section.py | 2 +- modules/api/api.py | 2 +- modules/call_queue.py | 2 +- modules/extra_networks_hypernet.py | 4 ++-- modules/generation_parameters_copypaste.py | 6 ++---- modules/images.py | 6 +++--- modules/img2img.py | 3 +-- modules/models/diffusion/ddpm_edit.py | 4 ++-- modules/processing.py | 3 ++- modules/prompt_parser.py | 6 +++--- modules/script_callbacks.py | 4 ++-- modules/sd_hijack_clip.py | 2 +- modules/sd_hijack_clip_old.py | 2 +- modules/textual_inversion/autocrop.py | 14 +++++++------- modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/preprocess.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- modules/ui.py | 2 +- modules/ui_extensions.py | 5 +++-- modules/ui_settings.py | 2 +- scripts/prompts_from_file.py | 3 +-- 25 files changed, 47 insertions(+), 48 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 27a86e13..c29d274d 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -91,8 +91,9 @@ class VQModel(pl.LightningModule): del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") + if unexpected: print(f"Unexpected Keys: {unexpected}") def on_train_batch_end(self, *args, **kwargs): diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 631a08ef..04adc5eb 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index b5fea4d2..66ee9c85 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index eec14712..af93991c 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk): else: raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") - if len(keys_failed_to_match) > 0: + if keys_failed_to_match: print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") return lora @@ -267,7 +267,7 @@ def load_loras(names, multipliers=None): lora.multiplier = multipliers[i] if multipliers else 1.0 loaded_loras.append(lora) - if len(failed_to_load_loras) > 0: + if failed_to_load_loras: sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 17f84184..a05e10d8 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -21,7 +21,7 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row(): + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): for setting_name in shared.opts.extra_options: with FormColumn(): comp = ui_settings.create_setting_component(setting_name) diff --git a/modules/api/api.py b/modules/api/api.py index d34ab422..555eefdb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -280,7 +280,7 @@ class Api: script_args[0] = selectable_idx + 1 # Now check for always on scripts - if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + if request.alwayson_scripts: for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script is None: diff --git a/modules/call_queue.py b/modules/call_queue.py index 53af6d70..1b5e5273 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -21,7 +21,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id - if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): id_task = args[0] progress.add_task_to_queue(id_task) else: diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index aa2a14ef..b6a6dc0e 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_hypernetwork - if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional): hypernet_prompt_text = f"" p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 071bd9ea..237401a1 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -55,7 +55,7 @@ def image_from_url_text(filedata): if filedata is None: return None - if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): + if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] if type(filedata) == dict and filedata.get("is_file", False): @@ -437,7 +437,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, vals_pairs = [f"{k}: {v}" for k, v in vals.items()] - return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) paste_fields = paste_fields + [(override_settings_component, paste_settings)] @@ -454,5 +454,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, outputs=[], show_progress=False, ) - - diff --git a/modules/images.py b/modules/images.py index a12d252b..7bbfc3e0 100644 --- a/modules/images.py +++ b/modules/images.py @@ -406,7 +406,7 @@ class FilenameGenerator: prompt_no_style = self.prompt for style in shared.prompt_styles.get_style_prompts(self.p.styles): - if len(style) > 0: + if style: for part in style.split("{prompt}"): prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') @@ -415,7 +415,7 @@ class FilenameGenerator: return sanitize_filename_part(prompt_no_style, replace_spaces=False) def prompt_words(self): - words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0] + words = [x for x in re_nonletters.split(self.prompt or "") if x] if len(words) == 0: words = ["empty"] return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False) @@ -423,7 +423,7 @@ class FilenameGenerator: def datetime(self, *args): time_datetime = datetime.datetime.now() - time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format + time_format = args[0] if (args and args[0] != "") else self.default_time_format try: time_zone = pytz.timezone(args[1]) if len(args) > 1 else None except pytz.exceptions.UnknownTimeZoneError: diff --git a/modules/img2img.py b/modules/img2img.py index 4c12c2c5..35c4facc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -21,8 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): is_inpaint_batch = False if inpaint_mask_dir: inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = len(inpaint_masks) > 0 - if is_inpaint_batch: + is_inpaint_batch = bool(inpaint_masks) print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 3fb76b65..b892d5fc 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -230,9 +230,9 @@ class DDPM(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/modules/processing.py b/modules/processing.py index 362ab4c2..9ebdb549 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -975,7 +975,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and latent_scale_mode is None: - assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b4aff704..0069d8b0 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -336,11 +336,11 @@ def parse_prompt_attention(text): round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: + elif weight is not None and round_brackets: multiply_range(round_brackets.pop(), float(weight)) - elif text == ')' and len(round_brackets) > 0: + elif text == ')' and round_brackets: multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == ']' and len(square_brackets) > 0: + elif text == ']' and square_brackets: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: parts = re.split(re_break, text) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index f755283c..77ee55ee 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -287,14 +287,14 @@ def list_unets_callback(): def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' callbacks.append(ScriptCallback(filename, fun)) def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' if filename == 'unknown file': return for callback_list in callback_map.values(): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index cc6e8c21..3b5a7666 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0 or len(chunks) == 0: + if chunk.tokens or not chunks: next_chunk(is_last=True) return chunks, token_count diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index a3476e95..c5c6270b 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text self.hijack.comments += hijack_comments - if len(used_custom_terms) > 0: + if used_custom_terms: embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) self.hijack.comments.append(f"Used embeddings: {embedding_names}") diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 8e667a4d..75705459 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -77,27 +77,27 @@ def focal_point(im, settings): pois = [] weight_pref_total = 0 - if len(corner_points) > 0: + if corner_points: weight_pref_total += settings.corner_points_weight - if len(entropy_points) > 0: + if entropy_points: weight_pref_total += settings.entropy_points_weight - if len(face_points) > 0: + if face_points: weight_pref_total += settings.face_points_weight corner_centroid = None - if len(corner_points) > 0: + if corner_points: corner_centroid = centroid(corner_points) corner_centroid.weight = settings.corner_points_weight / weight_pref_total pois.append(corner_centroid) entropy_centroid = None - if len(entropy_points) > 0: + if entropy_points: entropy_centroid = centroid(entropy_points) entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total pois.append(entropy_centroid) face_centroid = None - if len(face_points) > 0: + if face_points: face_centroid = centroid(face_points) face_centroid.weight = settings.face_points_weight / weight_pref_total pois.append(face_centroid) @@ -187,7 +187,7 @@ def image_face_points(im, settings): except Exception: continue - if len(faces) > 0: + if faces: rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [] diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b9621fc9..7ee05061 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -32,7 +32,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None self.placeholder_token = placeholder_token diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a009d8e8..0d4c3f84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption += shared.interrogator.generate_caption(image) if params.process_caption_deepbooru: - if len(caption) > 0: + if caption: caption += ", " caption += deepbooru.model.tag_multi(image) @@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption = caption.strip() - if len(caption) > 0: + if caption: with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8da050ca..bb6f211c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -251,7 +251,7 @@ class EmbeddingDatabase: if self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: + if self.skipped_embeddings: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): diff --git a/modules/ui.py b/modules/ui.py index b7459f08..9a025cca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -398,7 +398,7 @@ def create_override_settings_dropdown(tabname, row): dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) dropdown.change( - fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + fn=lambda x: gr.Dropdown.update(visible=bool(x)), inputs=[dropdown], outputs=[dropdown], ) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 3140ed64..65173e06 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -333,7 +333,8 @@ def install_extension_from_url(dirname, url, branch_name=None): assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' normalized_url = normalize_git_url(url) - assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' + if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url): + raise Exception(f'Extension with this URL is already installed: {url}') tmpdir = os.path.join(paths.data_path, "tmp", dirname) @@ -449,7 +450,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" existing = installed_extension_urls.get(normalize_git_url(url), None) extension_tags = extension_tags + ["installed"] if existing else extension_tags - if len([x for x in extension_tags if x in tags_to_hide]) > 0: + if any(x for x in extension_tags if x in tags_to_hide): hidden += 1 continue diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 7874298e..2688d8c2 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -81,7 +81,7 @@ class UiSettings: opts.save(shared.config_filename) except RuntimeError: return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.' def run_settings_single(self, value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 83a2f220..50320d55 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -121,8 +121,7 @@ class Script(scripts.Script): return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): - lines = [x.strip() for x in prompt_txt.splitlines()] - lines = [x for x in lines if len(x) > 0] + lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True -- cgit v1.2.3