From 762265eab58cdb8f2d6398769bab43d8b8db0075 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 07:52:45 +0300 Subject: autofixes from ruff --- modules/ui_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index d9faf85a..ed70abe5 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -490,7 +490,7 @@ def create_ui(): config_states.list_config_states() with gr.Blocks(analytics_enabled=False) as ui: - with gr.Tabs(elem_id="tabs_extensions") as tabs: + with gr.Tabs(elem_id="tabs_extensions"): with gr.TabItem("Installed", id="installed"): with gr.Row(elem_id="extensions_installed_top"): -- cgit v1.2.3 From 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 18:28:15 +0300 Subject: Autofix Ruff W (not W605) (mostly whitespace) --- extensions-builtin/LDSR/ldsr_model_arch.py | 4 +- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 6 +-- extensions-builtin/ScuNET/scunet_model_arch.py | 2 +- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- extensions-builtin/SwinIR/swinir_model_arch.py | 2 +- extensions-builtin/SwinIR/swinir_model_arch_v2.py | 52 +++++++++++------------ launch.py | 2 +- modules/api/api.py | 4 +- modules/api/models.py | 2 +- modules/cmd_args.py | 2 +- modules/codeformer/codeformer_arch.py | 14 +++--- modules/codeformer/vqgan_arch.py | 38 ++++++++--------- modules/esrgan_model_arch.py | 4 +- modules/extras.py | 2 +- modules/hypernetworks/hypernetwork.py | 12 +++--- modules/images.py | 2 +- modules/mac_specific.py | 4 +- modules/masking.py | 2 +- modules/ngrok.py | 4 +- modules/processing.py | 2 +- modules/script_callbacks.py | 14 +++--- modules/sd_hijack.py | 12 +++--- modules/sd_hijack_optimizations.py | 32 +++++++------- modules/sd_models.py | 4 +- modules/sd_samplers_kdiffusion.py | 18 ++++---- modules/sub_quadratic_attention.py | 2 +- modules/textual_inversion/dataset.py | 4 +- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 16 +++---- modules/ui.py | 18 ++++---- modules/ui_extensions.py | 6 +-- modules/xlmr.py | 6 +-- pyproject.toml | 5 ++- scripts/img2imgalt.py | 14 +++--- scripts/loopback.py | 8 ++-- scripts/poor_mans_outpainting.py | 2 +- scripts/prompt_matrix.py | 2 +- scripts/prompts_from_file.py | 4 +- scripts/sd_upscale.py | 2 +- 39 files changed, 167 insertions(+), 166 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 2173de79..7f450086 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -130,11 +130,11 @@ class LDSR: im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) else: print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) - + logs = self.run(model["model"], im_padded, diffusion_steps, eta) sample = logs["sample"] diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 57c02d12..631a08ef 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1): self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: @@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1): z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): + if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] @@ -890,7 +890,7 @@ class LatentDiffusionV1(DDPMV1): if hasattr(self, "split_input_params"): assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids + assert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py index 8028918a..b51a8806 100644 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -265,4 +265,4 @@ class SCUNet(nn.Module): nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) \ No newline at end of file + nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 55dd94ab..0ba50487 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -150,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale): for w_idx in w_idx_list: if state.interrupted or state.skipped: break - + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 73e37cfa..93b93274 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -805,7 +805,7 @@ class SwinIR(nn.Module): def forward(self, x): H, W = x.shape[2:] x = self.check_image_size(x) - + self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 3ca9be78..dad22cca 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = None self.register_buffer("attn_mask", attn_mask) - + def calculate_mask(self, x_size): # calculate attention mask for SW-MSA H, W = x_size @@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - return attn_mask + return attn_mask def forward(self, x, x_size): H, W = x_size @@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module): attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C else: attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - + # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C @@ -369,7 +369,7 @@ class PatchMerging(nn.Module): H, W = self.input_resolution flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim flops += H * W * self.dim // 2 - return flops + return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. @@ -447,7 +447,7 @@ class BasicLayer(nn.Module): nn.init.constant_(blk.norm1.weight, 0) nn.init.constant_(blk.norm2.bias, 0) nn.init.constant_(blk.norm2.weight, 0) - + class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: @@ -492,7 +492,7 @@ class PatchEmbed(nn.Module): flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim - return flops + return flops class RSTB(nn.Module): """Residual Swin Transformer Block (RSTB). @@ -531,7 +531,7 @@ class RSTB(nn.Module): num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path, norm_layer=norm_layer, @@ -622,7 +622,7 @@ class Upsample(nn.Sequential): else: raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') super(Upsample, self).__init__(*m) - + class Upsample_hf(nn.Sequential): """Upsample module. @@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential): m.append(nn.PixelShuffle(3)) else: raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) + super(Upsample_hf, self).__init__(*m) class UpsampleOneStep(nn.Sequential): @@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential): H, W = self.input_resolution flops = H * W * self.num_feat * 3 * 9 return flops - - + + class Swin2SR(nn.Module): r""" Swin2SR @@ -699,7 +699,7 @@ class Swin2SR(nn.Module): def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, + window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', @@ -764,7 +764,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, + qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results norm_layer=norm_layer, @@ -776,7 +776,7 @@ class Swin2SR(nn.Module): ) self.layers.append(layer) - + if self.upsampler == 'pixelshuffle_hf': self.layers_hf = nn.ModuleList() for i_layer in range(self.num_layers): @@ -787,7 +787,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, + qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results norm_layer=norm_layer, @@ -799,7 +799,7 @@ class Swin2SR(nn.Module): ) self.layers_hf.append(layer) - + self.norm = norm_layer(self.num_features) # build the last conv layer in deep feature extraction @@ -829,10 +829,10 @@ class Swin2SR(nn.Module): self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_after_aux = nn.Sequential( nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) + nn.LeakyReLU(inplace=True)) self.upsample = Upsample(upscale, num_feat) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - + elif self.upsampler == 'pixelshuffle_hf': self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) @@ -846,7 +846,7 @@ class Swin2SR(nn.Module): nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - + elif self.upsampler == 'pixelshuffledirect': # for lightweight SR (to save parameters) self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, @@ -905,7 +905,7 @@ class Swin2SR(nn.Module): x = self.patch_unembed(x, x_size) return x - + def forward_features_hf(self, x): x_size = (x.shape[2], x.shape[3]) x = self.patch_embed(x) @@ -919,7 +919,7 @@ class Swin2SR(nn.Module): x = self.norm(x) # B L C x = self.patch_unembed(x, x_size) - return x + return x def forward(self, x): H, W = x.shape[2:] @@ -951,7 +951,7 @@ class Swin2SR(nn.Module): x = self.conv_after_body(self.forward_features(x)) + x x_before = self.conv_before_upsample(x) x_out = self.conv_last(self.upsample(x_before)) - + x_hf = self.conv_first_hf(x_before) x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf x_hf = self.conv_before_upsample_hf(x_hf) @@ -977,15 +977,15 @@ class Swin2SR(nn.Module): x_first = self.conv_first(x) res = self.conv_after_body(self.forward_features(x_first)) + x_first x = x + self.conv_last(res) - + x = x / self.img_range + self.mean if self.upsampler == "pixelshuffle_aux": return x[:, :, :H*self.upscale, :W*self.upscale], aux - + elif self.upsampler == "pixelshuffle_hf": x_out = x_out / self.img_range + self.mean return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - + else: return x[:, :, :H*self.upscale, :W*self.upscale] @@ -1014,4 +1014,4 @@ if __name__ == '__main__': x = torch.randn((1, 3, height, width)) x = model(x) - print(x.shape) \ No newline at end of file + print(x.shape) diff --git a/launch.py b/launch.py index 670af87c..62b33f14 100644 --- a/launch.py +++ b/launch.py @@ -327,7 +327,7 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) - + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) diff --git a/modules/api/api.py b/modules/api/api.py index 594fa655..165985c3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -227,7 +227,7 @@ class Api: script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx - + def get_scripts_list(self): t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] @@ -237,7 +237,7 @@ class Api: def get_script(self, script_name, script_runner): if script_name is None or script_name == "": return None, None - + script_idx = script_name_to_index(script_name, script_runner.scripts) return script_runner.scripts[script_idx] diff --git a/modules/api/models.py b/modules/api/models.py index 4d291076..006ccdb7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -289,4 +289,4 @@ class MemoryResponse(BaseModel): class ScriptsList(BaseModel): txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") - img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file + img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e01ca655..f4a4ab36 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -102,4 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) -parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') \ No newline at end of file +parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 45c70f84..12db6814 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -119,7 +119,7 @@ class TransformerSALayer(nn.Module): tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): - + # self attention tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) @@ -159,7 +159,7 @@ class Fuse_sft_block(nn.Module): @ARCH_REGISTRY.register() class CodeFormer(VQAutoEncoder): - def __init__(self, dim_embd=512, n_head=8, n_layers=9, + def __init__(self, dim_embd=512, n_head=8, n_layers=9, codebook_size=1024, latent_size=256, connect_list=('32', '64', '128', '256'), fix_modules=('quantize', 'generator')): @@ -179,14 +179,14 @@ class CodeFormer(VQAutoEncoder): self.feat_emb = nn.Linear(256, self.dim_embd) # transformer - self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) for _ in range(self.n_layers)]) # logits_predict head self.idx_pred_layer = nn.Sequential( nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)) - + self.channels = { '16': 512, '32': 256, @@ -221,7 +221,7 @@ class CodeFormer(VQAutoEncoder): enc_feat_dict = {} out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] for i, block in enumerate(self.encoder.blocks): - x = block(x) + x = block(x) if i in out_list: enc_feat_dict[str(x.shape[-1])] = x.clone() @@ -266,11 +266,11 @@ class CodeFormer(VQAutoEncoder): fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] for i, block in enumerate(self.generator.blocks): - x = block(x) + x = block(x) if i in fuse_list: # fuse after i-th block f_size = str(x.shape[-1]) if w>0: x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) out = x # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat \ No newline at end of file + return out, logits, lq_feat diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index b24a0394..09ee6660 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -13,7 +13,7 @@ from basicsr.utils.registry import ARCH_REGISTRY def normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - + @torch.jit.script def swish(x): @@ -210,15 +210,15 @@ class AttnBlock(nn.Module): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) + q = q.permute(0, 2, 1) k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) + w_ = torch.bmm(q, k) w_ = w_ * (int(c)**(-0.5)) w_ = F.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) + w_ = w_.permute(0, 2, 1) h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) @@ -270,18 +270,18 @@ class Encoder(nn.Module): def forward(self, x): for block in self.blocks: x = block(x) - + return x class Generator(nn.Module): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): super().__init__() - self.nf = nf - self.ch_mult = ch_mult + self.nf = nf + self.ch_mult = ch_mult self.num_resolutions = len(self.ch_mult) self.num_res_blocks = res_blocks - self.resolution = img_size + self.resolution = img_size self.attn_resolutions = attn_resolutions self.in_channels = emb_dim self.out_channels = 3 @@ -315,24 +315,24 @@ class Generator(nn.Module): blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) self.blocks = nn.ModuleList(blocks) - + def forward(self, x): for block in self.blocks: x = block(x) - + return x - + @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks self.codebook_size = codebook_size self.embed_dim = emb_dim self.ch_mult = ch_mult @@ -363,11 +363,11 @@ class VQAutoEncoder(nn.Module): self.kl_weight ) self.generator = Generator( - self.nf, + self.nf, self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, + self.ch_mult, + self.n_blocks, + self.resolution, self.attn_resolutions ) @@ -432,4 +432,4 @@ class VQGANDiscriminator(nn.Module): raise ValueError('Wrong params!') def forward(self, x): - return self.main(x) \ No newline at end of file + return self.main(x) diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py index 4de9dd8d..2b9888ba 100644 --- a/modules/esrgan_model_arch.py +++ b/modules/esrgan_model_arch.py @@ -105,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module): Modified options that can be used: - "Partial Convolution based Padding" arXiv:1811.11718 - "Spectral normalization" arXiv:1802.05957 - - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. {Rakotonirina} and A. {Rasoanaivo} """ @@ -170,7 +170,7 @@ class GaussianNoise(nn.Module): scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x sampled_noise = self.noise.repeat(*x.size()).normal_() * scale x = x + sampled_noise - return x + return x def conv1x1(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) diff --git a/modules/extras.py b/modules/extras.py index eb4f0b42..bdf9b3b7 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ result_is_inpainting_model = True else: theta_0[key] = theta_func2(a, b, multiplier) - + theta_0[key] = to_half(theta_0[key], save_as_half) shared.state.sampling_step += 1 diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38ef074f..570b5603 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -540,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None if clip_grad: clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) @@ -593,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi print(e) scaler = torch.cuda.amp.GradScaler() - + batch_size = ds.batch_size gradient_step = ds.gradient_step # n steps = batch_size * gradient_step * n image processed @@ -636,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi if clip_grad: clip_grad_sched.step(hypernetwork.step) - + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if use_weight: @@ -657,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step += loss.item() scaler.scale(loss).backward() - + # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue loss_logging.append(_loss_step) if clip_grad: clip_grad(weights, clip_grad_sched.learn_rate) - + scaler.step(optimizer) scaler.update() hypernetwork.step += 1 @@ -674,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step = 0 steps_done = hypernetwork.step + 1 - + epoch_num = hypernetwork.step // steps_per_epoch epoch_step = hypernetwork.step % steps_per_epoch diff --git a/modules/images.py b/modules/images.py index 3b8b62d9..b2de3662 100644 --- a/modules/images.py +++ b/modules/images.py @@ -367,7 +367,7 @@ class FilenameGenerator: self.seed = seed self.prompt = prompt self.image = image - + def hasprompt(self, *args): lower = self.prompt.lower() if self.p is None or self.prompt is None: diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 5c2f92a1..d74c6b95 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -42,7 +42,7 @@ if has_mps: # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) - # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 @@ -60,4 +60,4 @@ if has_mps: # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 if platform.processor() == 'i386': for funcName in ['torch.argmax', 'torch.Tensor.argmax']: - CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') \ No newline at end of file + CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') diff --git a/modules/masking.py b/modules/masking.py index a5c4d2da..be9f84c7 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps def get_crop_region(mask, pad=0): """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" - + h, w = mask.shape crop_left = 0 diff --git a/modules/ngrok.py b/modules/ngrok.py index 7a7b4b26..67a74e85 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -13,7 +13,7 @@ def connect(token, port, region): config = conf.PyngrokConfig( auth_token=token, region=region ) - + # Guard for existing tunnels existing = ngrok.get_tunnels(pyngrok_config=config) if existing: @@ -24,7 +24,7 @@ def connect(token, port, region): print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n' 'You can use this link after the launch is complete.') return - + try: if account is None: public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url diff --git a/modules/processing.py b/modules/processing.py index c3932d6b..f902b9df 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -164,7 +164,7 @@ class StableDiffusionProcessing: self.all_subseeds = None self.iteration = 0 self.is_hr_pass = False - + @property def sd_model(self): diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 17109732..7d9dd736 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -32,22 +32,22 @@ class CFGDenoiserParams: def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): self.x = x """Latent image representation in the process of being denoised""" - + self.image_cond = image_cond """Conditioning image""" - + self.sigma = sigma """Current sigma noise step value""" - + self.sampling_step = sampling_step """Current Sampling step number""" - + self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" - + self.text_cond = text_cond """ Encoder hidden states of text conditioning from prompt""" - + self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" @@ -240,7 +240,7 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) - + def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index e374aeb8..7e50f1ab 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -34,7 +34,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th - + optimization_method = None can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp @@ -92,12 +92,12 @@ def fix_checkpoint(): def weighted_loss(sd_model, pred, target, mean=True): #Calculate the weight normally, but ignore the mean loss = sd_model._old_get_loss(pred, target, mean=False) - + #Check if we have weights available weight = getattr(sd_model, '_custom_loss_weight', None) if weight is not None: loss *= weight - + #Return the loss, as mean if specified return loss.mean() if mean else loss @@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs): try: #Temporarily append weights to a place accessible during loss calc sd_model._custom_loss_weight = w - + #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set if not hasattr(sd_model, '_old_get_loss'): @@ -120,7 +120,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs): del sd_model._custom_loss_weight except AttributeError: pass - + #If we have an old loss function, reset the loss function to the original one if hasattr(sd_model, '_old_get_loss'): sd_model.get_loss = sd_model._old_get_loss @@ -184,7 +184,7 @@ class StableDiffusionModelHijack: def undo_hijack(self, m): if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: - m.cond_stage_model = m.cond_stage_model.wrapped + m.cond_stage_model = m.cond_stage_model.wrapped elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index a174bbe1..f00fe55c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): end = i + 2 s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) s1 *= self.scale - + s2 = s1.softmax(dim=-1) del s1 - + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) del s2 del q, k, v @@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None): with devices.without_autocast(disable=not shared.opts.upcast_attn): k_in = k_in * self.scale - + del context, x - + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) del q_in, k_in, v_in - + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - + mem_free_total = get_available_vram() - + gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() modifier = 3 if q.element_size() == 2 else 2.5 mem_required = tensor_size * modifier steps = 1 - + if mem_required > mem_free_total: steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - + if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - + s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 - + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 - + del q, k, v r1 = r1.to(dtype) @@ -228,7 +228,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): with devices.without_autocast(disable=not shared.opts.upcast_attn): k = k * self.scale - + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) r = einsum_op(q, k, v) r = r.to(dtype) @@ -369,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - + del q_in, k_in, v_in dtype = q.dtype @@ -451,7 +451,7 @@ def cross_attention_attnblock_forward(self, x): h3 += x return h3 - + def xformers_attnblock_forward(self, x): try: h_ = x diff --git a/modules/sd_models.py b/modules/sd_models.py index d1e946a5..3316d021 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,7 +165,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -372,7 +372,7 @@ def enable_midas_autodownload(): if not os.path.exists(path): if not os.path.exists(midas_path): mkdir(midas_path) - + print(f"Downloading midas model weights for {model_type} to {path}") request.urlretrieve(midas_urls[model_type], path) print(f"{model_type} downloaded") diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 2f733cf5..e9e41818 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -93,10 +93,10 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} else: image_uncond = image_cond - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -316,7 +316,7 @@ class KDiffusionSampler: sigma_sched = sigmas[steps - t_enc - 1:] xi = x + noise * sigma_sched[0] - + extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters @@ -339,9 +339,9 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond } @@ -374,9 +374,9 @@ class KDiffusionSampler: self.last_latent = x samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond }, disable=False, callback=self.callback_state, **extra_params_kwargs)) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index cc38debd..497568eb 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -179,7 +179,7 @@ def efficient_dot_product_attention( chunk_idx, min(query_chunk_size, q_tokens) ) - + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 41610e03..b9621fc9 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,7 +118,7 @@ class PersonalizedBase(Dataset): weight = torch.ones(latent_sample.shape) else: weight = None - + if latent_sampling_method == "random": entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) else: @@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader): return self def collate_wrapper_random(batch): - return BatchLoaderRandom(batch) \ No newline at end of file + return BatchLoaderRandom(batch) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index d0cad09e..a009d8e8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -125,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr default=None ) return wh and center_crop(image, *wh) - + def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): width = process_width diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9e1b2b9a..d489ed1e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -323,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) def tensorboard_add_scaler(tensorboard_writer, tag, value, step): - tensorboard_writer.add_scalar(tag=tag, + tensorboard_writer.add_scalar(tag=tag, scalar_value=value, global_step=step) def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): # Convert a pil image to a torch tensor img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) - img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands())) img_tensor = img_tensor.permute((2, 0, 1)) - + tensorboard_writer.add_image(tag, img_tensor, global_step=step) def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): @@ -402,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename - + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ @@ -412,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed - + if shared.opts.training_enable_tensorboard: tensorboard_writer = tensorboard_setup(log_directory) @@ -439,7 +439,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu') if embedding.checksum() == optimizer_saved_dict.get('hash', None): optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) - + if optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) print("Loaded existing optimizer from checkpoint") @@ -485,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if clip_grad: clip_grad_sched.step(embedding.step) - + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if use_weight: @@ -513,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue - + if clip_grad: clip_grad(embedding.vec, clip_grad_sched.learn_rate) diff --git a/modules/ui.py b/modules/ui.py index 1efb656a..ff82fff6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1171,7 +1171,7 @@ def create_ui(): process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - + with gr.Column(visible=False) as process_multicrop_col: gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') with gr.Row(): @@ -1183,7 +1183,7 @@ def create_ui(): with gr.Row(): process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") - + with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") @@ -1226,7 +1226,7 @@ def create_ui(): with FormRow(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - + with FormRow(): clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) @@ -1565,7 +1565,7 @@ def create_ui(): gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - + def unload_sd_weights(): modules.sd_models.unload_model_weights() @@ -1841,15 +1841,15 @@ def versions_html(): return f""" version: {tag} - •  + • python: {python_version} - •  + • torch: {getattr(torch, '__long_version__',torch.__version__)} - •  + • xformers: {xformers_version} - •  + • gradio: {gr.__version__} - •  + • checkpoint: N/A """ diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ed70abe5..af497733 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -467,7 +467,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" {html.escape(description)}

Added: {html.escape(added)}

{install_code} - + """ for tag in [x for x in extension_tags if x not in tags]: @@ -535,9 +535,9 @@ def create_ui(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") - with gr.Row(): + with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) - + install_result = gr.HTML() available_extensions_table = gr.HTML() diff --git a/modules/xlmr.py b/modules/xlmr.py index e056c3f6..a407a3ca 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): config_class = BertSeriesConfig def __init__(self, config=None, **kargs): - # modify initialization for autoloading + # modify initialization for autoloading if config is None: config = XLMRobertaConfig() config.attention_probs_dropout_prob= 0.1 @@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): text["attention_mask"] = torch.tensor( text['attention_mask']).to(device) features = self(**text) - return features['projection_state'] + return features['projection_state'] def forward( self, @@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): base_model_prefix = 'roberta' - config_class= RobertaSeriesConfig \ No newline at end of file + config_class= RobertaSeriesConfig diff --git a/pyproject.toml b/pyproject.toml index c88907be..d4a1bbf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ extend-select = [ "B", "C", "I", + "W", ] exclude = [ @@ -20,7 +21,7 @@ ignore = [ "I001", # Import block is un-sorted or un-formatted "C901", # Function is too complex "C408", # Rewrite as a literal - + "W605", # invalid escape sequence, messes with some docstrings ] [tool.ruff.per-file-ignores] @@ -28,4 +29,4 @@ ignore = [ [tool.ruff.flake8-bugbear] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. -extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] \ No newline at end of file +extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index bb00fb3f..1e833fa8 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -149,9 +149,9 @@ class Script(scripts.Script): sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) return [ - info, + info, override_sampler, - override_prompt, original_prompt, original_negative_prompt, + override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment, @@ -191,17 +191,17 @@ class Script(scripts.Script): self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) - + combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) - + sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) sigmas = sampler.model_wrap.get_sigmas(p.steps) - + noise_dt = combined_noise - (p.init_latent / sigmas[0]) - + p.seed = p.seed + 1 - + return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) p.sample = sample_extra diff --git a/scripts/loopback.py b/scripts/loopback.py index ad6609be..2d5feaf9 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -14,7 +14,7 @@ class Script(scripts.Script): def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): + def ui(self, is_img2img): loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength")) denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear") @@ -104,7 +104,7 @@ class Script(scripts.Script): p.seed = processed.seed + 1 p.denoising_strength = calculate_denoising_strength(i + 1) - + if state.skipped: break @@ -121,7 +121,7 @@ class Script(scripts.Script): all_images.append(last_image) p.inpainting_fill = original_inpainting_fill - + if state.interrupted: break @@ -132,7 +132,7 @@ class Script(scripts.Script): if opts.return_grid: grids.append(grid) - + all_images = grids + all_images processed = Processed(p, all_images, initial_seed, initial_info) diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index c0bbecc1..ea0632b6 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -19,7 +19,7 @@ class Script(scripts.Script): def ui(self, is_img2img): if not is_img2img: return None - + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index fb06beab..88324fe6 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -96,7 +96,7 @@ class Script(scripts.Script): p.prompt_for_display = positive_prompt processed = process_images(p) - grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) + grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size) processed.images.insert(0, grid) processed.index_of_first_image = 1 diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 9607077a..2378816f 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -109,7 +109,7 @@ class Script(scripts.Script): def title(self): return "Prompts from file or textbox" - def ui(self, is_img2img): + def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) @@ -166,7 +166,7 @@ class Script(scripts.Script): proc = process_images(copy_p) images += proc.images - + if checkbox_iterate: p.seed = p.seed + (p.batch_size * p.n_iter) all_prompts += proc.all_prompts diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 0b1d3096..e614c23b 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -16,7 +16,7 @@ class Script(scripts.Script): def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): + def ui(self, is_img2img): info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) -- cgit v1.2.3 From 0d2a4b608c075daa3a4d1a1c9df01a763ae4793a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 15 May 2023 20:57:11 +0300 Subject: load extensions' git metadata in parallel to loading the main program to save a ton of time during startup --- modules/config_states.py | 2 ++ modules/extensions.py | 12 +++++++++++- modules/ui_extensions.py | 15 +++++++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/config_states.py b/modules/config_states.py index 75da862a..db65bcdb 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -83,6 +83,8 @@ def get_extension_config(): ext_config = {} for ext in extensions.extensions: + ext.read_info_from_repo() + entry = { "name": ext.name, "path": ext.path, diff --git a/modules/extensions.py b/modules/extensions.py index bc2c0450..1053253e 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,5 +1,6 @@ import os import sys +import threading import traceback import time @@ -24,6 +25,8 @@ def active(): class Extension: + lock = threading.Lock() + def __init__(self, name, path, enabled=True, is_builtin=False): self.name = name self.path = path @@ -42,8 +45,13 @@ class Extension: if self.is_builtin or self.have_info_from_repo: return - self.have_info_from_repo = True + with self.lock: + if self.have_info_from_repo: + return + self.do_read_info_from_repo() + + def do_read_info_from_repo(self): repo = None try: if os.path.exists(os.path.join(self.path, ".git")): @@ -70,6 +78,8 @@ class Extension: print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) self.remote = None + self.have_info_from_repo = True + def list_files(self, subdir, extension): from modules import scripts diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index af497733..aaa7e571 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,6 +1,7 @@ import json import os.path import sys +import threading import time from datetime import datetime import traceback @@ -484,11 +485,18 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" return code, list(tags) +def preload_extensions_git_metadata(): + for extension in extensions.extensions: + extension.read_info_from_repo() + + def create_ui(): import modules.ui config_states.list_config_states() + threading.Thread(target=preload_extensions_git_metadata).start() + with gr.Blocks(analytics_enabled=False) as ui: with gr.Tabs(elem_id="tabs_extensions"): with gr.TabItem("Installed", id="installed"): @@ -508,7 +516,8 @@ def create_ui(): """ info = gr.HTML(html) - extensions_table = gr.HTML(lambda: extension_table()) + extensions_table = gr.HTML('Loading...') + ui.load(fn=extension_table, inputs=[], outputs=[extensions_table]) apply.click( fn=apply_and_restart, @@ -595,7 +604,8 @@ def create_ui(): config_save_button = gr.Button(value="Save Current Config") config_states_info = gr.HTML("") - config_states_table = gr.HTML(lambda: update_config_states_table("Current")) + config_states_table = gr.HTML("Loading...") + ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table]) config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info]) @@ -608,4 +618,5 @@ def create_ui(): outputs=[config_states_table], ) + return ui -- cgit v1.2.3 From a47abe1b7b667374e9df1932172230132d3fe8db Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 15 May 2023 21:22:35 +0300 Subject: update extensions table: show branch, show date in separate column, and show version from tags if available --- modules/extensions.py | 6 ++---- modules/ui_extensions.py | 7 ++++++- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/extensions.py b/modules/extensions.py index 1053253e..f16f059e 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -66,13 +66,11 @@ class Extension: try: self.status = 'unknown' self.remote = next(repo.remote().urls, None) - head = repo.head.commit self.commit_date = repo.head.commit.committed_date - ts = time.asctime(time.gmtime(self.commit_date)) if repo.active_branch: self.branch = repo.active_branch.name - self.commit_hash = head.hexsha - self.version = f'{self.commit_hash[:8]} ({ts})' + self.commit_hash = repo.head.commit.hexsha + self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care except Exception as ex: print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index aaa7e571..6ad9a97c 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -141,7 +141,9 @@ def extension_table(): Extension URL - Version + Branch + Version + Date Update @@ -149,6 +151,7 @@ def extension_table(): """ for ext in extensions.extensions: + ext: extensions.Extension ext.read_info_from_repo() remote = f"""{html.escape("built-in" if ext.is_builtin else ext.remote or '')}""" @@ -170,7 +173,9 @@ def extension_table(): {html.escape(ext.name)} {remote} + {ext.branch} {version_link} + {time.asctime(time.gmtime(ext.commit_date))} {ext_status} """ -- cgit v1.2.3 From 3d76eabbca3adb711787d1802d6b61c0971b4bc0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 16 May 2023 07:59:43 +0300 Subject: add visual progress for extension installation from URL --- modules/extensions.py | 1 - modules/ui_extensions.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/extensions.py b/modules/extensions.py index f16f059e..359a7aa5 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -3,7 +3,6 @@ import sys import threading import traceback -import time import git from modules import shared diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6ad9a97c..d7a0f685 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -593,9 +593,9 @@ def create_ui(): install_result = gr.HTML(elem_id="extension_install_result") install_button.click( - fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), + fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), inputs=[install_dirname, install_url, install_branch], - outputs=[extensions_table, install_result], + outputs=[install_url, extensions_table, install_result], ) with gr.TabItem("Backup/Restore"): -- cgit v1.2.3 From 85b4f89926f7c3aaa7846dcbb47df3fd3b483b6b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 23:46:45 +0300 Subject: Replace state.need_restart with state.server_command + replace poll loop with signal --- modules/shared.py | 42 +++++++++++++++++++++++++++++++++++++++++- modules/ui.py | 6 +----- modules/ui_extensions.py | 7 ++----- webui.py | 39 ++++++++++++++++++++++++--------------- 4 files changed, 68 insertions(+), 26 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/shared.py b/modules/shared.py index 3abf71c0..648a2a19 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -2,6 +2,7 @@ import datetime import json import os import sys +import threading import time import gradio as gr @@ -110,8 +111,47 @@ class State: id_live_preview = 0 textinfo = None time_start = None - need_restart = False server_start = None + _server_command_signal = threading.Event() + _server_command: str | None = None + + @property + def need_restart(self) -> bool: + # Compatibility getter for need_restart. + return self.server_command == "restart" + + @need_restart.setter + def need_restart(self, value: bool) -> None: + # Compatibility setter for need_restart. + if value: + self.server_command = "restart" + + @property + def server_command(self): + return self._server_command + + @server_command.setter + def server_command(self, value: str | None) -> None: + """ + Set the server command to `value` and signal that it's been set. + """ + self._server_command = value + self._server_command_signal.set() + + def wait_for_server_command(self, timeout: float | None = None) -> str | None: + """ + Wait for server command to get set; return and clear the value and signal. + """ + if self._server_command_signal.wait(timeout): + self._server_command_signal.clear() + req = self._server_command + self._server_command = None + return req + return None + + def request_restart(self) -> None: + self.interrupt() + self.server_command = True def skip(self): self.skipped = True diff --git a/modules/ui.py b/modules/ui.py index 8e51e782..bed8464e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1609,12 +1609,8 @@ def create_ui(): outputs=[] ) - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - restart_gradio.click( - fn=request_restart, + fn=shared.state.request_restart, _js='restart_reload', inputs=[], outputs=[], diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index d7a0f685..4ba3bdd7 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all): shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all shared.opts.save(shared.config_filename) - - shared.state.interrupt() - shared.state.need_restart = True + shared.state.request_restart() def save_config_state(name): @@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type): if restore_type == "webui" or restore_type == "both": config_states.restore_webui_config(config_state) - shared.state.interrupt() - shared.state.need_restart = True + shared.state.request_restart() return "" diff --git a/webui.py b/webui.py index 293a16cc..39dec3ca 100644 --- a/webui.py +++ b/webui.py @@ -234,7 +234,10 @@ def initialize(): print(f'Interrupted with signal {sig} in {frame}') os._exit(0) - signal.signal(signal.SIGINT, sigint_handler) + if not os.environ.get("COVERAGE_RUN"): + # Don't install the immediate-quit handler when running under coverage, + # as then the coverage report won't be generated. + signal.signal(signal.SIGINT, sigint_handler) def setup_middleware(app): @@ -255,19 +258,6 @@ def create_api(app): return api -def wait_on_server(demo=None): - while 1: - time.sleep(0.5) - if shared.state.need_restart: - shared.state.need_restart = False - time.sleep(0.5) - demo.close() - time.sleep(0.5) - - modules.script_callbacks.app_reload_callback() - break - - def api_only(): initialize() @@ -328,6 +318,7 @@ def webui(): inbrowser=cmd_opts.autolaunch, prevent_thread_lock=True ) + # after initial launch, disable --autolaunch for subsequent restarts cmd_opts.autolaunch = False @@ -359,8 +350,26 @@ def webui(): redirector.get("/") gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}") - wait_on_server(shared.demo) + try: + while True: + server_command = shared.state.wait_for_server_command(timeout=5) + if server_command: + if server_command in ("stop", "restart"): + break + else: + print(f"Unknown server command: {server_command}") + except KeyboardInterrupt: + server_command = "stop" + + if server_command == "stop": + # If we catch a keyboard interrupt, we want to stop the server and exit. + print('Caught KeyboardInterrupt, stopping...') + shared.demo.close() + break print('Restarting UI...') + shared.demo.close() + time.sleep(0.5) + modules.script_callbacks.app_reload_callback() startup_timer.reset() -- cgit v1.2.3 From 379fd6204dfc27e16acc03f705ecd9ff23c2d1c0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 19 May 2023 08:14:38 +0300 Subject: make links to http://<...>.git git extensions work in the extension tab --- modules/ui_extensions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 4ba3bdd7..ef18f438 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -125,7 +125,9 @@ def make_commit_link(commit_hash, remote, text=None): if text is None: text = commit_hash[:8] if remote.startswith("https://github.com/"): - href = os.path.join(remote, "commit", commit_hash) + if remote.endswith(".git"): + remote = remote[:-4] + href = remote + "/commit/" + commit_hash return f'{text}' else: return text -- cgit v1.2.3 From bf5e5f4269a263cce0abb993419302edb5b5b8d0 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 20 May 2023 15:08:08 +0900 Subject: extensions clone depth 1 --- modules/ui_extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ef18f438..88a7381b 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -345,12 +345,12 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.rmtree(tmpdir, True) if not branch_name: # if no branch is specified, use the default branch - with git.Repo.clone_from(url, tmpdir) as repo: + with git.Repo.clone_from(url, tmpdir, depth=1) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() else: - with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo: + with git.Repo.clone_from(url, tmpdir, depth=1, branch=branch_name) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() -- cgit v1.2.3 From cd03317c05b5e31d2d7b77d7ddd0e01f45c6064d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 21 May 2023 16:42:54 +0900 Subject: --filter=blob:none Co-Authored-By: Aarni Koskela Co-Authored-By: catboxanon <122327233+catboxanon@users.noreply.github.com> --- modules/ui_extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 88a7381b..515ec262 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -345,12 +345,12 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.rmtree(tmpdir, True) if not branch_name: # if no branch is specified, use the default branch - with git.Repo.clone_from(url, tmpdir, depth=1) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() else: - with git.Repo.clone_from(url, tmpdir, depth=1, branch=branch_name) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() -- cgit v1.2.3 From 77a10c62c9a44a27e8030eff6e5b3fb182be55ae Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 00:41:12 +0300 Subject: Patch GitPython to not use leaky persistent processes --- modules/extensions.py | 9 ++++----- modules/gitpython_hack.py | 42 ++++++++++++++++++++++++++++++++++++++++++ modules/ui_extensions.py | 6 ++++++ 3 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 modules/gitpython_hack.py (limited to 'modules/ui_extensions.py') diff --git a/modules/extensions.py b/modules/extensions.py index 624832a0..fb7250e6 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -3,9 +3,8 @@ import sys import threading import traceback -import git - from modules import shared +from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 extensions = [] @@ -54,7 +53,7 @@ class Extension: repo = None try: if os.path.exists(os.path.join(self.path, ".git")): - repo = git.Repo(self.path) + repo = Repo(self.path) except Exception: print(f"Error reading github repository info from {self.path}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -94,7 +93,7 @@ class Extension: return res def check_updates(self): - repo = git.Repo(self.path) + repo = Repo(self.path) for fetch in repo.remote().fetch(dry_run=True): if fetch.flags != fetch.HEAD_UPTODATE: self.can_update = True @@ -116,7 +115,7 @@ class Extension: self.status = "latest" def fetch_and_reset_hard(self, commit='origin'): - repo = git.Repo(self.path) + repo = Repo(self.path) # Fix: `error: Your local changes to the following files would be overwritten by merge`, # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. repo.git.fetch(all=True) diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py new file mode 100644 index 00000000..e537c1df --- /dev/null +++ b/modules/gitpython_hack.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import io +import subprocess + +import git + + +class Git(git.Git): + """ + Git subclassed to never use persistent processes. + """ + + def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs): + raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})") + + def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]: + ret = subprocess.check_output( + [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"], + input=self._prepare_ref(ref), + cwd=self._working_dir, + timeout=2, + ) + return self._parse_object_header(ret) + + def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]: + # Not really streaming, per se; this buffers the entire object in memory. + # Shouldn't be a problem for our use case, since we're only using this for + # object headers (commit objects). + ret = subprocess.check_output( + [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"], + input=self._prepare_ref(ref), + cwd=self._working_dir, + timeout=30, + ) + bio = io.BytesIO(ret) + hexsha, typename, size = self._parse_object_header(bio.readline()) + return (hexsha, typename, size, self.CatFileContentStream(size, bio)) + + +class Repo(git.Repo): + GitCommandWrapperType = Git diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 515ec262..1c3f5ed9 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -490,8 +490,14 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" def preload_extensions_git_metadata(): + t0 = time.time() for extension in extensions.extensions: extension.read_info_from_repo() + print( + f"preload_extensions_git_metadata for " + f"{len(extensions.extensions)} extensions took " + f"{time.time() - t0:.2f}s" + ) def create_ui(): -- cgit v1.2.3 From 00dfe27f59727407c5b408a80ff2a262934df495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 08:54:13 +0300 Subject: Add & use modules.errors.print_error where currently printing exception info by hand --- extensions-builtin/LDSR/scripts/ldsr_model.py | 7 ++--- extensions-builtin/ScuNET/scripts/scunet_model.py | 6 ++-- modules/api/api.py | 7 +++-- modules/call_queue.py | 22 ++++++-------- modules/codeformer_model.py | 10 +++---- modules/config_states.py | 12 +++----- modules/errors.py | 16 +++++++++++ modules/extensions.py | 10 +++---- modules/gfpgan_model.py | 6 ++-- modules/hypernetworks/hypernetwork.py | 14 ++++----- modules/images.py | 9 ++---- modules/interrogate.py | 5 ++-- modules/launch_utils.py | 7 +++-- modules/localization.py | 6 ++-- modules/processing.py | 2 +- modules/realesrgan_model.py | 14 ++++----- modules/safe.py | 26 +++++++++-------- modules/script_callbacks.py | 9 +++--- modules/script_loading.py | 7 ++--- modules/scripts.py | 35 ++++++++--------------- modules/sd_hijack_optimizations.py | 6 ++-- modules/textual_inversion/textual_inversion.py | 9 ++---- modules/ui.py | 10 +++---- modules/ui_extensions.py | 9 ++---- scripts/prompts_from_file.py | 6 ++-- 25 files changed, 117 insertions(+), 153 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index c4da79f3..95f1669d 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,9 +1,8 @@ import os -import sys -import traceback from basicsr.utils.download_util import load_file_from_url +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks @@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 45d9297b..dd1b822e 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,6 +1,5 @@ import os.path import sys -import traceback import PIL.Image import numpy as np @@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader, script_callbacks from scunet_model_arch import SCUNet as net + +from modules.errors import print_error from modules.shared import opts @@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print(f"Error loading ScuNET model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/modules/api/api.py b/modules/api/api.py index 6a456861..79ce9228 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -16,6 +16,7 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api import models +from modules.errors import print_error from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding @@ -108,7 +109,6 @@ def api_middleware(app: FastAPI): from rich.console import Console console = Console() except Exception: - import traceback rich_available = False @app.middleware("http") @@ -139,11 +139,12 @@ def api_middleware(app: FastAPI): "errors": str(e), } if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions - print(f"API error: {request.method}: {request.url} {err}") + message = f"API error: {request.method}: {request.url} {err}" if rich_available: + print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - traceback.print_exc() + print_error(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") diff --git a/modules/call_queue.py b/modules/call_queue.py index 447bb764..dba2a9b4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,10 +1,9 @@ import html -import sys import threading -import traceback import time from modules import shared, progress +from modules.errors import print_error queue_lock = threading.Lock() @@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): try: res = list(func(*args, **kwargs)) except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {args} {kwargs}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) + # When printing out our debug argument list, + # do not print out more than a 100 KB of text + max_debug_str_len = 131072 + message = "Error completing request" + arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] + if len(arg_str) > max_debug_str_len: + arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" + print_error(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 @@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): return tuple(res) return f - diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ececdbae..76143e9f 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,6 +1,4 @@ import os -import sys -import traceback import cv2 import torch @@ -8,6 +6,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader +from modules.errors import print_error from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,8 +104,8 @@ def setup_model(dirname): restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) + except Exception: + print_error('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -135,7 +134,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print("Error setting up CodeFormer:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index db65bcdb..faeaf28b 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c """ import os -import sys -import traceback import json import time import tqdm @@ -14,6 +12,7 @@ from collections import OrderedDict import git from modules import shared, extensions +from modules.errors import print_error from modules.paths_internal import script_path, config_states_dir @@ -53,8 +52,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -134,8 +132,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -143,8 +140,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/errors.py b/modules/errors.py index da4694f8..41d8dc93 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -1,7 +1,23 @@ import sys +import textwrap import traceback +def print_error( + message: str, + *, + exc_info: bool = False, +) -> None: + """ + Print an error message to stderr, with optional traceback. + """ + for line in message.splitlines(): + print("***", line, file=sys.stderr) + if exc_info: + print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) + print("---") + + def print_error_explanation(message): lines = message.strip().split("\n") max_len = max([len(x) for x in lines]) diff --git a/modules/extensions.py b/modules/extensions.py index 624832a0..369d2584 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,11 +1,10 @@ import os -import sys import threading -import traceback import git from modules import shared +from modules.errors import print_error from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 extensions = [] @@ -56,8 +55,7 @@ class Extension: if os.path.exists(os.path.join(self.path, ".git")): repo = git.Repo(self.path) except Exception: - print(f"Error reading github repository info from {self.path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -72,8 +70,8 @@ class Extension: self.commit_hash = commit.hexsha self.version = self.commit_hash[:8] - except Exception as ex: - print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) + except Exception: + print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 0131dea4..d2f647fe 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import facexlib import gfpgan import modules.face_restoration from modules import paths, shared, devices, modelloader +from modules.errors import print_error model_dir = "GFPGAN" user_path = None @@ -112,5 +111,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print("Error setting up GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up GFPGAN", exc_info=True) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 570b5603..fcc1ef20 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -2,8 +2,6 @@ import datetime import glob import html import os -import sys -import traceback import inspect import modules.textual_inversion.dataset @@ -12,6 +10,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint +from modules.errors import print_error from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -325,17 +324,14 @@ def load_hypernetwork(name): if path is None: return None - hypernetwork = Hypernetwork() - try: + hypernetwork = Hypernetwork() hypernetwork.load(path) + return hypernetwork except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading hypernetwork {path}", exc_info=True) return None - return hypernetwork - def load_hypernetworks(names, multipliers=None): already_loaded = {} @@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print(traceback.format_exc(), file=sys.stderr) + print_error("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index e21e554c..69151bec 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,6 +1,4 @@ import datetime -import sys -import traceback import pytz import io @@ -18,6 +16,7 @@ import json import hashlib from modules import sd_samplers, shared, script_callbacks, errors +from modules.errors import print_error from modules.paths_internal import roboto_ttf_file from modules.shared import opts @@ -464,8 +463,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print(f"Error adding [{pattern}] to filename", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -697,8 +695,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print("Error parsing NovelAI image generation parameters:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index 111b1322..d36e1a5a 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,6 +1,5 @@ import os import sys -import traceback from collections import namedtuple from pathlib import Path import re @@ -12,6 +11,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.errors import print_error blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -216,8 +216,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print("Error interrogating", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 35a52310..22edc106 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -8,6 +8,7 @@ import json from functools import lru_cache from modules import cmd_args +from modules.errors import print_error from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -188,7 +189,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print(e, file=sys.stderr) + print_error(str(e)) def list_extensions(settings_file): @@ -198,8 +199,8 @@ def list_extensions(settings_file): if os.path.isfile(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) - except Exception as e: - print(e, file=sys.stderr) + except Exception: + print_error("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/localization.py b/modules/localization.py index ee9c65e7..9a1df343 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,8 +1,7 @@ import json import os -import sys -import traceback +from modules.errors import print_error localizations = {} @@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print(f"Error loading localization from {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/processing.py b/modules/processing.py index b75f2515..5c9bcce8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,5 @@ import json +import logging import math import os import sys @@ -23,7 +24,6 @@ import modules.images as images import modules.styles import modules.sd_models as sd_models import modules.sd_vae as sd_vae -import logging from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 99983678..c8d0c64f 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import numpy as np from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader @@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print("Error importing Real-ESRGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) return info - except Exception as e: - print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + except Exception: + print_error("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -135,5 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print("Error making Real-ESRGAN models list:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/safe.py b/modules/safe.py index e8f50774..b596f565 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -2,8 +2,6 @@ import pickle import collections -import sys -import traceback import torch import numpy @@ -11,6 +9,8 @@ import _codecs import zipfile import re +from modules.errors import print_error + # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage @@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + "-----> !!!! The file is most likely corrupted !!!! <-----\n" + "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", + exc_info=True, + ) return None - except Exception: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + f"The file may be malicious, so the program is not going to read it.\n" + f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", + exc_info=True, + ) return None return unsafe_torch_load(filename, *args, **kwargs) @@ -190,4 +193,3 @@ with safe.Extra(handler): unsafe_torch_load = torch.load torch.load = load global_extra_handler = None - diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index d2728e12..6aa9c3b6 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,16 +1,15 @@ -import sys -import traceback -from collections import namedtuple import inspect +from collections import namedtuple from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks +from modules.errors import print_error + def report_exception(c, job): - print(f"Error executing callback {job} for {c.script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: diff --git a/modules/script_loading.py b/modules/script_loading.py index 57b15862..26efffcb 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,8 +1,8 @@ import os -import sys -import traceback import importlib.util +from modules.errors import print_error + def load_module(path): module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) @@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print(f"Error running preload() for {preload_script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index c902804b..a7168fd1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,12 +1,12 @@ import os import re import sys -import traceback from collections import namedtuple import gradio as gr from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing +from modules.errors import print_error AlwaysVisible = object() @@ -264,8 +264,7 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath @@ -280,11 +279,9 @@ def load_scripts(): def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: - res = func(*args, **kwargs) - return res + return func(*args, **kwargs) except Exception: - print(f"Error calling: {filename}/{funcname}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -450,8 +447,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print(f"Error running process: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -459,8 +455,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -468,8 +463,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -477,8 +471,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print(f"Error running postprocess: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -486,8 +479,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -495,24 +487,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print(f"Error running before_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print(f"Error running after_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 2ec0b049..fd186fa2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,5 @@ from __future__ import annotations import math -import sys -import traceback import psutil import torch @@ -11,6 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention +from modules.errors import print_error from modules.hypernetworks import hypernetwork import ldm.modules.attention @@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print("Cannot import xformers", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Cannot import xformers", exc_info=True) def get_available_vram(): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d489ed1e..a040a988 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,4 @@ import os -import sys -import traceback from collections import namedtuple import torch @@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint import modules.textual_inversion.dataset +from modules.errors import print_error from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay @@ -207,8 +206,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -632,8 +630,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print(traceback.format_exc(), file=sys.stderr) - pass + print_error("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/ui.py b/modules/ui.py index 001b9792..1ad94f02 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -2,7 +2,6 @@ import json import mimetypes import os import sys -import traceback from functools import reduce import warnings @@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave +from modules.errors import print_error from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: res = all_seeds[index if 0 <= index < len(all_seeds) else 0] except json.decoder.JSONDecodeError: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) + if gen_info_string: + print_error(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -1753,8 +1752,7 @@ def create_ui(): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 515ec262..cadf56be 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,10 +1,8 @@ import json import os.path -import sys import threading import time from datetime import datetime -import traceback import git @@ -14,6 +12,7 @@ import shutil import errno from modules import extensions, shared, paths, config_states +from modules.errors import print_error from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print(f"Error getting updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all @@ -113,8 +111,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print(f"Error checking updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index b918a764..4dc24615 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,13 +1,12 @@ import copy import random -import sys -import traceback import shlex import modules.scripts as scripts import gradio as gr from modules import sd_samplers +from modules.errors import print_error from modules.processing import Processed, process_images from modules.shared import state @@ -136,8 +135,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print(f"Error parsing line {line} as commandline:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} -- cgit v1.2.3 From 05933840f0676dd1a90a7e2ad3f2a0672624b2cd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 31 May 2023 19:56:37 +0300 Subject: rename print_error to report, use it with together with package name --- extensions-builtin/LDSR/scripts/ldsr_model.py | 5 ++--- extensions-builtin/ScuNET/scripts/scunet_model.py | 5 ++--- modules/api/api.py | 5 ++--- modules/call_queue.py | 5 ++--- modules/codeformer_model.py | 7 +++---- modules/config_states.py | 9 ++++----- modules/errors.py | 8 ++------ modules/extensions.py | 7 +++---- modules/gfpgan_model.py | 5 ++--- modules/hypernetworks/hypernetwork.py | 7 +++---- modules/images.py | 5 ++--- modules/interrogate.py | 3 +-- modules/launch_utils.py | 7 +++---- modules/localization.py | 4 ++-- modules/realesrgan_model.py | 10 +++++----- modules/safe.py | 7 ++++--- modules/script_callbacks.py | 4 ++-- modules/script_loading.py | 4 ++-- modules/scripts.py | 23 +++++++++++------------ modules/sd_hijack_optimizations.py | 3 +-- modules/textual_inversion/textual_inversion.py | 7 +++---- modules/ui.py | 7 +++---- modules/ui_extensions.py | 7 +++---- scripts/prompts_from_file.py | 5 ++--- 24 files changed, 69 insertions(+), 90 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 95f1669d..dbd6d331 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -2,10 +2,9 @@ import os from basicsr.utils.download_util import load_file_from_url -from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR -from modules import shared, script_callbacks +from modules import shared, script_callbacks, errors import sd_hijack_autoencoder # noqa: F401 import sd_hijack_ddpm_v1 # noqa: F401 @@ -51,7 +50,7 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) except Exception: - print_error("Error importing LDSR", exc_info=True) + errors.report("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index dd1b822e..85b4505f 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -9,10 +9,9 @@ from tqdm import tqdm from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import devices, modelloader, script_callbacks +from modules import devices, modelloader, script_callbacks, errors from scunet_model_arch import SCUNet as net -from modules.errors import print_error from modules.shared import opts @@ -39,7 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print_error(f"Error loading ScuNET model: {file}", exc_info=True) + errors.report(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/modules/api/api.py b/modules/api/api.py index fbd616a3..d34ab422 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,9 +14,8 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors from modules.api import models -from modules.errors import print_error from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding @@ -145,7 +144,7 @@ def api_middleware(app: FastAPI): print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - print_error(message, exc_info=True) + errors.report(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") diff --git a/modules/call_queue.py b/modules/call_queue.py index dba2a9b4..53af6d70 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -2,8 +2,7 @@ import html import threading import time -from modules import shared, progress -from modules.errors import print_error +from modules import shared, progress, errors queue_lock = threading.Lock() @@ -62,7 +61,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] if len(arg_str) > max_debug_str_len: arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" - print_error(f"{message}\n{arg_str}", exc_info=True) + errors.report(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 76143e9f..4260b016 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -5,8 +5,7 @@ import torch import modules.face_restoration import modules.shared -from modules import shared, devices, modelloader -from modules.errors import print_error +from modules import shared, devices, modelloader, errors from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,7 +104,7 @@ def setup_model(dirname): del output torch.cuda.empty_cache() except Exception: - print_error('Failed inference for CodeFormer', exc_info=True) + errors.report('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -134,6 +133,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print_error("Error setting up CodeFormer", exc_info=True) + errors.report("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index faeaf28b..6f1ab53f 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -11,8 +11,7 @@ from datetime import datetime from collections import OrderedDict import git -from modules import shared, extensions -from modules.errors import print_error +from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir @@ -52,7 +51,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print_error(f"Error reading webui git info from {script_path}", exc_info=True) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -132,7 +131,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print_error(f"Error reading webui git info from {script_path}", exc_info=True) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -140,7 +139,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print_error(f"Error restoring webui to commit{webui_commit_hash}") + errors.report(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/errors.py b/modules/errors.py index 41d8dc93..e408f500 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -3,11 +3,7 @@ import textwrap import traceback -def print_error( - message: str, - *, - exc_info: bool = False, -) -> None: +def report(message: str, *, exc_info: bool = False) -> None: """ Print an error message to stderr, with optional traceback. """ @@ -15,7 +11,7 @@ def print_error( print("***", line, file=sys.stderr) if exc_info: print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) - print("---") + print("---", file=sys.stderr) def print_error_explanation(message): diff --git a/modules/extensions.py b/modules/extensions.py index 92f93ad9..8608584b 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,8 +1,7 @@ import os import threading -from modules import shared -from modules.errors import print_error +from modules import shared, errors from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -54,7 +53,7 @@ class Extension: if os.path.exists(os.path.join(self.path, ".git")): repo = Repo(self.path) except Exception: - print_error(f"Error reading github repository info from {self.path}", exc_info=True) + errors.report(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -70,7 +69,7 @@ class Extension: self.version = self.commit_hash[:8] except Exception: - print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) + errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index d2f647fe..e239a09d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -4,8 +4,7 @@ import facexlib import gfpgan import modules.face_restoration -from modules import paths, shared, devices, modelloader -from modules.errors import print_error +from modules import paths, shared, devices, modelloader, errors model_dir = "GFPGAN" user_path = None @@ -111,4 +110,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print_error("Error setting up GFPGAN", exc_info=True) + errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fcc1ef20..5d12b449 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -9,8 +9,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint -from modules.errors import print_error +from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -329,7 +328,7 @@ def load_hypernetwork(name): hypernetwork.load(path) return hypernetwork except Exception: - print_error(f"Error loading hypernetwork {path}", exc_info=True) + errors.report(f"Error loading hypernetwork {path}", exc_info=True) return None @@ -766,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print_error("Exception in training hypernetwork", exc_info=True) + errors.report("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index 09f728df..30e9ffc5 100644 --- a/modules/images.py +++ b/modules/images.py @@ -16,7 +16,6 @@ import json import hashlib from modules import sd_samplers, shared, script_callbacks, errors -from modules.errors import print_error from modules.paths_internal import roboto_ttf_file from modules.shared import opts @@ -463,7 +462,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print_error(f"Error adding [{pattern}] to filename", exc_info=True) + errors.report(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -698,7 +697,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print_error("Error parsing NovelAI image generation parameters", exc_info=True) + errors.report("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index d36e1a5a..9b2c5b60 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,7 +11,6 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.errors import print_error blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -216,7 +215,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print_error("Error interrogating", exc_info=True) + errors.report("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0bf4cb7e..6e9bb770 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -7,8 +7,7 @@ import platform import json from functools import lru_cache -from modules import cmd_args -from modules.errors import print_error +from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -189,7 +188,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print_error(str(e)) + errors.report(str(e)) def list_extensions(settings_file): @@ -200,7 +199,7 @@ def list_extensions(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) except Exception: - print_error("Could not load settings", exc_info=True) + errors.report("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/localization.py b/modules/localization.py index 9a1df343..e8f585da 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,7 +1,7 @@ import json import os -from modules.errors import print_error +from modules import errors localizations = {} @@ -30,6 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print_error(f"Error loading localization from {fn}", exc_info=True) + errors.report(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index c8d0c64f..2d27b321 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -5,10 +5,10 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer -from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts -from modules import modelloader +from modules import modelloader, errors + class UpscalerRealESRGAN(Upscaler): def __init__(self, path): @@ -35,7 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print_error("Error importing Real-ESRGAN", exc_info=True) + errors.report("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -75,7 +75,7 @@ class UpscalerRealESRGAN(Upscaler): return info except Exception: - print_error("Error making Real-ESRGAN models list", exc_info=True) + errors.report("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -132,4 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print_error("Error making Real-ESRGAN models list", exc_info=True) + errors.report("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/safe.py b/modules/safe.py index b596f565..b1d08a79 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -9,9 +9,10 @@ import _codecs import zipfile import re -from modules.errors import print_error # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +from modules import errors + TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage def encode(*args): @@ -136,7 +137,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print_error( + errors.report( f"Error verifying pickled file from {filename}\n" "-----> !!!! The file is most likely corrupted !!!! <-----\n" "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", @@ -144,7 +145,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): ) return None except Exception: - print_error( + errors.report( f"Error verifying pickled file from {filename}\n" f"The file may be malicious, so the program is not going to read it.\n" f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6aa9c3b6..ec1469d0 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -5,11 +5,11 @@ from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks -from modules.errors import print_error +from modules import errors def report_exception(c, job): - print_error(f"Error executing callback {job} for {c.script}", exc_info=True) + errors.report(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: diff --git a/modules/script_loading.py b/modules/script_loading.py index 26efffcb..306a1f35 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,7 +1,7 @@ import os import importlib.util -from modules.errors import print_error +from modules import errors def load_module(path): @@ -27,4 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print_error(f"Error running preload() for {preload_script}", exc_info=True) + errors.report(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index a7168fd1..0970f38e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -5,8 +5,7 @@ from collections import namedtuple import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing -from modules.errors import print_error +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors AlwaysVisible = object() @@ -264,7 +263,7 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print_error(f"Error loading script: {scriptfile.filename}", exc_info=True) + errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath @@ -281,7 +280,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: return func(*args, **kwargs) except Exception: - print_error(f"Error calling: {filename}/{funcname}", exc_info=True) + errors.report(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -447,7 +446,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print_error(f"Error running process: {script.filename}", exc_info=True) + errors.report(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -455,7 +454,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print_error(f"Error running before_process_batch: {script.filename}", exc_info=True) + errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -463,7 +462,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print_error(f"Error running process_batch: {script.filename}", exc_info=True) + errors.report(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -471,7 +470,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print_error(f"Error running postprocess: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -479,7 +478,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -487,21 +486,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print_error(f"Error running postprocess_image: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print_error(f"Error running before_component: {script.filename}", exc_info=True) + errors.report(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print_error(f"Error running after_component: {script.filename}", exc_info=True) + errors.report(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index fd186fa2..5f0ff513 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,6 @@ from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention -from modules.errors import print_error from modules.hypernetworks import hypernetwork import ldm.modules.attention @@ -139,7 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print_error("Cannot import xformers", exc_info=True) + errors.report("Cannot import xformers", exc_info=True) def get_available_vram(): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b3dcb140..8da050ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,9 +12,8 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors import modules.textual_inversion.dataset -from modules.errors import print_error from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay @@ -219,7 +218,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print_error(f"Error loading embedding {fn}", exc_info=True) + errors.report(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -643,7 +642,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print_error("Error training embedding", exc_info=True) + errors.report("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/ui.py b/modules/ui.py index fb6b2498..f361264c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,8 +12,7 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave -from modules.errors import print_error +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -232,7 +231,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: except json.decoder.JSONDecodeError: if gen_info_string: - print_error(f"Error parsing JSON generation info: {gen_info_string}") + errors.report(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -1752,7 +1751,7 @@ def create_ui(): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print_error("Error loading/saving model file", exc_info=True) + errors.report("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index e2ee9d72..3140ed64 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -11,8 +11,7 @@ import html import shutil import errno -from modules import extensions, shared, paths, config_states -from modules.errors import print_error +from modules import extensions, shared, paths, config_states, errors from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -45,7 +44,7 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print_error(f"Error getting updates for {ext.name}", exc_info=True) + errors.report(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all @@ -111,7 +110,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print_error(f"Error checking updates for {ext.name}", exc_info=True) + errors.report(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 4dc24615..83a2f220 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -5,8 +5,7 @@ import shlex import modules.scripts as scripts import gradio as gr -from modules import sd_samplers -from modules.errors import print_error +from modules import sd_samplers, errors from modules.processing import Processed, process_images from modules.shared import state @@ -135,7 +134,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print_error(f"Error parsing line {line} as commandline", exc_info=True) + errors.report(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} -- cgit v1.2.3 From 51864790fd72386fbbbb015d24a43ce501ecaa4b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 2 Jun 2023 14:58:10 +0300 Subject: Simplify a bunch of `len(x) > 0`/`len(x) == 0` style expressions --- extensions-builtin/LDSR/sd_hijack_autoencoder.py | 3 ++- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 4 ++-- extensions-builtin/Lora/extra_networks_lora.py | 4 ++-- extensions-builtin/Lora/lora.py | 4 ++-- .../extra-options-section/scripts/extra_options_section.py | 2 +- modules/api/api.py | 2 +- modules/call_queue.py | 2 +- modules/extra_networks_hypernet.py | 4 ++-- modules/generation_parameters_copypaste.py | 6 ++---- modules/images.py | 6 +++--- modules/img2img.py | 3 +-- modules/models/diffusion/ddpm_edit.py | 4 ++-- modules/processing.py | 3 ++- modules/prompt_parser.py | 6 +++--- modules/script_callbacks.py | 4 ++-- modules/sd_hijack_clip.py | 2 +- modules/sd_hijack_clip_old.py | 2 +- modules/textual_inversion/autocrop.py | 14 +++++++------- modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/preprocess.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- modules/ui.py | 2 +- modules/ui_extensions.py | 5 +++-- modules/ui_settings.py | 2 +- scripts/prompts_from_file.py | 3 +-- 25 files changed, 47 insertions(+), 48 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 27a86e13..c29d274d 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -91,8 +91,9 @@ class VQModel(pl.LightningModule): del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") + if unexpected: print(f"Unexpected Keys: {unexpected}") def on_train_batch_end(self, *args, **kwargs): diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 631a08ef..04adc5eb 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index b5fea4d2..66ee9c85 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index eec14712..af93991c 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk): else: raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") - if len(keys_failed_to_match) > 0: + if keys_failed_to_match: print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") return lora @@ -267,7 +267,7 @@ def load_loras(names, multipliers=None): lora.multiplier = multipliers[i] if multipliers else 1.0 loaded_loras.append(lora) - if len(failed_to_load_loras) > 0: + if failed_to_load_loras: sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 17f84184..a05e10d8 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -21,7 +21,7 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row(): + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): for setting_name in shared.opts.extra_options: with FormColumn(): comp = ui_settings.create_setting_component(setting_name) diff --git a/modules/api/api.py b/modules/api/api.py index d34ab422..555eefdb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -280,7 +280,7 @@ class Api: script_args[0] = selectable_idx + 1 # Now check for always on scripts - if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + if request.alwayson_scripts: for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script is None: diff --git a/modules/call_queue.py b/modules/call_queue.py index 53af6d70..1b5e5273 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -21,7 +21,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id - if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): id_task = args[0] progress.add_task_to_queue(id_task) else: diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index aa2a14ef..b6a6dc0e 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_hypernetwork - if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional): hypernet_prompt_text = f"" p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 071bd9ea..237401a1 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -55,7 +55,7 @@ def image_from_url_text(filedata): if filedata is None: return None - if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): + if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] if type(filedata) == dict and filedata.get("is_file", False): @@ -437,7 +437,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, vals_pairs = [f"{k}: {v}" for k, v in vals.items()] - return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) paste_fields = paste_fields + [(override_settings_component, paste_settings)] @@ -454,5 +454,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, outputs=[], show_progress=False, ) - - diff --git a/modules/images.py b/modules/images.py index a12d252b..7bbfc3e0 100644 --- a/modules/images.py +++ b/modules/images.py @@ -406,7 +406,7 @@ class FilenameGenerator: prompt_no_style = self.prompt for style in shared.prompt_styles.get_style_prompts(self.p.styles): - if len(style) > 0: + if style: for part in style.split("{prompt}"): prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') @@ -415,7 +415,7 @@ class FilenameGenerator: return sanitize_filename_part(prompt_no_style, replace_spaces=False) def prompt_words(self): - words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0] + words = [x for x in re_nonletters.split(self.prompt or "") if x] if len(words) == 0: words = ["empty"] return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False) @@ -423,7 +423,7 @@ class FilenameGenerator: def datetime(self, *args): time_datetime = datetime.datetime.now() - time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format + time_format = args[0] if (args and args[0] != "") else self.default_time_format try: time_zone = pytz.timezone(args[1]) if len(args) > 1 else None except pytz.exceptions.UnknownTimeZoneError: diff --git a/modules/img2img.py b/modules/img2img.py index 4c12c2c5..35c4facc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -21,8 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): is_inpaint_batch = False if inpaint_mask_dir: inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = len(inpaint_masks) > 0 - if is_inpaint_batch: + is_inpaint_batch = bool(inpaint_masks) print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 3fb76b65..b892d5fc 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -230,9 +230,9 @@ class DDPM(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/modules/processing.py b/modules/processing.py index 362ab4c2..9ebdb549 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -975,7 +975,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and latent_scale_mode is None: - assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b4aff704..0069d8b0 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -336,11 +336,11 @@ def parse_prompt_attention(text): round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: + elif weight is not None and round_brackets: multiply_range(round_brackets.pop(), float(weight)) - elif text == ')' and len(round_brackets) > 0: + elif text == ')' and round_brackets: multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == ']' and len(square_brackets) > 0: + elif text == ']' and square_brackets: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: parts = re.split(re_break, text) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index f755283c..77ee55ee 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -287,14 +287,14 @@ def list_unets_callback(): def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' callbacks.append(ScriptCallback(filename, fun)) def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' if filename == 'unknown file': return for callback_list in callback_map.values(): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index cc6e8c21..3b5a7666 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0 or len(chunks) == 0: + if chunk.tokens or not chunks: next_chunk(is_last=True) return chunks, token_count diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index a3476e95..c5c6270b 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text self.hijack.comments += hijack_comments - if len(used_custom_terms) > 0: + if used_custom_terms: embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) self.hijack.comments.append(f"Used embeddings: {embedding_names}") diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 8e667a4d..75705459 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -77,27 +77,27 @@ def focal_point(im, settings): pois = [] weight_pref_total = 0 - if len(corner_points) > 0: + if corner_points: weight_pref_total += settings.corner_points_weight - if len(entropy_points) > 0: + if entropy_points: weight_pref_total += settings.entropy_points_weight - if len(face_points) > 0: + if face_points: weight_pref_total += settings.face_points_weight corner_centroid = None - if len(corner_points) > 0: + if corner_points: corner_centroid = centroid(corner_points) corner_centroid.weight = settings.corner_points_weight / weight_pref_total pois.append(corner_centroid) entropy_centroid = None - if len(entropy_points) > 0: + if entropy_points: entropy_centroid = centroid(entropy_points) entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total pois.append(entropy_centroid) face_centroid = None - if len(face_points) > 0: + if face_points: face_centroid = centroid(face_points) face_centroid.weight = settings.face_points_weight / weight_pref_total pois.append(face_centroid) @@ -187,7 +187,7 @@ def image_face_points(im, settings): except Exception: continue - if len(faces) > 0: + if faces: rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [] diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b9621fc9..7ee05061 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -32,7 +32,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None self.placeholder_token = placeholder_token diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a009d8e8..0d4c3f84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption += shared.interrogator.generate_caption(image) if params.process_caption_deepbooru: - if len(caption) > 0: + if caption: caption += ", " caption += deepbooru.model.tag_multi(image) @@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption = caption.strip() - if len(caption) > 0: + if caption: with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8da050ca..bb6f211c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -251,7 +251,7 @@ class EmbeddingDatabase: if self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: + if self.skipped_embeddings: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): diff --git a/modules/ui.py b/modules/ui.py index b7459f08..9a025cca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -398,7 +398,7 @@ def create_override_settings_dropdown(tabname, row): dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) dropdown.change( - fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + fn=lambda x: gr.Dropdown.update(visible=bool(x)), inputs=[dropdown], outputs=[dropdown], ) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 3140ed64..65173e06 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -333,7 +333,8 @@ def install_extension_from_url(dirname, url, branch_name=None): assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' normalized_url = normalize_git_url(url) - assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' + if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url): + raise Exception(f'Extension with this URL is already installed: {url}') tmpdir = os.path.join(paths.data_path, "tmp", dirname) @@ -449,7 +450,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" existing = installed_extension_urls.get(normalize_git_url(url), None) extension_tags = extension_tags + ["installed"] if existing else extension_tags - if len([x for x in extension_tags if x in tags_to_hide]) > 0: + if any(x for x in extension_tags if x in tags_to_hide): hidden += 1 continue diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 7874298e..2688d8c2 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -81,7 +81,7 @@ class UiSettings: opts.save(shared.config_filename) except RuntimeError: return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.' def run_settings_single(self, value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 83a2f220..50320d55 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -121,8 +121,7 @@ class Script(scripts.Script): return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): - lines = [x.strip() for x in prompt_txt.splitlines()] - lines = [x for x in lines if len(x) > 0] + lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True -- cgit v1.2.3 From 333e63c0911c148ea306d7b72580d5c6d2f2c41a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Jun 2023 09:59:56 +0300 Subject: a yet another method to restart webui --- modules/launch_utils.py | 6 ++++++ modules/shared.py | 9 +++++++++ modules/ui_extensions.py | 2 +- webui.bat | 2 ++ webui.sh | 34 ++++++++++++++++++++-------------- 5 files changed, 38 insertions(+), 15 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0c8c4db0..af8d8b37 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -244,6 +244,12 @@ def prepare_environment(): codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") + try: + # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution + os.remove(os.path.join(script_path, "tmp", "restart")) + except OSError: + pass + if not args.skip_python_version_check: check_python_version() diff --git a/modules/shared.py b/modules/shared.py index 7025a754..c4c719ad 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -841,3 +841,12 @@ def walk_files(path, allowed_extensions=None): continue yield os.path.join(root, filename) + + +def restart_program(): + """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" + + with open(os.path.join(script_path, "tmp", "restart"), "w"): + pass + + os._exit(0) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 3140ed64..5580dfaf 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -49,7 +49,7 @@ def apply_and_restart(disable_list, update_list, disable_all): shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all shared.opts.save(shared.config_filename) - shared.state.request_restart() + shared.restart_program() def save_config_state(name): diff --git a/webui.bat b/webui.bat index 209d972b..961fc7d4 100644 --- a/webui.bat +++ b/webui.bat @@ -51,12 +51,14 @@ if EXIST %ACCELERATE% goto :accelerate_launch :launch %PYTHON% launch.py %* +if EXIST tmp/restart goto :skip_venv pause exit /b :accelerate_launch echo Accelerating %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py +if EXIST tmp/restart goto :skip_venv pause exit /b diff --git a/webui.sh b/webui.sh index 1e728813..c407b3ef 100755 --- a/webui.sh +++ b/webui.sh @@ -203,17 +203,23 @@ prepare_tcmalloc() { fi } -if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] -then - printf "\n%s\n" "${delimiter}" - printf "Accelerating launch.py..." - printf "\n%s\n" "${delimiter}" - prepare_tcmalloc - exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" -else - printf "\n%s\n" "${delimiter}" - printf "Launching launch.py..." - printf "\n%s\n" "${delimiter}" - prepare_tcmalloc - exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" -fi +KEEP_GOING=1 +while [[ "$KEEP_GOING" -eq "1" ]]; do + if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]; then + printf "\n%s\n" "${delimiter}" + printf "Accelerating launch.py..." + printf "\n%s\n" "${delimiter}" + prepare_tcmalloc + accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" + else + printf "\n%s\n" "${delimiter}" + printf "Launching launch.py..." + printf "\n%s\n" "${delimiter}" + prepare_tcmalloc + "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + fi + + if [[ ! -f tmp/restart ]]; then + KEEP_GOING=0 + fi +done -- cgit v1.2.3 From 46a5bd64edece07f521409f0adace6cdd6f30a40 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 5 Jun 2023 20:04:28 +0300 Subject: Restart: only do restart if running via the wrapper script --- modules/restart.py | 23 +++++++++++++++++++++++ modules/shared.py | 9 --------- modules/ui_extensions.py | 11 ++++++++--- webui.bat | 2 +- webui.sh | 1 + 5 files changed, 33 insertions(+), 13 deletions(-) create mode 100644 modules/restart.py (limited to 'modules/ui_extensions.py') diff --git a/modules/restart.py b/modules/restart.py new file mode 100644 index 00000000..18eacaf3 --- /dev/null +++ b/modules/restart.py @@ -0,0 +1,23 @@ +import os +from pathlib import Path + +from modules.paths_internal import script_path + + +def is_restartable() -> bool: + """ + Return True if the webui is restartable (i.e. there is something watching to restart it with) + """ + return bool(os.environ.get('SD_WEBUI_RESTART')) + + +def restart_program() -> None: + """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" + + (Path(script_path) / "tmp" / "restart").touch() + + stop_program() + + +def stop_program() -> None: + os._exit(0) diff --git a/modules/shared.py b/modules/shared.py index 2bd7c6ec..c9ee2dd1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -853,12 +853,3 @@ def walk_files(path, allowed_extensions=None): continue yield os.path.join(root, filename) - - -def restart_program(): - """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" - - with open(os.path.join(script_path, "tmp", "restart"), "w"): - pass - - os._exit(0) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 5580dfaf..3d216912 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -11,7 +11,7 @@ import html import shutil import errno -from modules import extensions, shared, paths, config_states, errors +from modules import extensions, shared, paths, config_states, errors, restart from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -49,7 +49,11 @@ def apply_and_restart(disable_list, update_list, disable_all): shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all shared.opts.save(shared.config_filename) - shared.restart_program() + + if restart.is_restartable(): + restart.restart_program() + else: + restart.stop_program() def save_config_state(name): @@ -508,7 +512,8 @@ def create_ui(): with gr.TabItem("Installed", id="installed"): with gr.Row(elem_id="extensions_installed_top"): - apply = gr.Button(value="Apply and restart UI", variant="primary") + apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit") + apply = gr.Button(value=apply_label, variant="primary") check = gr.Button(value="Check for updates") extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all") extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) diff --git a/webui.bat b/webui.bat index 961fc7d4..42e7d517 100644 --- a/webui.bat +++ b/webui.bat @@ -3,7 +3,7 @@ if not defined PYTHON (set PYTHON=python) if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") - +set SD_WEBUI_RESTART=tmp/restart set ERROR_REPORTING=FALSE mkdir tmp 2>NUL diff --git a/webui.sh b/webui.sh index c407b3ef..6c48e969 100755 --- a/webui.sh +++ b/webui.sh @@ -204,6 +204,7 @@ prepare_tcmalloc() { } KEEP_GOING=1 +export SD_WEBUI_RESTART=tmp/restart while [[ "$KEEP_GOING" -eq "1" ]]; do if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]; then printf "\n%s\n" "${delimiter}" -- cgit v1.2.3 From d5a5f2f29fc1c2b5f7dda0af5b984264744478b5 Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Sun, 25 Jun 2023 01:31:02 +0800 Subject: Strip whitespaces from URL and dirname prior to extension installation This avoid some cryptic errors brought by accidental spaces around urls --- modules/ui_extensions.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 515ec262..6e5222ba 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -325,6 +325,11 @@ def normalize_git_url(url): def install_extension_from_url(dirname, url, branch_name=None): check_access() + if isinstance(dirname, str): + dirname = dirname.strip() + if isinstance(url, str): + url = url.strip() + assert url, 'No URL specified' if dirname is None or dirname == "": -- cgit v1.2.3 From dd268c48c9099c4cf308eb04590bd201c9b64253 Mon Sep 17 00:00:00 2001 From: "Martín (Netux) Rodríguez" Date: Sun, 25 Jun 2023 00:30:08 -0300 Subject: feat(extensions): add toggle all checkbox to Installed tab Small QoL addition. While there is the option to disable all extensions with the radio buttons at the top, that only acts as an added flag and doesn't really change the state of the extensions in the UI. An use case for this checkbox is to disable all extensions except for a few, which is important for debugging extensions. You could do that before, but you'd have to uncheck and recheck every extension one by one. --- javascript/extensions.js | 18 ++++++++++++++++++ modules/ui_extensions.py | 7 +++++-- 2 files changed, 23 insertions(+), 2 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/javascript/extensions.js b/javascript/extensions.js index efeaf3a5..1f7254c5 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type) } return [confirmed, config_state_name, config_restore_type]; } + +function toggle_all_extensions(event) { + gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) { + checkbox_el.checked = event.target.checked; + }); +} + +function toggle_extension() { + let all_extensions_toggled = true; + for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) { + if (!checkbox_el.checked) { + all_extensions_toggled = false; + break; + } + } + + gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled; +} diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 4379a641..50955fab 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -138,7 +138,10 @@ def extension_table(): - + @@ -170,7 +173,7 @@ def extension_table(): code += f""" - + -- cgit v1.2.3 From 9bb1fcfad43103778406ace17e6804c67fad9c17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 08:59:35 +0300 Subject: alternate fix for catch errors when retrieving extension index #11290 --- modules/ui_extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index f3db76f2..278bf5e4 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -571,9 +571,9 @@ def create_ui(): available_extensions_table = gr.HTML() refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]), inputs=[available_extensions_index, hide_tags, sort_column], - outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text], + outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result], ) install_extension_button.click( -- cgit v1.2.3 From d47324b898d057c0f854b9be891f2483a2b7001f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 19:25:18 +0900 Subject: add stars --- modules/ui_extensions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 278bf5e4..ac239d64 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -424,6 +424,7 @@ sort_ordering = [ (False, lambda x: x.get('name', 'z')), (True, lambda x: x.get('name', 'z')), (False, lambda x: 'z'), + (True, lambda x: x.get('stars', 0)), ] @@ -451,6 +452,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): name = ext.get("name", "noname") + stars = int(ext.get("stars", 0)) added = ext.get('added', 'unknown') url = ext.get("url", None) description = ext.get("description", "") @@ -478,7 +480,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" code += f""" - + @@ -562,7 +564,7 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", "stars"], type="index") with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) -- cgit v1.2.3 From 2ccc832b3333fe520961466aa1f05b24aafdd792 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 22:46:59 +0900 Subject: add extensions Update Created dates with sorting --- modules/ui_extensions.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'modules/ui_extensions.py') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ac239d64..dff522ef 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -424,10 +424,19 @@ sort_ordering = [ (False, lambda x: x.get('name', 'z')), (True, lambda x: x.get('name', 'z')), (False, lambda x: 'z'), + (True, lambda x: x.get('commit_time', '')), + (True, lambda x: x.get('created_at', '')), (True, lambda x: x.get('stars', 0)), ] +def get_date(info: dict, key): + try: + return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d") + except (ValueError, TypeError): + return '' + + def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} @@ -454,6 +463,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" name = ext.get("name", "noname") stars = int(ext.get("stars", 0)) added = ext.get('added', 'unknown') + update_time = get_date(ext, 'commit_time') + create_time = get_date(ext, 'created_at') url = ext.get("url", None) description = ext.get("description", "") extension_tags = ext.get("tags", []) @@ -480,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" code += f""" - + @@ -564,7 +576,7 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", "stars"], type="index") + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) -- cgit v1.2.3
Extension + + Extension + URL Branch Version
{html.escape(ext.name)}{html.escape(ext.name)} {remote} {ext.branch} {version_link}
{html.escape(name)}
{tags_text}
{html.escape(description)}

Added: {html.escape(added)}

{html.escape(description)}

Added: {html.escape(added)}stars: {stars:,}

{install_code}
{html.escape(name)}
{tags_text}
{html.escape(description)}

Added: {html.escape(added)}stars: {stars:,}

{html.escape(description)}

+ Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}stars: {stars}

{install_code}