From 96d6ca4199e7c5eee8d451618de5161cea317c40 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 10 May 2023 08:25:25 +0300
Subject: manual fixes for ruff
---
modules/interrogate.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'modules/interrogate.py')
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9f7d657f..22df9216 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -11,7 +11,6 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-import modules.shared as shared
from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384
--
cgit v1.2.3
From 028d3f6425d85f122027c127fba8bcbf4f66ee75 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 10 May 2023 11:05:02 +0300
Subject: ruff auto fixes
---
extensions-builtin/LDSR/sd_hijack_autoencoder.py | 4 ++--
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 12 ++++++------
extensions-builtin/Lora/lora.py | 12 ++++++------
extensions-builtin/Lora/scripts/lora_script.py | 2 +-
modules/config_states.py | 2 +-
modules/deepbooru.py | 2 +-
modules/devices.py | 2 +-
modules/hypernetworks/hypernetwork.py | 2 +-
modules/hypernetworks/ui.py | 4 ++--
modules/interrogate.py | 2 +-
modules/modelloader.py | 2 +-
modules/models/diffusion/ddpm_edit.py | 4 ++--
modules/scripts_auto_postprocessing.py | 2 +-
modules/sd_hijack.py | 2 +-
modules/sd_hijack_optimizations.py | 14 +++++++-------
modules/sd_samplers_compvis.py | 2 +-
modules/sd_samplers_kdiffusion.py | 2 +-
modules/shared.py | 6 +++---
modules/textual_inversion/textual_inversion.py | 2 +-
modules/ui.py | 8 ++++----
modules/ui_extra_networks.py | 4 ++--
modules/ui_tempdir.py | 2 +-
22 files changed, 47 insertions(+), 47 deletions(-)
(limited to 'modules/interrogate.py')
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index 6303fed5..f457ca93 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -288,5 +288,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant)
return dec
-setattr(ldm.models.autoencoder, "VQModel", VQModel)
-setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
+ldm.models.autoencoder.VQModel = VQModel
+ldm.models.autoencoder.VQModelInterface = VQModelInterface
diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
index 4d3f6c56..d8fc30e3 100644
--- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
+++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
@@ -1116,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@@ -1215,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
@@ -1437,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img
return logs
-setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
-setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
-setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
-setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
+ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
+ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
+ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
+ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 0ab43229..9795540f 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -172,7 +172,7 @@ def load_lora(name, filename):
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -184,7 +184,7 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -202,7 +202,7 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ if any(x is None for x in loras_on_disk):
list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
@@ -309,7 +309,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
@@ -343,8 +343,8 @@ def lora_forward(module, input, original_forward):
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 7db971fd..b70e2de7 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(lora.available_loras)}, refresh=lora.list_available_loras),
}))
diff --git a/modules/config_states.py b/modules/config_states.py
index 8f1ff428..75da862a 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -35,7 +35,7 @@ def list_config_states():
j["filepath"] = path
config_states.append(j)
- config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
+ config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"]))
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 1c4554a2..547e1b4c 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -78,7 +78,7 @@ class DeepDanbooru:
res = []
- filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
+ filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
diff --git a/modules/devices.py b/modules/devices.py
index c705a3cb..d8a34a0f 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -65,7 +65,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+ if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 9fe749b7..6ef0bfdf 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -403,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
k = self.to_k(context_k)
v = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index be168736..e3f9eb13 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -5,13 +5,13 @@ import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available]
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
+ return gr.Dropdown.update(choices=sorted(shared.hypernetworks.keys())), f"Created: {filename}", ""
def train_hypernetwork(*args):
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 22df9216..a1c8e537 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -159,7 +159,7 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
+ text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 92ada694..25612bf8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -39,7 +39,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if os.path.islink(full_path) and not os.path.exists(full_path):
print(f"Skipping broken symlink: {full_path}")
continue
- if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
+ if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
continue
if full_path not in output:
output.append(full_path)
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py
index 611c2b69..09432117 100644
--- a/modules/models/diffusion/ddpm_edit.py
+++ b/modules/models/diffusion/ddpm_edit.py
@@ -1130,7 +1130,7 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@@ -1229,7 +1229,7 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py
index 30d6d658..d63078de 100644
--- a/modules/scripts_auto_postprocessing.py
+++ b/modules/scripts_auto_postprocessing.py
@@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args):
- args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
+ args_dict = dict(zip(self.postprocessing_controls, args))
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 81573b78..e374aeb8 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -37,7 +37,7 @@ def apply_optimizations():
optimization_method = None
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index b623d53d..a174bbe1 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
v_in = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -98,7 +98,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
del context, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
@@ -229,7 +229,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@@ -334,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -460,7 +460,7 @@ def xformers_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -482,7 +482,7 @@ def sdp_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -506,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index bfcc5574..7427648f 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
- assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
+ assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 3b8e9622..2f733cf5 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -86,7 +86,7 @@ class CFGDenoiser(torch.nn.Module):
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
diff --git a/modules/shared.py b/modules/shared.py
index 7d70f041..e2691585 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -381,7 +381,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(hypernetworks.keys())}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
@@ -403,7 +403,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
- "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
@@ -583,7 +583,7 @@ class Options:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
- self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+ self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 9ed9ba45..c37bb2ad 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -167,7 +167,7 @@ class EmbeddingDatabase:
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ param_dict = param_dict._parameters # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
diff --git a/modules/ui.py b/modules/ui.py
index 782b569d..84d661b2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1222,7 +1222,7 @@ def create_ui():
)
def get_textual_inversion_template_names():
- return sorted([x for x in textual_inversion.textual_inversion_templates])
+ return sorted(textual_inversion.textual_inversion_templates)
with gr.Tab(label="Train", id="train"):
gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
@@ -1230,8 +1230,8 @@ def create_ui():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
- train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
- create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=list(shared.hypernetworks.keys()))
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks.keys())}, "refresh_train_hypernetwork_name")
with FormRow():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
@@ -1808,7 +1808,7 @@ def create_ui():
if type(x) == gr.Dropdown:
def check_dropdown(val):
if getattr(x, 'multiselect', False):
- return all([value in x.choices for value in val])
+ return all(value in x.choices for value in val)
else:
return val in x.choices
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 800e467a..ab585917 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -26,7 +26,7 @@ def register_page(page):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
- if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
@@ -326,7 +326,7 @@ def setup_ui(ui, gallery):
is_allowed = False
for extra_page in ui.stored_extra_pages:
- if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
is_allowed = True
break
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 46fa9cb0..cac73c51 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -23,7 +23,7 @@ def register_tmp_file(gradio, filename):
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
- return any([filename in fileset for fileset in gradio.temp_file_sets])
+ return any(filename in fileset for fileset in gradio.temp_file_sets)
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
--
cgit v1.2.3
From a5121e7a0623db328a9462d340d389ed6737374a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 10 May 2023 11:37:18 +0300
Subject: fixes for B007
---
extensions-builtin/LDSR/ldsr_model_arch.py | 2 +-
extensions-builtin/Lora/lora.py | 2 +-
extensions-builtin/ScuNET/scripts/scunet_model.py | 2 +-
extensions-builtin/SwinIR/swinir_model_arch.py | 2 +-
extensions-builtin/SwinIR/swinir_model_arch_v2.py | 2 +-
modules/codeformer_model.py | 2 +-
modules/esrgan_model.py | 8 ++------
modules/extra_networks.py | 2 +-
modules/generation_parameters_copypaste.py | 2 +-
modules/hypernetworks/hypernetwork.py | 12 ++++++------
modules/images.py | 2 +-
modules/interrogate.py | 4 ++--
modules/prompt_parser.py | 14 +++++++-------
modules/safe.py | 4 ++--
modules/scripts.py | 10 +++++-----
modules/scripts_postprocessing.py | 8 ++++----
modules/sd_hijack_clip.py | 2 +-
modules/shared.py | 6 +++---
modules/textual_inversion/learn_schedule.py | 2 +-
modules/textual_inversion/textual_inversion.py | 10 +++++-----
modules/ui.py | 6 +++---
modules/ui_extra_networks.py | 2 +-
modules/ui_tempdir.py | 2 +-
modules/upscaler.py | 2 +-
pyproject.toml | 1 -
scripts/prompts_from_file.py | 2 +-
scripts/sd_upscale.py | 4 ++--
scripts/xyz_grid.py | 2 +-
28 files changed, 57 insertions(+), 62 deletions(-)
(limited to 'modules/interrogate.py')
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index a5fb8907..27e38549 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -88,7 +88,7 @@ class LDSR:
x_t = None
logs = None
- for n in range(n_runs):
+ for _ in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 9795540f..7b56136f 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -418,7 +418,7 @@ def infotext_pasted(infotext, params):
added = []
- for k, v in params.items():
+ for k in params:
if not k.startswith("AddNet Model "):
continue
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index aa2fdb3a..1f5ea0d3 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -132,7 +132,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
- for k, v in model.named_parameters():
+ for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
index 75f7bedc..de195d9b 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch.py
@@ -848,7 +848,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
index d4c0b0da..15777af9 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
@@ -1001,7 +1001,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 8e56cb89..ececdbae 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -94,7 +94,7 @@ def setup_model(dirname):
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
- for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 85aa6934..a009eb42 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -16,9 +16,7 @@ def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 1978673d..f9db41bc 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
- for extra_network_name, extra_network_args in extra_network_data.items():
+ for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 7fbbe707..b0e945a1 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
lines.append(lastline)
lastline = ''
- for i, line in enumerate(lines):
+ for line in lines:
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6ef0bfdf..38ef074f 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -177,34 +177,34 @@ class Hypernetwork:
def weights(self):
res = []
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
res += layer.parameters()
return res
def train(self, mode=True):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.train(mode=mode)
for param in layer.parameters():
param.requires_grad = mode
def to(self, device):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.multiplier = multiplier
return self
def eval(self):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.eval()
for param in layer.parameters():
@@ -619,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
diff --git a/modules/images.py b/modules/images.py
index 7392cb8b..c4e98c75 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -149,7 +149,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
return ImageFont.truetype(Roboto, fontsize)
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
- for i, line in enumerate(lines):
+ for line in lines:
fnt = initial_fnt
fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
diff --git a/modules/interrogate.py b/modules/interrogate.py
index a1c8e537..111b1322 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -207,8 +207,8 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True)
- for name, topn, items in self.categories():
- matches = self.rank(image_features, items, top_count=topn)
+ for cat in self.categories():
+ matches = self.rank(image_features, cat.items, top_count=cat.topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})"
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 3a720721..b4aff704 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -143,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
conds = model.get_learned_conditioning(texts)
cond_schedule = []
- for i, (end_at_step, text) in enumerate(prompt_schedule):
+ for i, (end_at_step, _) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule
@@ -219,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
target_index = 0
- for current, (end_at, cond) in enumerate(cond_schedule):
- if current_step <= end_at:
+ for current, entry in enumerate(cond_schedule):
+ if current_step <= entry.end_at_step:
target_index = current
break
res[i] = cond_schedule[target_index].cond
@@ -234,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
tensors = []
conds_list = []
- for batch_no, composable_prompts in enumerate(c.batch):
+ for composable_prompts in c.batch:
conds_for_batch = []
- for cond_index, composable_prompt in enumerate(composable_prompts):
+ for composable_prompt in composable_prompts:
target_index = 0
- for current, (end_at, cond) in enumerate(composable_prompt.schedules):
- if current_step <= end_at:
+ for current, entry in enumerate(composable_prompt.schedules):
+ if current_step <= entry.end_at_step:
target_index = current
break
diff --git a/modules/safe.py b/modules/safe.py
index 2d5b972f..1e791c5b 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -95,11 +95,11 @@ def check_pt(filename, extra_handler):
except zipfile.BadZipfile:
- # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
- for i in range(5):
+ for _ in range(5):
unpickler.load()
diff --git a/modules/scripts.py b/modules/scripts.py
index d945b89f..0c12ebd5 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -231,7 +231,7 @@ def load_scripts():
syspath = sys.path
def register_scripts_from_module(module):
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) != type:
continue
@@ -295,9 +295,9 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
- for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
- script = script_class()
- script.filename = path
+ for script_data in auto_processing_scripts + scripts_data:
+ script = script_data.script_class()
+ script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -492,7 +492,7 @@ class ScriptRunner:
module = script_loading.load_module(script.filename)
cache[filename] = module
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index b11568c0..6751406c 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
def initialize_scripts(self, scripts_data):
self.scripts = []
- for script_class, path, basedir, script_module in scripts_data:
- script: ScriptPostprocessing = script_class()
- script.filename = path
+ for script_data in scripts_data:
+ script: ScriptPostprocessing = script_data.script_class()
+ script.filename = script_data.path
if script.name == "Simple Upscale":
continue
@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
script_args = args[script.args_from:script.args_to]
process_args = {}
- for (name, component), value in zip(script.controls.items(), script_args):
+ for (name, component), value in zip(script.controls.items(), script_args): # noqa B007
process_args[name] = value
script.process(pp, **process_args)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 9fa5c5c5..c0c350f6 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes:
- for position, embedding in fixes:
+ for position, embedding in fixes: # noqa: B007
used_embeddings[embedding.name] = embedding
z = self.process_tokens(tokens, multipliers)
diff --git a/modules/shared.py b/modules/shared.py
index e2691585..913c9e63 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -211,7 +211,7 @@ class OptionInfo:
def options_section(section_identifier, options_dict):
- for k, v in options_dict.items():
+ for v in options_dict.values():
v.section = section_identifier
return options_dict
@@ -579,7 +579,7 @@ class Options:
section_ids = {}
settings_items = self.data_labels.items()
- for k, item in settings_items:
+ for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
@@ -740,7 +740,7 @@ def walk_files(path, allowed_extensions=None):
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
- for root, dirs, files in os.walk(path):
+ for root, _, files in os.walk(path):
for filename in files:
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index fda58898..c56bea45 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -12,7 +12,7 @@ class LearnScheduleIterator:
self.it = 0
self.maxit = 0
try:
- for i, pair in enumerate(pairs):
+ for pair in pairs:
if not pair.strip():
continue
tmp = pair.split(':')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c37bb2ad..47035332 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -29,7 +29,7 @@ textual_inversion_templates = {}
def list_textual_inversion_templates():
textual_inversion_templates.clear()
- for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
for fn in fns:
path = os.path.join(root, fn)
@@ -198,7 +198,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path):
return
- for root, dirs, fns in os.walk(embdir.path, followlinks=True):
+ for root, _, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -215,7 +215,7 @@ class EmbeddingDatabase:
def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
if embdir.has_changed():
need_reload = True
break
@@ -228,7 +228,7 @@ class EmbeddingDatabase:
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
self.load_from_dir(embdir)
embdir.update()
@@ -469,7 +469,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
diff --git a/modules/ui.py b/modules/ui.py
index 84d661b2..83bfb7d8 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -416,7 +416,7 @@ def create_sampler_and_steps_selection(choices, tabname):
def ordered_ui_categories():
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
+ for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
@@ -1646,7 +1646,7 @@ def create_ui():
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"):
- for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
+ for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
@@ -1673,7 +1673,7 @@ def create_ui():
outputs=[text_settings, result],
)
- for i, k, item in quicksettings_list:
+ for _i, k, _item in quicksettings_list:
component = component_dict[k]
info = opts.data_labels[k]
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index ab585917..2fd82e8e 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -90,7 +90,7 @@ class ExtraNetworksPage:
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for root, dirs, files in os.walk(parentdir):
+ for root, dirs, _ in os.walk(parentdir):
for dirname in dirs:
x = os.path.join(root, dirname)
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index cac73c51..f05049e1 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -72,7 +72,7 @@ def cleanup_tmpdr():
if temp_dir == "" or not os.path.isdir(temp_dir):
return
- for root, dirs, files in os.walk(temp_dir, topdown=False):
+ for root, _, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
diff --git a/modules/upscaler.py b/modules/upscaler.py
index e145be30..8acb6e96 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -55,7 +55,7 @@ class Upscaler:
dest_w = int(img.width * scale)
dest_h = int(img.height * scale)
- for i in range(3):
+ for _ in range(3):
shape = (img.width, img.height)
img = self.do_upscale(img, selected_model)
diff --git a/pyproject.toml b/pyproject.toml
index 346a0cde..c88907be 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,6 @@ ignore = [
"I001", # Import block is un-sorted or un-formatted
"C901", # Function is too complex
"C408", # Rewrite as a literal
- "B007", # Loop control variable not used within loop body
]
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 149bc85f..27af5ff6 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -156,7 +156,7 @@ class Script(scripts.Script):
images = []
all_prompts = []
infotexts = []
- for n, args in enumerate(jobs):
+ for args in jobs:
state.job = f"{state.job_no + 1} out of {state.job_count}"
copy_p = copy.copy(p)
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index d873a09c..0b1d3096 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -56,7 +56,7 @@ class Script(scripts.Script):
work = []
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
work.append(tiledata[2])
@@ -85,7 +85,7 @@ class Script(scripts.Script):
work_results += processed.images
image_index = 0
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
image_index += 1
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 332e0ecd..38a20381 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -704,7 +704,7 @@ class Script(scripts.Script):
if not include_sub_grids:
# Done with sub-grids, drop all related information:
- for sg in range(z_count):
+ for _ in range(z_count):
del processed.images[1]
del processed.all_prompts[1]
del processed.all_seeds[1]
--
cgit v1.2.3
From 00dfe27f59727407c5b408a80ff2a262934df495 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Mon, 29 May 2023 08:54:13 +0300
Subject: Add & use modules.errors.print_error where currently printing
exception info by hand
---
extensions-builtin/LDSR/scripts/ldsr_model.py | 7 ++---
extensions-builtin/ScuNET/scripts/scunet_model.py | 6 ++--
modules/api/api.py | 7 +++--
modules/call_queue.py | 22 ++++++--------
modules/codeformer_model.py | 10 +++----
modules/config_states.py | 12 +++-----
modules/errors.py | 16 +++++++++++
modules/extensions.py | 10 +++----
modules/gfpgan_model.py | 6 ++--
modules/hypernetworks/hypernetwork.py | 14 ++++-----
modules/images.py | 9 ++----
modules/interrogate.py | 5 ++--
modules/launch_utils.py | 7 +++--
modules/localization.py | 6 ++--
modules/processing.py | 2 +-
modules/realesrgan_model.py | 14 ++++-----
modules/safe.py | 26 +++++++++--------
modules/script_callbacks.py | 9 +++---
modules/script_loading.py | 7 ++---
modules/scripts.py | 35 ++++++++---------------
modules/sd_hijack_optimizations.py | 6 ++--
modules/textual_inversion/textual_inversion.py | 9 ++----
modules/ui.py | 10 +++----
modules/ui_extensions.py | 9 ++----
scripts/prompts_from_file.py | 6 ++--
25 files changed, 117 insertions(+), 153 deletions(-)
(limited to 'modules/interrogate.py')
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index c4da79f3..95f1669d 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -1,9 +1,8 @@
import os
-import sys
-import traceback
from basicsr.utils.download_util import load_file_from_url
+from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
@@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler):
try:
return LDSR(model, yaml)
-
except Exception:
- print("Error importing LDSR:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path):
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 45d9297b..dd1b822e 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,6 +1,5 @@
import os.path
import sys
-import traceback
import PIL.Image
import numpy as np
@@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net
+
+from modules.errors import print_error
from modules.shared import opts
@@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
- print(f"Error loading ScuNET model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
diff --git a/modules/api/api.py b/modules/api/api.py
index 6a456861..79ce9228 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -16,6 +16,7 @@ from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api import models
+from modules.errors import print_error
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
@@ -108,7 +109,6 @@ def api_middleware(app: FastAPI):
from rich.console import Console
console = Console()
except Exception:
- import traceback
rich_available = False
@app.middleware("http")
@@ -139,11 +139,12 @@ def api_middleware(app: FastAPI):
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
- print(f"API error: {request.method}: {request.url} {err}")
+ message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
+ print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
- traceback.print_exc()
+ print_error(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 447bb764..dba2a9b4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,10 +1,9 @@
import html
-import sys
import threading
-import traceback
import time
from modules import shared, progress
+from modules.errors import print_error
queue_lock = threading.Lock()
@@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
try:
res = list(func(*args, **kwargs))
except Exception as e:
- # When printing out our debug argument list, do not print out more than a MB of text
- max_debug_str_len = 131072 # (1024*1024)/8
-
- print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {args} {kwargs}"
- print(argStr[:max_debug_str_len], file=sys.stderr)
- if len(argStr) > max_debug_str_len:
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
- print(traceback.format_exc(), file=sys.stderr)
+ # When printing out our debug argument list,
+ # do not print out more than a 100 KB of text
+ max_debug_str_len = 131072
+ message = "Error completing request"
+ arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
+ if len(arg_str) > max_debug_str_len:
+ arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
+ print_error(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
@@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res)
return f
-
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index ececdbae..76143e9f 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
import cv2
import torch
@@ -8,6 +6,7 @@ import torch
import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader
+from modules.errors import print_error
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
@@ -105,8 +104,8 @@ def setup_model(dirname):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
+ except Exception:
+ print_error('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
@@ -135,7 +134,6 @@ def setup_model(dirname):
shared.face_restorers.append(codeformer)
except Exception:
- print("Error setting up CodeFormer:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
diff --git a/modules/config_states.py b/modules/config_states.py
index db65bcdb..faeaf28b 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
"""
import os
-import sys
-import traceback
import json
import time
import tqdm
@@ -14,6 +12,7 @@ from collections import OrderedDict
import git
from modules import shared, extensions
+from modules.errors import print_error
from modules.paths_internal import script_path, config_states_dir
@@ -53,8 +52,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading webui git info from {script_path}", exc_info=True)
webui_remote = None
webui_commit_hash = None
@@ -134,8 +132,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading webui git info from {script_path}", exc_info=True)
return
try:
@@ -143,8 +140,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
- print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error restoring webui to commit{webui_commit_hash}")
def restore_extension_config(config):
diff --git a/modules/errors.py b/modules/errors.py
index da4694f8..41d8dc93 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -1,7 +1,23 @@
import sys
+import textwrap
import traceback
+def print_error(
+ message: str,
+ *,
+ exc_info: bool = False,
+) -> None:
+ """
+ Print an error message to stderr, with optional traceback.
+ """
+ for line in message.splitlines():
+ print("***", line, file=sys.stderr)
+ if exc_info:
+ print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
+ print("---")
+
+
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
diff --git a/modules/extensions.py b/modules/extensions.py
index 624832a0..369d2584 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,11 +1,10 @@
import os
-import sys
import threading
-import traceback
import git
from modules import shared
+from modules.errors import print_error
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
@@ -56,8 +55,7 @@ class Extension:
if os.path.exists(os.path.join(self.path, ".git")):
repo = git.Repo(self.path)
except Exception:
- print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading github repository info from {self.path}", exc_info=True)
if repo is None or repo.bare:
self.remote = None
@@ -72,8 +70,8 @@ class Extension:
self.commit_hash = commit.hexsha
self.version = self.commit_hash[:8]
- except Exception as ex:
- print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
+ except Exception:
+ print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None
self.have_info_from_repo = True
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 0131dea4..d2f647fe 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,12 +1,11 @@
import os
-import sys
-import traceback
import facexlib
import gfpgan
import modules.face_restoration
from modules import paths, shared, devices, modelloader
+from modules.errors import print_error
model_dir = "GFPGAN"
user_path = None
@@ -112,5 +111,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
- print("Error setting up GFPGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 570b5603..fcc1ef20 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -2,8 +2,6 @@ import datetime
import glob
import html
import os
-import sys
-import traceback
import inspect
import modules.textual_inversion.dataset
@@ -12,6 +10,7 @@ import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
+from modules.errors import print_error
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -325,17 +324,14 @@ def load_hypernetwork(name):
if path is None:
return None
- hypernetwork = Hypernetwork()
-
try:
+ hypernetwork = Hypernetwork()
hypernetwork.load(path)
+ return hypernetwork
except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading hypernetwork {path}", exc_info=True)
return None
- return hypernetwork
-
def load_hypernetworks(names, multipliers=None):
already_loaded = {}
@@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}
"""
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Exception in training hypernetwork", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/images.py b/modules/images.py
index e21e554c..69151bec 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,6 +1,4 @@
import datetime
-import sys
-import traceback
import pytz
import io
@@ -18,6 +16,7 @@ import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
+from modules.errors import print_error
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
@@ -464,8 +463,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -697,8 +695,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 111b1322..d36e1a5a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,6 +1,5 @@
import os
import sys
-import traceback
from collections import namedtuple
from pathlib import Path
import re
@@ -12,6 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors
+from modules.errors import print_error
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -216,8 +216,7 @@ class InterrogateModels:
res += f", {match}"
except Exception:
- print("Error interrogating", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error interrogating", exc_info=True)
res += ""
self.unload()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 35a52310..22edc106 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -8,6 +8,7 @@ import json
from functools import lru_cache
from modules import cmd_args
+from modules.errors import print_error
from modules.paths_internal import script_path, extensions_dir
args, _ = cmd_args.parser.parse_known_args()
@@ -188,7 +189,7 @@ def run_extension_installer(extension_dir):
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
- print(e, file=sys.stderr)
+ print_error(str(e))
def list_extensions(settings_file):
@@ -198,8 +199,8 @@ def list_extensions(settings_file):
if os.path.isfile(settings_file):
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
- except Exception as e:
- print(e, file=sys.stderr)
+ except Exception:
+ print_error("Could not load settings", exc_info=True)
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
diff --git a/modules/localization.py b/modules/localization.py
index ee9c65e7..9a1df343 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,8 +1,7 @@
import json
import os
-import sys
-import traceback
+from modules.errors import print_error
localizations = {}
@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
- print(f"Error loading localization from {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/processing.py b/modules/processing.py
index b75f2515..5c9bcce8 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,5 @@
import json
+import logging
import math
import os
import sys
@@ -23,7 +24,6 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
-import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 99983678..c8d0c64f 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,12 +1,11 @@
import os
-import sys
-import traceback
import numpy as np
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
+from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader
@@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
@@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
return info
- except Exception as e:
- print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ except Exception:
+ print_error("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _):
@@ -135,5 +132,4 @@ def get_realesrgan_models(scaler):
]
return models
except Exception:
- print("Error making Real-ESRGAN models list:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error making Real-ESRGAN models list", exc_info=True)
diff --git a/modules/safe.py b/modules/safe.py
index e8f50774..b596f565 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -2,8 +2,6 @@
import pickle
import collections
-import sys
-import traceback
import torch
import numpy
@@ -11,6 +9,8 @@ import _codecs
import zipfile
import re
+from modules.errors import print_error
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
@@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ print_error(
+ f"Error verifying pickled file from {filename}\n"
+ "-----> !!!! The file is most likely corrupted !!!! <-----\n"
+ "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
+ exc_info=True,
+ )
return None
-
except Exception:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ print_error(
+ f"Error verifying pickled file from {filename}\n"
+ f"The file may be malicious, so the program is not going to read it.\n"
+ f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
+ exc_info=True,
+ )
return None
return unsafe_torch_load(filename, *args, **kwargs)
@@ -190,4 +193,3 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None
-
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index d2728e12..6aa9c3b6 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,16 +1,15 @@
-import sys
-import traceback
-from collections import namedtuple
import inspect
+from collections import namedtuple
from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
+from modules.errors import print_error
+
def report_exception(c, job):
- print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error executing callback {job} for {c.script}", exc_info=True)
class ImageSaveParams:
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 57b15862..26efffcb 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,8 +1,8 @@
import os
-import sys
-import traceback
import importlib.util
+from modules.errors import print_error
+
def load_module(path):
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser)
except Exception:
- print(f"Error running preload() for {preload_script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running preload() for {preload_script}", exc_info=True)
diff --git a/modules/scripts.py b/modules/scripts.py
index c902804b..a7168fd1 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,12 +1,12 @@
import os
import re
import sys
-import traceback
from collections import namedtuple
import gradio as gr
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
+from modules.errors import print_error
AlwaysVisible = object()
@@ -264,8 +264,7 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
@@ -280,11 +279,9 @@ def load_scripts():
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
- res = func(*args, **kwargs)
- return res
+ return func(*args, **kwargs)
except Exception:
- print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -450,8 +447,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running process: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -459,8 +455,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running before_process_batch: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -468,8 +463,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -477,8 +471,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print(f"Error running postprocess: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -486,8 +479,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -495,24 +487,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print(f"Error running before_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print(f"Error running after_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 2ec0b049..fd186fa2 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,5 @@
from __future__ import annotations
import math
-import sys
-import traceback
import psutil
import torch
@@ -11,6 +9,7 @@ from ldm.util import default
from einops import rearrange
from modules import shared, errors, devices, sub_quadratic_attention
+from modules.errors import print_error
from modules.hypernetworks import hypernetwork
import ldm.modules.attention
@@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print("Cannot import xformers", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Cannot import xformers", exc_info=True)
def get_available_vram():
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d489ed1e..a040a988 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
from collections import namedtuple
import torch
@@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
+from modules.errors import print_error
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
@@ -207,8 +206,7 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
@@ -632,8 +630,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- pass
+ print_error("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/ui.py b/modules/ui.py
index 001b9792..1ad94f02 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -2,7 +2,6 @@ import json
import mimetypes
import os
import sys
-import traceback
from functools import reduce
import warnings
@@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
+from modules.errors import print_error
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
@@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
+ if gen_info_string:
+ print_error(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
@@ -1753,8 +1752,7 @@ def create_ui():
try:
results = modules.extras.run_modelmerger(*args)
except Exception as e:
- print("Error loading/saving model file:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error loading/saving model file", exc_info=True)
modules.sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 515ec262..cadf56be 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,10 +1,8 @@
import json
import os.path
-import sys
import threading
import time
from datetime import datetime
-import traceback
import git
@@ -14,6 +12,7 @@ import shutil
import errno
from modules import extensions, shared, paths, config_states
+from modules.errors import print_error
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
@@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
try:
ext.fetch_and_reset_hard()
except Exception:
- print(f"Error getting updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error getting updates for {ext.name}", exc_info=True)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
@@ -113,8 +111,7 @@ def check_updates(id_task, disable_list):
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
- print(f"Error checking updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error checking updates for {ext.name}", exc_info=True)
shared.state.nextjob()
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index b918a764..4dc24615 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -1,13 +1,12 @@
import copy
import random
-import sys
-import traceback
import shlex
import modules.scripts as scripts
import gradio as gr
from modules import sd_samplers
+from modules.errors import print_error
from modules.processing import Processed, process_images
from modules.shared import state
@@ -136,8 +135,7 @@ class Script(scripts.Script):
try:
args = cmdargs(line)
except Exception:
- print(f"Error parsing line {line} as commandline:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error parsing line {line} as commandline", exc_info=True)
args = {"prompt": line}
else:
args = {"prompt": line}
--
cgit v1.2.3
From 05933840f0676dd1a90a7e2ad3f2a0672624b2cd Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 31 May 2023 19:56:37 +0300
Subject: rename print_error to report, use it with together with package name
---
extensions-builtin/LDSR/scripts/ldsr_model.py | 5 ++---
extensions-builtin/ScuNET/scripts/scunet_model.py | 5 ++---
modules/api/api.py | 5 ++---
modules/call_queue.py | 5 ++---
modules/codeformer_model.py | 7 +++----
modules/config_states.py | 9 ++++-----
modules/errors.py | 8 ++------
modules/extensions.py | 7 +++----
modules/gfpgan_model.py | 5 ++---
modules/hypernetworks/hypernetwork.py | 7 +++----
modules/images.py | 5 ++---
modules/interrogate.py | 3 +--
modules/launch_utils.py | 7 +++----
modules/localization.py | 4 ++--
modules/realesrgan_model.py | 10 +++++-----
modules/safe.py | 7 ++++---
modules/script_callbacks.py | 4 ++--
modules/script_loading.py | 4 ++--
modules/scripts.py | 23 +++++++++++------------
modules/sd_hijack_optimizations.py | 3 +--
modules/textual_inversion/textual_inversion.py | 7 +++----
modules/ui.py | 7 +++----
modules/ui_extensions.py | 7 +++----
scripts/prompts_from_file.py | 5 ++---
24 files changed, 69 insertions(+), 90 deletions(-)
(limited to 'modules/interrogate.py')
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index 95f1669d..dbd6d331 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -2,10 +2,9 @@ import os
from basicsr.utils.download_util import load_file_from_url
-from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
-from modules import shared, script_callbacks
+from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
@@ -51,7 +50,7 @@ class UpscalerLDSR(Upscaler):
try:
return LDSR(model, yaml)
except Exception:
- print_error("Error importing LDSR", exc_info=True)
+ errors.report("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path):
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index dd1b822e..85b4505f 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -9,10 +9,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import devices, modelloader, script_callbacks
+from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net
-from modules.errors import print_error
from modules.shared import opts
@@ -39,7 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
- print_error(f"Error loading ScuNET model: {file}", exc_info=True)
+ errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
diff --git a/modules/api/api.py b/modules/api/api.py
index fbd616a3..d34ab422 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -14,9 +14,8 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
from modules.api import models
-from modules.errors import print_error
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
@@ -145,7 +144,7 @@ def api_middleware(app: FastAPI):
print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
- print_error(message, exc_info=True)
+ errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index dba2a9b4..53af6d70 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -2,8 +2,7 @@ import html
import threading
import time
-from modules import shared, progress
-from modules.errors import print_error
+from modules import shared, progress, errors
queue_lock = threading.Lock()
@@ -62,7 +61,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
if len(arg_str) > max_debug_str_len:
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
- print_error(f"{message}\n{arg_str}", exc_info=True)
+ errors.report(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 76143e9f..4260b016 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -5,8 +5,7 @@ import torch
import modules.face_restoration
import modules.shared
-from modules import shared, devices, modelloader
-from modules.errors import print_error
+from modules import shared, devices, modelloader, errors
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
@@ -105,7 +104,7 @@ def setup_model(dirname):
del output
torch.cuda.empty_cache()
except Exception:
- print_error('Failed inference for CodeFormer', exc_info=True)
+ errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
@@ -134,6 +133,6 @@ def setup_model(dirname):
shared.face_restorers.append(codeformer)
except Exception:
- print_error("Error setting up CodeFormer", exc_info=True)
+ errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
diff --git a/modules/config_states.py b/modules/config_states.py
index faeaf28b..6f1ab53f 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -11,8 +11,7 @@ from datetime import datetime
from collections import OrderedDict
import git
-from modules import shared, extensions
-from modules.errors import print_error
+from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir
@@ -52,7 +51,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print_error(f"Error reading webui git info from {script_path}", exc_info=True)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
webui_remote = None
webui_commit_hash = None
@@ -132,7 +131,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print_error(f"Error reading webui git info from {script_path}", exc_info=True)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
return
try:
@@ -140,7 +139,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
- print_error(f"Error restoring webui to commit{webui_commit_hash}")
+ errors.report(f"Error restoring webui to commit{webui_commit_hash}")
def restore_extension_config(config):
diff --git a/modules/errors.py b/modules/errors.py
index 41d8dc93..e408f500 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -3,11 +3,7 @@ import textwrap
import traceback
-def print_error(
- message: str,
- *,
- exc_info: bool = False,
-) -> None:
+def report(message: str, *, exc_info: bool = False) -> None:
"""
Print an error message to stderr, with optional traceback.
"""
@@ -15,7 +11,7 @@ def print_error(
print("***", line, file=sys.stderr)
if exc_info:
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
- print("---")
+ print("---", file=sys.stderr)
def print_error_explanation(message):
diff --git a/modules/extensions.py b/modules/extensions.py
index 92f93ad9..8608584b 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,8 +1,7 @@
import os
import threading
-from modules import shared
-from modules.errors import print_error
+from modules import shared, errors
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -54,7 +53,7 @@ class Extension:
if os.path.exists(os.path.join(self.path, ".git")):
repo = Repo(self.path)
except Exception:
- print_error(f"Error reading github repository info from {self.path}", exc_info=True)
+ errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
if repo is None or repo.bare:
self.remote = None
@@ -70,7 +69,7 @@ class Extension:
self.version = self.commit_hash[:8]
except Exception:
- print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
+ errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None
self.have_info_from_repo = True
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index d2f647fe..e239a09d 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -4,8 +4,7 @@ import facexlib
import gfpgan
import modules.face_restoration
-from modules import paths, shared, devices, modelloader
-from modules.errors import print_error
+from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN"
user_path = None
@@ -111,4 +110,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
- print_error("Error setting up GFPGAN", exc_info=True)
+ errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index fcc1ef20..5d12b449 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -9,8 +9,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
-from modules.errors import print_error
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -329,7 +328,7 @@ def load_hypernetwork(name):
hypernetwork.load(path)
return hypernetwork
except Exception:
- print_error(f"Error loading hypernetwork {path}", exc_info=True)
+ errors.report(f"Error loading hypernetwork {path}", exc_info=True)
return None
@@ -766,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}
"""
except Exception:
- print_error("Exception in training hypernetwork", exc_info=True)
+ errors.report("Exception in training hypernetwork", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/images.py b/modules/images.py
index 09f728df..30e9ffc5 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -16,7 +16,6 @@ import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
-from modules.errors import print_error
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
@@ -463,7 +462,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print_error(f"Error adding [{pattern}] to filename", exc_info=True)
+ errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -698,7 +697,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print_error("Error parsing NovelAI image generation parameters", exc_info=True)
+ errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
diff --git a/modules/interrogate.py b/modules/interrogate.py
index d36e1a5a..9b2c5b60 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -11,7 +11,6 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors
-from modules.errors import print_error
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -216,7 +215,7 @@ class InterrogateModels:
res += f", {match}"
except Exception:
- print_error("Error interrogating", exc_info=True)
+ errors.report("Error interrogating", exc_info=True)
res += ""
self.unload()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 0bf4cb7e..6e9bb770 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -7,8 +7,7 @@ import platform
import json
from functools import lru_cache
-from modules import cmd_args
-from modules.errors import print_error
+from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
args, _ = cmd_args.parser.parse_known_args()
@@ -189,7 +188,7 @@ def run_extension_installer(extension_dir):
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
- print_error(str(e))
+ errors.report(str(e))
def list_extensions(settings_file):
@@ -200,7 +199,7 @@ def list_extensions(settings_file):
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
except Exception:
- print_error("Could not load settings", exc_info=True)
+ errors.report("Could not load settings", exc_info=True)
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
diff --git a/modules/localization.py b/modules/localization.py
index 9a1df343..e8f585da 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,7 +1,7 @@
import json
import os
-from modules.errors import print_error
+from modules import errors
localizations = {}
@@ -30,6 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
- print_error(f"Error loading localization from {fn}", exc_info=True)
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index c8d0c64f..2d27b321 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -5,10 +5,10 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
-from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
-from modules import modelloader
+from modules import modelloader, errors
+
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
@@ -35,7 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
except Exception:
- print_error("Error importing Real-ESRGAN", exc_info=True)
+ errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
@@ -75,7 +75,7 @@ class UpscalerRealESRGAN(Upscaler):
return info
except Exception:
- print_error("Error making Real-ESRGAN models list", exc_info=True)
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _):
@@ -132,4 +132,4 @@ def get_realesrgan_models(scaler):
]
return models
except Exception:
- print_error("Error making Real-ESRGAN models list", exc_info=True)
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
diff --git a/modules/safe.py b/modules/safe.py
index b596f565..b1d08a79 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -9,9 +9,10 @@ import _codecs
import zipfile
import re
-from modules.errors import print_error
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+from modules import errors
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
@@ -136,7 +137,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
- print_error(
+ errors.report(
f"Error verifying pickled file from {filename}\n"
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
@@ -144,7 +145,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
)
return None
except Exception:
- print_error(
+ errors.report(
f"Error verifying pickled file from {filename}\n"
f"The file may be malicious, so the program is not going to read it.\n"
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6aa9c3b6..ec1469d0 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -5,11 +5,11 @@ from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
-from modules.errors import print_error
+from modules import errors
def report_exception(c, job):
- print_error(f"Error executing callback {job} for {c.script}", exc_info=True)
+ errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
class ImageSaveParams:
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 26efffcb..306a1f35 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,7 +1,7 @@
import os
import importlib.util
-from modules.errors import print_error
+from modules import errors
def load_module(path):
@@ -27,4 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser)
except Exception:
- print_error(f"Error running preload() for {preload_script}", exc_info=True)
+ errors.report(f"Error running preload() for {preload_script}", exc_info=True)
diff --git a/modules/scripts.py b/modules/scripts.py
index a7168fd1..0970f38e 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -5,8 +5,7 @@ from collections import namedtuple
import gradio as gr
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
-from modules.errors import print_error
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors
AlwaysVisible = object()
@@ -264,7 +263,7 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print_error(f"Error loading script: {scriptfile.filename}", exc_info=True)
+ errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
@@ -281,7 +280,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
- print_error(f"Error calling: {filename}/{funcname}", exc_info=True)
+ errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -447,7 +446,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print_error(f"Error running process: {script.filename}", exc_info=True)
+ errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -455,7 +454,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print_error(f"Error running before_process_batch: {script.filename}", exc_info=True)
+ errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -463,7 +462,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print_error(f"Error running process_batch: {script.filename}", exc_info=True)
+ errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -471,7 +470,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print_error(f"Error running postprocess: {script.filename}", exc_info=True)
+ errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -479,7 +478,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True)
+ errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -487,21 +486,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print_error(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print_error(f"Error running before_component: {script.filename}", exc_info=True)
+ errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print_error(f"Error running after_component: {script.filename}", exc_info=True)
+ errors.report(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index fd186fa2..5f0ff513 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,7 +9,6 @@ from ldm.util import default
from einops import rearrange
from modules import shared, errors, devices, sub_quadratic_attention
-from modules.errors import print_error
from modules.hypernetworks import hypernetwork
import ldm.modules.attention
@@ -139,7 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print_error("Cannot import xformers", exc_info=True)
+ errors.report("Cannot import xformers", exc_info=True)
def get_available_vram():
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index b3dcb140..8da050ca 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -12,9 +12,8 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
import modules.textual_inversion.dataset
-from modules.errors import print_error
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
@@ -219,7 +218,7 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print_error(f"Error loading embedding {fn}", exc_info=True)
+ errors.report(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
@@ -643,7 +642,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print_error("Error training embedding", exc_info=True)
+ errors.report("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/ui.py b/modules/ui.py
index fb6b2498..f361264c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -12,8 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
-from modules.errors import print_error
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
@@ -232,7 +231,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
except json.decoder.JSONDecodeError:
if gen_info_string:
- print_error(f"Error parsing JSON generation info: {gen_info_string}")
+ errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
@@ -1752,7 +1751,7 @@ def create_ui():
try:
results = modules.extras.run_modelmerger(*args)
except Exception as e:
- print_error("Error loading/saving model file", exc_info=True)
+ errors.report("Error loading/saving model file", exc_info=True)
modules.sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index e2ee9d72..3140ed64 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -11,8 +11,7 @@ import html
import shutil
import errno
-from modules import extensions, shared, paths, config_states
-from modules.errors import print_error
+from modules import extensions, shared, paths, config_states, errors
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
@@ -45,7 +44,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
try:
ext.fetch_and_reset_hard()
except Exception:
- print_error(f"Error getting updates for {ext.name}", exc_info=True)
+ errors.report(f"Error getting updates for {ext.name}", exc_info=True)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
@@ -111,7 +110,7 @@ def check_updates(id_task, disable_list):
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
- print_error(f"Error checking updates for {ext.name}", exc_info=True)
+ errors.report(f"Error checking updates for {ext.name}", exc_info=True)
shared.state.nextjob()
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 4dc24615..83a2f220 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -5,8 +5,7 @@ import shlex
import modules.scripts as scripts
import gradio as gr
-from modules import sd_samplers
-from modules.errors import print_error
+from modules import sd_samplers, errors
from modules.processing import Processed, process_images
from modules.shared import state
@@ -135,7 +134,7 @@ class Script(scripts.Script):
try:
args = cmdargs(line)
except Exception:
- print_error(f"Error parsing line {line} as commandline", exc_info=True)
+ errors.report(f"Error parsing line {line} as commandline", exc_info=True)
args = {"prompt": line}
else:
args = {"prompt": line}
--
cgit v1.2.3
From f44feb6a10aacc6a5ff4c9275fba2546b2858935 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 30 Jun 2023 13:11:31 +0300
Subject: Add job argument to State.begin()
---
modules/api/api.py | 14 +++++++-------
modules/call_queue.py | 2 +-
modules/extras.py | 3 +--
modules/interrogate.py | 3 +--
modules/postprocessing.py | 3 +--
modules/shared.py | 4 ++--
6 files changed, 13 insertions(+), 16 deletions(-)
(limited to 'modules/interrogate.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 279c384a..3ea099ad 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -327,7 +327,7 @@ class Api:
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
- shared.state.begin()
+ shared.state.begin(job="scripts_txt2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
@@ -384,7 +384,7 @@ class Api:
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
- shared.state.begin()
+ shared.state.begin(job="scripts_img2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
@@ -599,7 +599,7 @@ class Api:
def create_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
@@ -610,7 +610,7 @@ class Api:
def create_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
@@ -620,7 +620,7 @@ class Api:
def preprocess(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info = 'preprocess complete')
@@ -636,7 +636,7 @@ class Api:
def train_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -657,7 +657,7 @@ class Api:
def train_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 69bf63d2..3b94f8a4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -30,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
id_task = None
with queue_lock:
- shared.state.begin()
+ shared.state.begin(job=id_task)
progress.start_task(id_task)
try:
diff --git a/modules/extras.py b/modules/extras.py
index 830b53aa..e9c0263e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
- shared.state.begin()
- shared.state.job = 'model-merge'
+ shared.state.begin(job="model-merge")
def fail(message):
shared.state.textinfo = message
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9b2c5b60..a3ae1dd5 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -184,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
+ shared.state.begin(job="interrogate")
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 736315e2..544b2f72 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
- shared.state.begin()
- shared.state.job = 'extras'
+ shared.state.begin(job="extras")
image_data = []
image_names = []
diff --git a/modules/shared.py b/modules/shared.py
index 203ee1b9..7df2879c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -173,7 +173,7 @@ class State:
return obj
- def begin(self):
+ def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
@@ -187,7 +187,7 @@ class State:
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
-
+ self.job = job
devices.torch_gc()
def end(self):
--
cgit v1.2.3