From 762265eab58cdb8f2d6398769bab43d8b8db0075 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 07:52:45 +0300 Subject: autofixes from ruff --- modules/api/api.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 9bb95dfd..d47c39fc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -60,7 +60,7 @@ def decode_base64_to_image(encoding): try: image = Image.open(BytesIO(base64.b64decode(encoding))) return image - except Exception as err: + except Exception: raise HTTPException(status_code=500, detail="Invalid encoded image") def encode_pil_to_base64(image): @@ -264,11 +264,11 @@ class Api: if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) - if alwayson_script == None: + if alwayson_script is None: raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") # Selectable script in always on script param check - if alwayson_script.alwayson == False: - raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + if alwayson_script.alwayson is False: + raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params") # always on script with no arg should always run so you don't really need to add them to the requests if "args" in request.alwayson_scripts[alwayson_script_name]: # min between arg length in scriptrunner and arg length in the request @@ -310,7 +310,7 @@ class Api: p.outpath_samples = opts.outdir_txt2img_samples shared.state.begin() - if selectable_scripts != None: + if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: @@ -367,7 +367,7 @@ class Api: p.outpath_samples = opts.outdir_img2img_samples shared.state.begin() - if selectable_scripts != None: + if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: @@ -642,7 +642,7 @@ class Api: sd_hijack.apply_optimizations() shared.state.end() return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError as msg: + except AssertionError: shared.state.end() return TrainResponse(info=f"train embedding error: {error}") -- cgit v1.2.3 From 96d6ca4199e7c5eee8d451618de5161cea317c40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:25:25 +0300 Subject: manual fixes for ruff --- extensions-builtin/LDSR/ldsr_model_arch.py | 2 +- extensions-builtin/LDSR/scripts/ldsr_model.py | 3 +- extensions-builtin/LDSR/sd_hijack_autoencoder.py | 10 +- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 26 ++--- extensions-builtin/ScuNET/scunet_model_arch.py | 9 +- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- modules/api/api.py | 129 +++++++++++----------- modules/api/models.py | 5 +- modules/codeformer/codeformer_arch.py | 2 +- modules/esrgan_model_arch.py | 2 + modules/extra_networks_hypernet.py | 2 +- modules/images.py | 4 +- modules/img2img.py | 1 - modules/interrogate.py | 1 - modules/modelloader.py | 6 +- modules/models/diffusion/ddpm_edit.py | 26 ++--- modules/models/diffusion/uni_pc/sampler.py | 3 +- modules/processing.py | 2 +- modules/prompt_parser.py | 11 +- modules/textual_inversion/autocrop.py | 2 +- modules/ui.py | 8 +- modules/upscaler.py | 2 +- 22 files changed, 129 insertions(+), 129 deletions(-) (limited to 'modules/api/api.py') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 2339de7f..a5fb8907 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -243,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) log["sample_noquant"] = x_sample_noquant log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except: + except Exception: pass log["sample"] = x_sample diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index da19cff1..e8dc083c 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks -import sd_hijack_autoencoder, sd_hijack_ddpm_v1 +import sd_hijack_autoencoder +import sd_hijack_ddpm_v1 class UpscalerLDSR(Upscaler): diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index db2231dd..6303fed5 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -1,16 +1,21 @@ # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder - +import numpy as np import torch import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager + +from torch.optim.lr_scheduler import LambdaLR + +from ldm.modules.ema import LitEma from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config import ldm.models.autoencoder +from packaging import version class VQModel(pl.LightningModule): def __init__(self, @@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 5c0488e5..4d3f6c56 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -450,7 +450,7 @@ class LatentDiffusionV1(DDPMV1): self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py index 43ca8d36..8028918a 100644 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -61,7 +61,9 @@ class WMSA(nn.Module): Returns: output: tensor shape [b h w c] """ - if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + if self.type != 'W': + x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) h_windows = x.size(1) w_windows = x.size(2) @@ -85,8 +87,9 @@ class WMSA(nn.Module): output = self.linear(output) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), - dims=(1, 2)) + if self.type != 'W': + output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) + return output def relative_embedding(self): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index e8783bca..d77c3a92 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -45,7 +45,7 @@ class UpscalerSwinIR(Upscaler): img = upscale(img, model) try: torch.cuda.empty_cache() - except: + except Exception: pass return img diff --git a/modules/api/api.py b/modules/api/api.py index d47c39fc..f52d371b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -15,7 +15,8 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing -from modules.api.models import * +from modules.api import models +from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess @@ -25,20 +26,21 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices -from typing import List +from typing import Dict, List, Any import piexif import piexif.helper + def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}") + except Exception: + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") def script_name_to_index(name, scripts): try: return [script.title().lower() for script in scripts].index(name.lower()) - except: + except Exception: raise HTTPException(status_code=422, detail=f"Script '{name}' not found") def validate_sampler_name(name): @@ -99,7 +101,7 @@ def api_middleware(app: FastAPI): import starlette # importing just so it can be placed on silent list from rich.console import Console console = Console() - except: + except Exception: import traceback rich_available = False @@ -166,36 +168,36 @@ class Api: self.app = app self.queue_lock = queue_lock api_middleware(self.app) - self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) - self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) - self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) + self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse) + self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse) + self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) + self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) + self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) + self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) - self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) + self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) - self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) - self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) - self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) - self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) - self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) - self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) - self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) + self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) - self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) - self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) - self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse) - self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) - self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) - self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) + self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) + self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) + self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) + self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) + self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) + self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) - self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) + self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -224,7 +226,7 @@ class Api: t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] - return ScriptsList(txt2img = t2ilist, img2img = i2ilist) + return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist) def get_script(self, script_name, script_runner): if script_name is None or script_name == "": @@ -276,7 +278,7 @@ class Api: script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args - def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -320,9 +322,9 @@ class Api: b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] - return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) + return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) - def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): + def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -381,9 +383,9 @@ class Api: img2imgreq.init_images = None img2imgreq.mask = None - return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) + return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) - def extras_single_image_api(self, req: ExtrasSingleImageRequest): + def extras_single_image_api(self, req: models.ExtrasSingleImageRequest): reqDict = setUpscalers(req) reqDict['image'] = decode_base64_to_image(reqDict['image']) @@ -391,9 +393,9 @@ class Api: with self.queue_lock: result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) - return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) + return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) - def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): + def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest): reqDict = setUpscalers(req) image_list = reqDict.pop('imageList', []) @@ -402,15 +404,15 @@ class Api: with self.queue_lock: result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) + return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - def pnginfoapi(self, req: PNGInfoRequest): + def pnginfoapi(self, req: models.PNGInfoRequest): if(not req.image.strip()): - return PNGInfoResponse(info="") + return models.PNGInfoResponse(info="") image = decode_base64_to_image(req.image.strip()) if image is None: - return PNGInfoResponse(info="") + return models.PNGInfoResponse(info="") geninfo, items = images.read_info_from_image(image) if geninfo is None: @@ -418,13 +420,13 @@ class Api: items = {**{'parameters': geninfo}, **items} - return PNGInfoResponse(info=geninfo, items=items) + return models.PNGInfoResponse(info=geninfo, items=items) - def progressapi(self, req: ProgressRequest = Depends()): + def progressapi(self, req: models.ProgressRequest = Depends()): # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) # avoid dividing zero progress = 0.01 @@ -446,9 +448,9 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) - def interrogateapi(self, interrogatereq: InterrogateRequest): + def interrogateapi(self, interrogatereq: models.InterrogateRequest): image_b64 = interrogatereq.image if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") @@ -465,7 +467,7 @@ class Api: else: raise HTTPException(status_code=404, detail="Model not found") - return InterrogateResponse(caption=processed) + return models.InterrogateResponse(caption=processed) def interruptapi(self): shared.state.interrupt() @@ -570,36 +572,36 @@ class Api: filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used shared.state.end() - return CreateResponse(info=f"create embedding filename: {filename}") + return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: shared.state.end() - return TrainResponse(info=f"create embedding error: {e}") + return models.TrainResponse(info=f"create embedding error: {e}") def create_hypernetwork(self, args: dict): try: shared.state.begin() filename = create_hypernetwork(**args) # create empty embedding shared.state.end() - return CreateResponse(info=f"create hypernetwork filename: {filename}") + return models.CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: shared.state.end() - return TrainResponse(info=f"create hypernetwork error: {e}") + return models.TrainResponse(info=f"create hypernetwork error: {e}") def preprocess(self, args: dict): try: shared.state.begin() preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() - return PreprocessResponse(info = 'preprocess complete') + return models.PreprocessResponse(info = 'preprocess complete') except KeyError as e: shared.state.end() - return PreprocessResponse(info=f"preprocess error: invalid token: {e}") + return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") except AssertionError as e: shared.state.end() - return PreprocessResponse(info=f"preprocess error: {e}") + return models.PreprocessResponse(info=f"preprocess error: {e}") except FileNotFoundError as e: shared.state.end() - return PreprocessResponse(info=f'preprocess error: {e}') + return models.PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: @@ -617,10 +619,10 @@ class Api: if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") + return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") except AssertionError as msg: shared.state.end() - return TrainResponse(info=f"train embedding error: {msg}") + return models.TrainResponse(info=f"train embedding error: {msg}") def train_hypernetwork(self, args: dict): try: @@ -641,14 +643,15 @@ class Api: if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") + return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") except AssertionError: shared.state.end() - return TrainResponse(info=f"train embedding error: {error}") + return models.TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: - import os, psutil + import os + import psutil process = psutil.Process(os.getpid()) res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe @@ -675,10 +678,10 @@ class Api: 'events': warnings, } else: - cuda = { 'error': 'unavailable' } + cuda = {'error': 'unavailable'} except Exception as err: - cuda = { 'error': f'{err}' } - return MemoryResponse(ram = ram, cuda = cuda) + cuda = {'error': f'{err}'} + return models.MemoryResponse(ram=ram, cuda=cuda) def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index 4a70f440..4d291076 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -223,8 +223,9 @@ for key in _options: if(_options[key].dest != 'help'): flag = _options[key] _type = str - if _options[key].default is not None: _type = type(_options[key].default) - flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) + if _options[key].default is not None: + _type = type(_options[key].default) + flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))}) FlagsModel = create_model("Flags", **flags) diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 11dcc3ee..f1a7cf09 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -7,7 +7,7 @@ from torch import nn, Tensor import torch.nn.functional as F from typing import Optional, List -from modules.codeformer.vqgan_arch import * +from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py index 6071fea7..7f8bc7c0 100644 --- a/modules/esrgan_model_arch.py +++ b/modules/esrgan_model_arch.py @@ -438,9 +438,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias= padding = padding if pad_type == 'zero' else 0 if convtype=='PartialConv2D': + from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups) elif convtype=='DeformConv2D': + from torchvision.ops import DeformConv2d # not tested c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups) elif convtype=='Conv3D': diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index 04f27c9f..aa2a14ef 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -1,4 +1,4 @@ -from modules import extra_networks, shared, extra_networks +from modules import extra_networks, shared from modules.hypernetworks import hypernetwork diff --git a/modules/images.py b/modules/images.py index 3d5d76cc..5eb6d855 100644 --- a/modules/images.py +++ b/modules/images.py @@ -472,9 +472,9 @@ def get_next_sequence_number(path, basename): prefix_length = len(basename) for p in os.listdir(path): if p.startswith(basename): - l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) + parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) try: - result = max(int(l[0]), result) + result = max(int(parts[0]), result) except ValueError: pass diff --git a/modules/img2img.py b/modules/img2img.py index cdae301a..32b1ecd6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -13,7 +13,6 @@ from modules.shared import opts, state import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html -import modules.images as images import modules.scripts diff --git a/modules/interrogate.py b/modules/interrogate.py index 9f7d657f..22df9216 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,7 +11,6 @@ import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -import modules.shared as shared from modules import devices, paths, shared, lowvram, modelloader, errors blip_image_eval_size = 384 diff --git a/modules/modelloader.py b/modules/modelloader.py index cb85ac4f..cf685000 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -108,12 +108,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): print(f"Moving {file} from {src_path} to {dest_path}.") try: shutil.move(fullpath, dest_path) - except: + except Exception: pass if len(os.listdir(src_path)) == 0: print(f"Removing empty folder: {src_path}") shutil.rmtree(src_path, True) - except: + except Exception: pass @@ -141,7 +141,7 @@ def load_upscalers(): full_model = f"modules.{model_name}_model" try: importlib.import_module(full_model) - except: + except Exception: pass datas = [] diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index f880bc3c..611c2b69 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -479,7 +479,7 @@ class LatentDiffusion(DDPM): self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -891,16 +891,6 @@ class LatentDiffusion(DDPM): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -1171,8 +1161,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1219,8 +1211,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1337,7 +1331,7 @@ class LatentDiffusion(DDPM): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index a241c8a7..0a9defa1 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -54,7 +54,8 @@ class UniPCSampler(object): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") diff --git a/modules/processing.py b/modules/processing.py index 1a76e552..6f5233c1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -664,7 +664,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if not shared.opts.dont_fix_second_order_samplers_schedule: try: step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1 - except: + except Exception: pass uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc) c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index e084e948..3a720721 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): """ def collect_steps(steps, tree): - l = [steps] + res = [steps] + class CollectSteps(lark.Visitor): def scheduled(self, tree): tree.children[-1] = float(tree.children[-1]) if tree.children[-1] < 1: tree.children[-1] *= steps tree.children[-1] = min(steps, int(tree.children[-1])) - l.append(tree.children[-1]) + res.append(tree.children[-1]) + def alternate(self, tree): - l.extend(range(1, steps+1)) + res.extend(range(1, steps+1)) + CollectSteps().visit(tree) - return sorted(set(l)) + return sorted(set(res)) def at_step(step, tree): class AtStep(lark.Transformer): diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index ba1bdcd4..d7d8d2e3 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -185,7 +185,7 @@ def image_face_points(im, settings): try: faces = classifier.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: + except Exception: continue if len(faces) > 0: diff --git a/modules/ui.py b/modules/ui.py index 2171f3aa..6beda76f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,15 +1,9 @@ -import html import json -import math import mimetypes import os -import platform -import random import sys -import tempfile -import time import traceback -from functools import partial, reduce +from functools import reduce import warnings import gradio as gr diff --git a/modules/upscaler.py b/modules/upscaler.py index e2eaa730..0ad4fe99 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -45,7 +45,7 @@ class Upscaler: try: import cv2 self.can_tile = True - except: + except Exception: pass @abstractmethod -- cgit v1.2.3 From 550256db1ce18778a9d56ff343d844c61b9f9b83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:19:16 +0300 Subject: ruff manual fixes --- extensions-builtin/LDSR/sd_hijack_autoencoder.py | 10 +++++----- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 14 +++++++------- extensions-builtin/SwinIR/swinir_model_arch.py | 6 +++++- extensions-builtin/SwinIR/swinir_model_arch_v2.py | 11 +++++++++-- modules/api/api.py | 18 ++++++++++++------ modules/codeformer/codeformer_arch.py | 7 +++++-- modules/codeformer/vqgan_arch.py | 4 ++-- modules/generation_parameters_copypaste.py | 4 ++-- modules/models/diffusion/ddpm_edit.py | 14 ++++++++------ modules/models/diffusion/uni_pc/uni_pc.py | 7 +++++-- modules/safe.py | 2 +- modules/sd_samplers_compvis.py | 2 +- modules/textual_inversion/image_embedding.py | 2 +- modules/textual_inversion/learn_schedule.py | 4 ++-- pyproject.toml | 5 ++++- 15 files changed, 69 insertions(+), 41 deletions(-) (limited to 'modules/api/api.py') diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index f457ca93..8cc82d54 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -24,7 +24,7 @@ class VQModel(pl.LightningModule): n_embed, embed_dim, ckpt_path=None, - ignore_keys=[], + ignore_keys=None, image_key="image", colorize_nlabels=None, monitor=None, @@ -62,7 +62,7 @@ class VQModel(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or []) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor @@ -81,11 +81,11 @@ class VQModel(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -270,7 +270,7 @@ class VQModel(pl.LightningModule): class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) + super().__init__(*args, embed_dim=embed_dim, **kwargs) self.embed_dim = embed_dim def encode(self, x): diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index d8fc30e3..f16d6504 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): if monitor is not None: self.monitor = monitor if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) @@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -444,7 +444,7 @@ class LatentDiffusionV1(DDPMV1): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key @@ -1418,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 863f42db..75f7bedc 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -644,13 +644,17 @@ class SwinIR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + embed_dim=96, depths=None, num_heads=None, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): super(SwinIR, self).__init__() + + depths = depths or [6, 6, 6, 6] + num_heads = num_heads or [6, 6, 6, 6] + num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 0e28ae6e..d4c0b0da 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -74,9 +74,12 @@ class WindowAttention(nn.Module): """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=[0, 0]): + pretrained_window_size=None): super().__init__() + + pretrained_window_size = pretrained_window_size or [0, 0] + self.dim = dim self.window_size = window_size # Wh, Ww self.pretrained_window_size = pretrained_window_size @@ -698,13 +701,17 @@ class Swin2SR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + embed_dim=96, depths=None, num_heads=None, window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): super(Swin2SR, self).__init__() + + depths = depths or [6, 6, 6, 6] + num_heads = num_heads or [6, 6, 6, 6] + num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 diff --git a/modules/api/api.py b/modules/api/api.py index f52d371b..9efb558e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -34,14 +34,16 @@ import piexif.helper def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except Exception: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e + def script_name_to_index(name, scripts): try: return [script.title().lower() for script in scripts].index(name.lower()) - except Exception: - raise HTTPException(status_code=422, detail=f"Script '{name}' not found") + except Exception as e: + raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e + def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) @@ -50,20 +52,23 @@ def validate_sampler_name(name): return name + def setUpscalers(req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict + def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] try: image = Image.open(BytesIO(base64.b64decode(encoding))) return image - except Exception: - raise HTTPException(status_code=500, detail="Invalid encoded image") + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: @@ -94,6 +99,7 @@ def encode_pil_to_base64(image): return base64.b64encode(bytes_data) + def api_middleware(app: FastAPI): rich_available = True try: diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 00c407de..ff1c0b4b 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -161,10 +161,13 @@ class Fuse_sft_block(nn.Module): class CodeFormer(VQAutoEncoder): def __init__(self, dim_embd=512, n_head=8, n_layers=9, codebook_size=1024, latent_size=256, - connect_list=['32', '64', '128', '256'], - fix_modules=['quantize','generator']): + connect_list=None, + fix_modules=None): super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + connect_list = connect_list or ['32', '64', '128', '256'] + fix_modules = fix_modules or ['quantize', 'generator'] + if fix_modules is not None: for module in fix_modules: for param in getattr(self, module).parameters(): diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index 820e6b12..b24a0394 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -326,7 +326,7 @@ class Generator(nn.Module): @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() @@ -337,7 +337,7 @@ class VQAutoEncoder(nn.Module): self.embed_dim = emb_dim self.ch_mult = ch_mult self.resolution = img_size - self.attn_resolutions = attn_resolutions + self.attn_resolutions = attn_resolutions or [16] self.quantizer_type = quantizer self.encoder = Encoder( self.in_channels, diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index f1c59c46..7fbbe707 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -19,14 +19,14 @@ registered_param_bindings = [] class ParamBinding: - def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]): + def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): self.paste_button = paste_button self.tabname = tabname self.source_text_component = source_text_component self.source_image_component = source_image_component self.source_tabname = source_tabname self.override_settings_component = override_settings_component - self.paste_field_names = paste_field_names + self.paste_field_names = paste_field_names or [] def reset(): diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 09432117..af4dea15 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -52,7 +52,7 @@ class DDPM(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -107,7 +107,7 @@ class DDPM(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) # If initialing from EMA-only checkpoint, create EMA model after loading. if self.use_ema and not load_ema: @@ -194,7 +194,9 @@ class DDPM(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): + ignore_keys = ignore_keys or [] + sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] @@ -473,7 +475,7 @@ class LatentDiffusion(DDPM): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key @@ -1433,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index a4c4ef4e..6f8ad631 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -178,13 +178,13 @@ def model_wrapper( model, noise_schedule, model_type="noise", - model_kwargs={}, + model_kwargs=None, guidance_type="uncond", #condition=None, #unconditional_condition=None, guidance_scale=1., classifier_fn=None, - classifier_kwargs={}, + classifier_kwargs=None, ): """Create a wrapper function for the noise prediction model. @@ -275,6 +275,9 @@ def model_wrapper( A noise prediction model that accepts the noised data and the continuous time as the inputs. """ + model_kwargs = model_kwargs or [] + classifier_kwargs = classifier_kwargs or [] + def get_model_input_time(t_continuous): """ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. diff --git a/modules/safe.py b/modules/safe.py index e6c2f2c0..2d5b972f 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -104,7 +104,7 @@ def check_pt(filename, extra_handler): def load(filename, *args, **kwargs): - return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) + return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) def load_with_extra(filename, extra_handler=None, *args, **kwargs): diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 7427648f..b1ee3be7 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler: def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs) x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index ee0e850a..d85a4888 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder): class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs) def object_hook(self, d): if 'TORCHTENSOR' in d: diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index f63fc72f..fda58898 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -32,8 +32,8 @@ class LearnScheduleIterator: self.maxit += 1 return assert self.rates - except (ValueError, AssertionError): - raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') + except (ValueError, AssertionError) as e: + raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e def __iter__(self): diff --git a/pyproject.toml b/pyproject.toml index 2f65fd6c..346a0cde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ ignore = [ ] - [tool.ruff.per-file-ignores] "webui.py" = ["E402"] # Module level import not at top of file + +[tool.ruff.flake8-bugbear] +# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. +extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] \ No newline at end of file -- cgit v1.2.3 From d25219b7e889cf34bccae9cb88497708796efda2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:55:09 +0300 Subject: manual fixes for some C408 --- extensions-builtin/LDSR/ldsr_model_arch.py | 4 ++-- extensions-builtin/LDSR/sd_hijack_autoencoder.py | 2 +- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 8 ++++---- modules/api/api.py | 2 +- modules/models/diffusion/ddpm_edit.py | 8 ++++---- modules/models/diffusion/uni_pc/uni_pc.py | 4 ++-- modules/sd_hijack_inpainting.py | 2 +- 7 files changed, 15 insertions(+), 15 deletions(-) (limited to 'modules/api/api.py') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 27e38549..2173de79 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -157,7 +157,7 @@ class LDSR: def get_cond(selected_path): - example = dict() + example = {} up_f = 4 c = selected_path.convert('RGB') c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) @@ -195,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s @torch.no_grad() def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): - log = dict() + log = {} z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, return_first_stage_outputs=True, diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 8cc82d54..81c5101b 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -237,7 +237,7 @@ class VQModel(pl.LightningModule): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index f16d6504..57c02d12 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -1247,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): use_ddim = ddim_steps is not None - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1274,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: diff --git a/modules/api/api.py b/modules/api/api.py index 9efb558e..594fa655 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -165,7 +165,7 @@ def api_middleware(app: FastAPI): class Api: def __init__(self, app: FastAPI, queue_lock: Lock): if shared.cmd_opts.api_auth: - self.credentials = dict() + self.credentials = {} for auth in shared.cmd_opts.api_auth.split(","): user, password = auth.split(":") self.credentials[user] = password diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index af4dea15..3fb76b65 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -405,7 +405,7 @@ class DDPM(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -413,7 +413,7 @@ class DDPM(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -1263,7 +1263,7 @@ class LatentDiffusion(DDPM): use_ddim = False - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1291,7 +1291,7 @@ class LatentDiffusion(DDPM): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 6f8ad631..f6c49f87 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -344,7 +344,7 @@ def model_wrapper( t_in = torch.cat([t_continuous] * 2) if isinstance(condition, dict): assert isinstance(unconditional_condition, dict) - c_in = dict() + c_in = {} for k in condition: if isinstance(condition[k], list): c_in[k] = [torch.cat([ @@ -355,7 +355,7 @@ def model_wrapper( unconditional_condition[k], condition[k]]) elif isinstance(condition, list): - c_in = list() + c_in = [] assert isinstance(unconditional_condition, list) for i in range(len(condition)): c_in.append(torch.cat([unconditional_condition[i], condition[i]])) diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 058575b7..c1977b19 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -23,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F if isinstance(c, dict): assert isinstance(unconditional_conditioning, dict) - c_in = dict() + c_in = {} for k in c: if isinstance(c[k], list): c_in[k] = [ -- cgit v1.2.3 From 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 18:28:15 +0300 Subject: Autofix Ruff W (not W605) (mostly whitespace) --- extensions-builtin/LDSR/ldsr_model_arch.py | 4 +- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 6 +-- extensions-builtin/ScuNET/scunet_model_arch.py | 2 +- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- extensions-builtin/SwinIR/swinir_model_arch.py | 2 +- extensions-builtin/SwinIR/swinir_model_arch_v2.py | 52 +++++++++++------------ launch.py | 2 +- modules/api/api.py | 4 +- modules/api/models.py | 2 +- modules/cmd_args.py | 2 +- modules/codeformer/codeformer_arch.py | 14 +++--- modules/codeformer/vqgan_arch.py | 38 ++++++++--------- modules/esrgan_model_arch.py | 4 +- modules/extras.py | 2 +- modules/hypernetworks/hypernetwork.py | 12 +++--- modules/images.py | 2 +- modules/mac_specific.py | 4 +- modules/masking.py | 2 +- modules/ngrok.py | 4 +- modules/processing.py | 2 +- modules/script_callbacks.py | 14 +++--- modules/sd_hijack.py | 12 +++--- modules/sd_hijack_optimizations.py | 32 +++++++------- modules/sd_models.py | 4 +- modules/sd_samplers_kdiffusion.py | 18 ++++---- modules/sub_quadratic_attention.py | 2 +- modules/textual_inversion/dataset.py | 4 +- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 16 +++---- modules/ui.py | 18 ++++---- modules/ui_extensions.py | 6 +-- modules/xlmr.py | 6 +-- pyproject.toml | 5 ++- scripts/img2imgalt.py | 14 +++--- scripts/loopback.py | 8 ++-- scripts/poor_mans_outpainting.py | 2 +- scripts/prompt_matrix.py | 2 +- scripts/prompts_from_file.py | 4 +- scripts/sd_upscale.py | 2 +- 39 files changed, 167 insertions(+), 166 deletions(-) (limited to 'modules/api/api.py') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 2173de79..7f450086 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -130,11 +130,11 @@ class LDSR: im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) else: print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) - + logs = self.run(model["model"], im_padded, diffusion_steps, eta) sample = logs["sample"] diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 57c02d12..631a08ef 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1): self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: @@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1): z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): + if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] @@ -890,7 +890,7 @@ class LatentDiffusionV1(DDPMV1): if hasattr(self, "split_input_params"): assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids + assert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py index 8028918a..b51a8806 100644 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -265,4 +265,4 @@ class SCUNet(nn.Module): nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) \ No newline at end of file + nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 55dd94ab..0ba50487 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -150,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale): for w_idx in w_idx_list: if state.interrupted or state.skipped: break - + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 73e37cfa..93b93274 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -805,7 +805,7 @@ class SwinIR(nn.Module): def forward(self, x): H, W = x.shape[2:] x = self.check_image_size(x) - + self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 3ca9be78..dad22cca 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = None self.register_buffer("attn_mask", attn_mask) - + def calculate_mask(self, x_size): # calculate attention mask for SW-MSA H, W = x_size @@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - return attn_mask + return attn_mask def forward(self, x, x_size): H, W = x_size @@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module): attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C else: attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - + # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C @@ -369,7 +369,7 @@ class PatchMerging(nn.Module): H, W = self.input_resolution flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim flops += H * W * self.dim // 2 - return flops + return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. @@ -447,7 +447,7 @@ class BasicLayer(nn.Module): nn.init.constant_(blk.norm1.weight, 0) nn.init.constant_(blk.norm2.bias, 0) nn.init.constant_(blk.norm2.weight, 0) - + class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: @@ -492,7 +492,7 @@ class PatchEmbed(nn.Module): flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim - return flops + return flops class RSTB(nn.Module): """Residual Swin Transformer Block (RSTB). @@ -531,7 +531,7 @@ class RSTB(nn.Module): num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path, norm_layer=norm_layer, @@ -622,7 +622,7 @@ class Upsample(nn.Sequential): else: raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') super(Upsample, self).__init__(*m) - + class Upsample_hf(nn.Sequential): """Upsample module. @@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential): m.append(nn.PixelShuffle(3)) else: raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) + super(Upsample_hf, self).__init__(*m) class UpsampleOneStep(nn.Sequential): @@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential): H, W = self.input_resolution flops = H * W * self.num_feat * 3 * 9 return flops - - + + class Swin2SR(nn.Module): r""" Swin2SR @@ -699,7 +699,7 @@ class Swin2SR(nn.Module): def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, + window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', @@ -764,7 +764,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, + qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results norm_layer=norm_layer, @@ -776,7 +776,7 @@ class Swin2SR(nn.Module): ) self.layers.append(layer) - + if self.upsampler == 'pixelshuffle_hf': self.layers_hf = nn.ModuleList() for i_layer in range(self.num_layers): @@ -787,7 +787,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, + qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results norm_layer=norm_layer, @@ -799,7 +799,7 @@ class Swin2SR(nn.Module): ) self.layers_hf.append(layer) - + self.norm = norm_layer(self.num_features) # build the last conv layer in deep feature extraction @@ -829,10 +829,10 @@ class Swin2SR(nn.Module): self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_after_aux = nn.Sequential( nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) + nn.LeakyReLU(inplace=True)) self.upsample = Upsample(upscale, num_feat) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - + elif self.upsampler == 'pixelshuffle_hf': self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) @@ -846,7 +846,7 @@ class Swin2SR(nn.Module): nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - + elif self.upsampler == 'pixelshuffledirect': # for lightweight SR (to save parameters) self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, @@ -905,7 +905,7 @@ class Swin2SR(nn.Module): x = self.patch_unembed(x, x_size) return x - + def forward_features_hf(self, x): x_size = (x.shape[2], x.shape[3]) x = self.patch_embed(x) @@ -919,7 +919,7 @@ class Swin2SR(nn.Module): x = self.norm(x) # B L C x = self.patch_unembed(x, x_size) - return x + return x def forward(self, x): H, W = x.shape[2:] @@ -951,7 +951,7 @@ class Swin2SR(nn.Module): x = self.conv_after_body(self.forward_features(x)) + x x_before = self.conv_before_upsample(x) x_out = self.conv_last(self.upsample(x_before)) - + x_hf = self.conv_first_hf(x_before) x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf x_hf = self.conv_before_upsample_hf(x_hf) @@ -977,15 +977,15 @@ class Swin2SR(nn.Module): x_first = self.conv_first(x) res = self.conv_after_body(self.forward_features(x_first)) + x_first x = x + self.conv_last(res) - + x = x / self.img_range + self.mean if self.upsampler == "pixelshuffle_aux": return x[:, :, :H*self.upscale, :W*self.upscale], aux - + elif self.upsampler == "pixelshuffle_hf": x_out = x_out / self.img_range + self.mean return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - + else: return x[:, :, :H*self.upscale, :W*self.upscale] @@ -1014,4 +1014,4 @@ if __name__ == '__main__': x = torch.randn((1, 3, height, width)) x = model(x) - print(x.shape) \ No newline at end of file + print(x.shape) diff --git a/launch.py b/launch.py index 670af87c..62b33f14 100644 --- a/launch.py +++ b/launch.py @@ -327,7 +327,7 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) - + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) diff --git a/modules/api/api.py b/modules/api/api.py index 594fa655..165985c3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -227,7 +227,7 @@ class Api: script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx - + def get_scripts_list(self): t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] @@ -237,7 +237,7 @@ class Api: def get_script(self, script_name, script_runner): if script_name is None or script_name == "": return None, None - + script_idx = script_name_to_index(script_name, script_runner.scripts) return script_runner.scripts[script_idx] diff --git a/modules/api/models.py b/modules/api/models.py index 4d291076..006ccdb7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -289,4 +289,4 @@ class MemoryResponse(BaseModel): class ScriptsList(BaseModel): txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") - img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file + img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e01ca655..f4a4ab36 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -102,4 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) -parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') \ No newline at end of file +parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 45c70f84..12db6814 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -119,7 +119,7 @@ class TransformerSALayer(nn.Module): tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): - + # self attention tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) @@ -159,7 +159,7 @@ class Fuse_sft_block(nn.Module): @ARCH_REGISTRY.register() class CodeFormer(VQAutoEncoder): - def __init__(self, dim_embd=512, n_head=8, n_layers=9, + def __init__(self, dim_embd=512, n_head=8, n_layers=9, codebook_size=1024, latent_size=256, connect_list=('32', '64', '128', '256'), fix_modules=('quantize', 'generator')): @@ -179,14 +179,14 @@ class CodeFormer(VQAutoEncoder): self.feat_emb = nn.Linear(256, self.dim_embd) # transformer - self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) for _ in range(self.n_layers)]) # logits_predict head self.idx_pred_layer = nn.Sequential( nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)) - + self.channels = { '16': 512, '32': 256, @@ -221,7 +221,7 @@ class CodeFormer(VQAutoEncoder): enc_feat_dict = {} out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] for i, block in enumerate(self.encoder.blocks): - x = block(x) + x = block(x) if i in out_list: enc_feat_dict[str(x.shape[-1])] = x.clone() @@ -266,11 +266,11 @@ class CodeFormer(VQAutoEncoder): fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] for i, block in enumerate(self.generator.blocks): - x = block(x) + x = block(x) if i in fuse_list: # fuse after i-th block f_size = str(x.shape[-1]) if w>0: x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) out = x # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat \ No newline at end of file + return out, logits, lq_feat diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index b24a0394..09ee6660 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -13,7 +13,7 @@ from basicsr.utils.registry import ARCH_REGISTRY def normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - + @torch.jit.script def swish(x): @@ -210,15 +210,15 @@ class AttnBlock(nn.Module): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) + q = q.permute(0, 2, 1) k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) + w_ = torch.bmm(q, k) w_ = w_ * (int(c)**(-0.5)) w_ = F.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) + w_ = w_.permute(0, 2, 1) h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) @@ -270,18 +270,18 @@ class Encoder(nn.Module): def forward(self, x): for block in self.blocks: x = block(x) - + return x class Generator(nn.Module): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): super().__init__() - self.nf = nf - self.ch_mult = ch_mult + self.nf = nf + self.ch_mult = ch_mult self.num_resolutions = len(self.ch_mult) self.num_res_blocks = res_blocks - self.resolution = img_size + self.resolution = img_size self.attn_resolutions = attn_resolutions self.in_channels = emb_dim self.out_channels = 3 @@ -315,24 +315,24 @@ class Generator(nn.Module): blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) self.blocks = nn.ModuleList(blocks) - + def forward(self, x): for block in self.blocks: x = block(x) - + return x - + @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks self.codebook_size = codebook_size self.embed_dim = emb_dim self.ch_mult = ch_mult @@ -363,11 +363,11 @@ class VQAutoEncoder(nn.Module): self.kl_weight ) self.generator = Generator( - self.nf, + self.nf, self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, + self.ch_mult, + self.n_blocks, + self.resolution, self.attn_resolutions ) @@ -432,4 +432,4 @@ class VQGANDiscriminator(nn.Module): raise ValueError('Wrong params!') def forward(self, x): - return self.main(x) \ No newline at end of file + return self.main(x) diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py index 4de9dd8d..2b9888ba 100644 --- a/modules/esrgan_model_arch.py +++ b/modules/esrgan_model_arch.py @@ -105,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module): Modified options that can be used: - "Partial Convolution based Padding" arXiv:1811.11718 - "Spectral normalization" arXiv:1802.05957 - - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. {Rakotonirina} and A. {Rasoanaivo} """ @@ -170,7 +170,7 @@ class GaussianNoise(nn.Module): scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x sampled_noise = self.noise.repeat(*x.size()).normal_() * scale x = x + sampled_noise - return x + return x def conv1x1(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) diff --git a/modules/extras.py b/modules/extras.py index eb4f0b42..bdf9b3b7 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ result_is_inpainting_model = True else: theta_0[key] = theta_func2(a, b, multiplier) - + theta_0[key] = to_half(theta_0[key], save_as_half) shared.state.sampling_step += 1 diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38ef074f..570b5603 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -540,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None if clip_grad: clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) @@ -593,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi print(e) scaler = torch.cuda.amp.GradScaler() - + batch_size = ds.batch_size gradient_step = ds.gradient_step # n steps = batch_size * gradient_step * n image processed @@ -636,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi if clip_grad: clip_grad_sched.step(hypernetwork.step) - + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if use_weight: @@ -657,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step += loss.item() scaler.scale(loss).backward() - + # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue loss_logging.append(_loss_step) if clip_grad: clip_grad(weights, clip_grad_sched.learn_rate) - + scaler.step(optimizer) scaler.update() hypernetwork.step += 1 @@ -674,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step = 0 steps_done = hypernetwork.step + 1 - + epoch_num = hypernetwork.step // steps_per_epoch epoch_step = hypernetwork.step % steps_per_epoch diff --git a/modules/images.py b/modules/images.py index 3b8b62d9..b2de3662 100644 --- a/modules/images.py +++ b/modules/images.py @@ -367,7 +367,7 @@ class FilenameGenerator: self.seed = seed self.prompt = prompt self.image = image - + def hasprompt(self, *args): lower = self.prompt.lower() if self.p is None or self.prompt is None: diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 5c2f92a1..d74c6b95 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -42,7 +42,7 @@ if has_mps: # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) - # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 @@ -60,4 +60,4 @@ if has_mps: # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 if platform.processor() == 'i386': for funcName in ['torch.argmax', 'torch.Tensor.argmax']: - CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') \ No newline at end of file + CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') diff --git a/modules/masking.py b/modules/masking.py index a5c4d2da..be9f84c7 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps def get_crop_region(mask, pad=0): """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" - + h, w = mask.shape crop_left = 0 diff --git a/modules/ngrok.py b/modules/ngrok.py index 7a7b4b26..67a74e85 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -13,7 +13,7 @@ def connect(token, port, region): config = conf.PyngrokConfig( auth_token=token, region=region ) - + # Guard for existing tunnels existing = ngrok.get_tunnels(pyngrok_config=config) if existing: @@ -24,7 +24,7 @@ def connect(token, port, region): print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n' 'You can use this link after the launch is complete.') return - + try: if account is None: public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url diff --git a/modules/processing.py b/modules/processing.py index c3932d6b..f902b9df 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -164,7 +164,7 @@ class StableDiffusionProcessing: self.all_subseeds = None self.iteration = 0 self.is_hr_pass = False - + @property def sd_model(self): diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 17109732..7d9dd736 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -32,22 +32,22 @@ class CFGDenoiserParams: def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): self.x = x """Latent image representation in the process of being denoised""" - + self.image_cond = image_cond """Conditioning image""" - + self.sigma = sigma """Current sigma noise step value""" - + self.sampling_step = sampling_step """Current Sampling step number""" - + self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" - + self.text_cond = text_cond """ Encoder hidden states of text conditioning from prompt""" - + self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" @@ -240,7 +240,7 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) - + def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index e374aeb8..7e50f1ab 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -34,7 +34,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th - + optimization_method = None can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp @@ -92,12 +92,12 @@ def fix_checkpoint(): def weighted_loss(sd_model, pred, target, mean=True): #Calculate the weight normally, but ignore the mean loss = sd_model._old_get_loss(pred, target, mean=False) - + #Check if we have weights available weight = getattr(sd_model, '_custom_loss_weight', None) if weight is not None: loss *= weight - + #Return the loss, as mean if specified return loss.mean() if mean else loss @@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs): try: #Temporarily append weights to a place accessible during loss calc sd_model._custom_loss_weight = w - + #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set if not hasattr(sd_model, '_old_get_loss'): @@ -120,7 +120,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs): del sd_model._custom_loss_weight except AttributeError: pass - + #If we have an old loss function, reset the loss function to the original one if hasattr(sd_model, '_old_get_loss'): sd_model.get_loss = sd_model._old_get_loss @@ -184,7 +184,7 @@ class StableDiffusionModelHijack: def undo_hijack(self, m): if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: - m.cond_stage_model = m.cond_stage_model.wrapped + m.cond_stage_model = m.cond_stage_model.wrapped elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index a174bbe1..f00fe55c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): end = i + 2 s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) s1 *= self.scale - + s2 = s1.softmax(dim=-1) del s1 - + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) del s2 del q, k, v @@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None): with devices.without_autocast(disable=not shared.opts.upcast_attn): k_in = k_in * self.scale - + del context, x - + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) del q_in, k_in, v_in - + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - + mem_free_total = get_available_vram() - + gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() modifier = 3 if q.element_size() == 2 else 2.5 mem_required = tensor_size * modifier steps = 1 - + if mem_required > mem_free_total: steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - + if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - + s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 - + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 - + del q, k, v r1 = r1.to(dtype) @@ -228,7 +228,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): with devices.without_autocast(disable=not shared.opts.upcast_attn): k = k * self.scale - + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) r = einsum_op(q, k, v) r = r.to(dtype) @@ -369,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) - + del q_in, k_in, v_in dtype = q.dtype @@ -451,7 +451,7 @@ def cross_attention_attnblock_forward(self, x): h3 += x return h3 - + def xformers_attnblock_forward(self, x): try: h_ = x diff --git a/modules/sd_models.py b/modules/sd_models.py index d1e946a5..3316d021 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,7 +165,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -372,7 +372,7 @@ def enable_midas_autodownload(): if not os.path.exists(path): if not os.path.exists(midas_path): mkdir(midas_path) - + print(f"Downloading midas model weights for {model_type} to {path}") request.urlretrieve(midas_urls[model_type], path) print(f"{model_type} downloaded") diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 2f733cf5..e9e41818 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -93,10 +93,10 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} else: image_uncond = image_cond - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -316,7 +316,7 @@ class KDiffusionSampler: sigma_sched = sigmas[steps - t_enc - 1:] xi = x + noise * sigma_sched[0] - + extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters @@ -339,9 +339,9 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond } @@ -374,9 +374,9 @@ class KDiffusionSampler: self.last_latent = x samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond }, disable=False, callback=self.callback_state, **extra_params_kwargs)) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index cc38debd..497568eb 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -179,7 +179,7 @@ def efficient_dot_product_attention( chunk_idx, min(query_chunk_size, q_tokens) ) - + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 41610e03..b9621fc9 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,7 +118,7 @@ class PersonalizedBase(Dataset): weight = torch.ones(latent_sample.shape) else: weight = None - + if latent_sampling_method == "random": entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) else: @@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader): return self def collate_wrapper_random(batch): - return BatchLoaderRandom(batch) \ No newline at end of file + return BatchLoaderRandom(batch) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index d0cad09e..a009d8e8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -125,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr default=None ) return wh and center_crop(image, *wh) - + def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): width = process_width diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9e1b2b9a..d489ed1e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -323,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) def tensorboard_add_scaler(tensorboard_writer, tag, value, step): - tensorboard_writer.add_scalar(tag=tag, + tensorboard_writer.add_scalar(tag=tag, scalar_value=value, global_step=step) def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): # Convert a pil image to a torch tensor img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) - img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands())) img_tensor = img_tensor.permute((2, 0, 1)) - + tensorboard_writer.add_image(tag, img_tensor, global_step=step) def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): @@ -402,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename - + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ @@ -412,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed - + if shared.opts.training_enable_tensorboard: tensorboard_writer = tensorboard_setup(log_directory) @@ -439,7 +439,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu') if embedding.checksum() == optimizer_saved_dict.get('hash', None): optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) - + if optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) print("Loaded existing optimizer from checkpoint") @@ -485,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if clip_grad: clip_grad_sched.step(embedding.step) - + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if use_weight: @@ -513,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue - + if clip_grad: clip_grad(embedding.vec, clip_grad_sched.learn_rate) diff --git a/modules/ui.py b/modules/ui.py index 1efb656a..ff82fff6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1171,7 +1171,7 @@ def create_ui(): process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - + with gr.Column(visible=False) as process_multicrop_col: gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') with gr.Row(): @@ -1183,7 +1183,7 @@ def create_ui(): with gr.Row(): process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") - + with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") @@ -1226,7 +1226,7 @@ def create_ui(): with FormRow(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - + with FormRow(): clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) @@ -1565,7 +1565,7 @@ def create_ui(): gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - + def unload_sd_weights(): modules.sd_models.unload_model_weights() @@ -1841,15 +1841,15 @@ def versions_html(): return f""" version: {tag} - •  + • python: {python_version} - •  + • torch: {getattr(torch, '__long_version__',torch.__version__)} - •  + • xformers: {xformers_version} - •  + • gradio: {gr.__version__} - •  + • checkpoint: N/A """ diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ed70abe5..af497733 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -467,7 +467,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" {html.escape(description)}

Added: {html.escape(added)}

{install_code} - + """ for tag in [x for x in extension_tags if x not in tags]: @@ -535,9 +535,9 @@ def create_ui(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") - with gr.Row(): + with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) - + install_result = gr.HTML() available_extensions_table = gr.HTML() diff --git a/modules/xlmr.py b/modules/xlmr.py index e056c3f6..a407a3ca 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): config_class = BertSeriesConfig def __init__(self, config=None, **kargs): - # modify initialization for autoloading + # modify initialization for autoloading if config is None: config = XLMRobertaConfig() config.attention_probs_dropout_prob= 0.1 @@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): text["attention_mask"] = torch.tensor( text['attention_mask']).to(device) features = self(**text) - return features['projection_state'] + return features['projection_state'] def forward( self, @@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): base_model_prefix = 'roberta' - config_class= RobertaSeriesConfig \ No newline at end of file + config_class= RobertaSeriesConfig diff --git a/pyproject.toml b/pyproject.toml index c88907be..d4a1bbf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ extend-select = [ "B", "C", "I", + "W", ] exclude = [ @@ -20,7 +21,7 @@ ignore = [ "I001", # Import block is un-sorted or un-formatted "C901", # Function is too complex "C408", # Rewrite as a literal - + "W605", # invalid escape sequence, messes with some docstrings ] [tool.ruff.per-file-ignores] @@ -28,4 +29,4 @@ ignore = [ [tool.ruff.flake8-bugbear] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. -extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] \ No newline at end of file +extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index bb00fb3f..1e833fa8 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -149,9 +149,9 @@ class Script(scripts.Script): sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) return [ - info, + info, override_sampler, - override_prompt, original_prompt, original_negative_prompt, + override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment, @@ -191,17 +191,17 @@ class Script(scripts.Script): self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) - + combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) - + sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) sigmas = sampler.model_wrap.get_sigmas(p.steps) - + noise_dt = combined_noise - (p.init_latent / sigmas[0]) - + p.seed = p.seed + 1 - + return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) p.sample = sample_extra diff --git a/scripts/loopback.py b/scripts/loopback.py index ad6609be..2d5feaf9 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -14,7 +14,7 @@ class Script(scripts.Script): def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): + def ui(self, is_img2img): loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength")) denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear") @@ -104,7 +104,7 @@ class Script(scripts.Script): p.seed = processed.seed + 1 p.denoising_strength = calculate_denoising_strength(i + 1) - + if state.skipped: break @@ -121,7 +121,7 @@ class Script(scripts.Script): all_images.append(last_image) p.inpainting_fill = original_inpainting_fill - + if state.interrupted: break @@ -132,7 +132,7 @@ class Script(scripts.Script): if opts.return_grid: grids.append(grid) - + all_images = grids + all_images processed = Processed(p, all_images, initial_seed, initial_info) diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index c0bbecc1..ea0632b6 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -19,7 +19,7 @@ class Script(scripts.Script): def ui(self, is_img2img): if not is_img2img: return None - + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index fb06beab..88324fe6 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -96,7 +96,7 @@ class Script(scripts.Script): p.prompt_for_display = positive_prompt processed = process_images(p) - grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) + grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size) processed.images.insert(0, grid) processed.index_of_first_image = 1 diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 9607077a..2378816f 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -109,7 +109,7 @@ class Script(scripts.Script): def title(self): return "Prompts from file or textbox" - def ui(self, is_img2img): + def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) @@ -166,7 +166,7 @@ class Script(scripts.Script): proc = process_images(copy_p) images += proc.images - + if checkbox_iterate: p.seed = p.seed + (p.batch_size * p.n_iter) all_prompts += proc.all_prompts diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 0b1d3096..e614c23b 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -16,7 +16,7 @@ class Script(scripts.Script): def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): + def ui(self, is_img2img): info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) -- cgit v1.2.3 From f6fc7916c4615c4f5cac97cab5add9b0536f6efa Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 17 May 2023 22:43:24 +0300 Subject: add /sdapi/v1/script-info api --- modules/api/api.py | 13 +++++++++++-- modules/api/models.py | 21 +++++++++++++++++++-- modules/scripts.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 5 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 165985c3..eee99bbb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -204,6 +204,7 @@ class Api: self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) + self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -229,11 +230,19 @@ class Api: return script, script_idx def get_scripts_list(self): - t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] - i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] + t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None] + i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None] return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist) + def get_script_info(self): + res = [] + + for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]: + res += [script.api_info for script in script_list if script.api_info is not None] + + return res + def get_script(self, script_name, script_runner): if script_name is None or script_name == "": return None, None diff --git a/modules/api/models.py b/modules/api/models.py index 006ccdb7..1ff2fb33 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -287,6 +287,23 @@ class MemoryResponse(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") + class ScriptsList(BaseModel): - txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") - img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") + txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)") + img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)") + + +class ScriptArg(BaseModel): + label: str = Field(default=None, title="Label", description="Name of the argument in UI") + value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument") + minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") + maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") + step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") + choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument") + + +class ScriptInfo(BaseModel): + name: str = Field(default=None, title="Name", description="Script name") + is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") + is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") + args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") diff --git a/modules/scripts.py b/modules/scripts.py index 0c12ebd5..e33d8c81 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -17,6 +17,9 @@ class PostprocessImageArgs: class Script: + name = None + """script's internal name derived from title""" + filename = None args_from = None args_to = None @@ -25,8 +28,8 @@ class Script: is_txt2img = False is_img2img = False - """A gr.Group component that has all script's UI inside it""" group = None + """A gr.Group component that has all script's UI inside it""" infotext_fields = None """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when @@ -38,6 +41,9 @@ class Script: various "Send to " buttons when clicked """ + api_info = None + """Generated value of type modules.api.models.ScriptInfo with information about the script for API""" + def title(self): """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" @@ -313,6 +319,8 @@ class ScriptRunner: self.selectable_scripts.append(script) def setup_ui(self): + import modules.api.models as api_models + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] inputs = [None] @@ -327,9 +335,28 @@ class ScriptRunner: if controls is None: return + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] + for control in controls: control.custom_script_source = os.path.basename(script.filename) + arg_info = api_models.ScriptArg(label=control.label or "") + + for field in ("value", "minimum", "maximum", "step", "choices"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) + + api_args.append(arg_info) + + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) + if script.infotext_fields is not None: self.infotext_fields += script.infotext_fields -- cgit v1.2.3 From efc98530595443d31c773526c6b760b722019d62 Mon Sep 17 00:00:00 2001 From: Monty Anderson Date: Mon, 22 May 2023 15:52:44 +0100 Subject: `modules/api/api.py`: disable `timeout_keep_alive` --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 9bb95dfd..bfeec385 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -682,4 +682,4 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) -- cgit v1.2.3 From 00dfe27f59727407c5b408a80ff2a262934df495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 08:54:13 +0300 Subject: Add & use modules.errors.print_error where currently printing exception info by hand --- extensions-builtin/LDSR/scripts/ldsr_model.py | 7 ++--- extensions-builtin/ScuNET/scripts/scunet_model.py | 6 ++-- modules/api/api.py | 7 +++-- modules/call_queue.py | 22 ++++++-------- modules/codeformer_model.py | 10 +++---- modules/config_states.py | 12 +++----- modules/errors.py | 16 +++++++++++ modules/extensions.py | 10 +++---- modules/gfpgan_model.py | 6 ++-- modules/hypernetworks/hypernetwork.py | 14 ++++----- modules/images.py | 9 ++---- modules/interrogate.py | 5 ++-- modules/launch_utils.py | 7 +++-- modules/localization.py | 6 ++-- modules/processing.py | 2 +- modules/realesrgan_model.py | 14 ++++----- modules/safe.py | 26 +++++++++-------- modules/script_callbacks.py | 9 +++--- modules/script_loading.py | 7 ++--- modules/scripts.py | 35 ++++++++--------------- modules/sd_hijack_optimizations.py | 6 ++-- modules/textual_inversion/textual_inversion.py | 9 ++---- modules/ui.py | 10 +++---- modules/ui_extensions.py | 9 ++---- scripts/prompts_from_file.py | 6 ++-- 25 files changed, 117 insertions(+), 153 deletions(-) (limited to 'modules/api/api.py') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index c4da79f3..95f1669d 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,9 +1,8 @@ import os -import sys -import traceback from basicsr.utils.download_util import load_file_from_url +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks @@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 45d9297b..dd1b822e 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,6 +1,5 @@ import os.path import sys -import traceback import PIL.Image import numpy as np @@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader, script_callbacks from scunet_model_arch import SCUNet as net + +from modules.errors import print_error from modules.shared import opts @@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print(f"Error loading ScuNET model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/modules/api/api.py b/modules/api/api.py index 6a456861..79ce9228 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -16,6 +16,7 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api import models +from modules.errors import print_error from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding @@ -108,7 +109,6 @@ def api_middleware(app: FastAPI): from rich.console import Console console = Console() except Exception: - import traceback rich_available = False @app.middleware("http") @@ -139,11 +139,12 @@ def api_middleware(app: FastAPI): "errors": str(e), } if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions - print(f"API error: {request.method}: {request.url} {err}") + message = f"API error: {request.method}: {request.url} {err}" if rich_available: + print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - traceback.print_exc() + print_error(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") diff --git a/modules/call_queue.py b/modules/call_queue.py index 447bb764..dba2a9b4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,10 +1,9 @@ import html -import sys import threading -import traceback import time from modules import shared, progress +from modules.errors import print_error queue_lock = threading.Lock() @@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): try: res = list(func(*args, **kwargs)) except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {args} {kwargs}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) + # When printing out our debug argument list, + # do not print out more than a 100 KB of text + max_debug_str_len = 131072 + message = "Error completing request" + arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] + if len(arg_str) > max_debug_str_len: + arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" + print_error(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 @@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): return tuple(res) return f - diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ececdbae..76143e9f 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,6 +1,4 @@ import os -import sys -import traceback import cv2 import torch @@ -8,6 +6,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader +from modules.errors import print_error from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,8 +104,8 @@ def setup_model(dirname): restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) + except Exception: + print_error('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -135,7 +134,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print("Error setting up CodeFormer:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index db65bcdb..faeaf28b 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c """ import os -import sys -import traceback import json import time import tqdm @@ -14,6 +12,7 @@ from collections import OrderedDict import git from modules import shared, extensions +from modules.errors import print_error from modules.paths_internal import script_path, config_states_dir @@ -53,8 +52,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -134,8 +132,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -143,8 +140,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/errors.py b/modules/errors.py index da4694f8..41d8dc93 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -1,7 +1,23 @@ import sys +import textwrap import traceback +def print_error( + message: str, + *, + exc_info: bool = False, +) -> None: + """ + Print an error message to stderr, with optional traceback. + """ + for line in message.splitlines(): + print("***", line, file=sys.stderr) + if exc_info: + print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) + print("---") + + def print_error_explanation(message): lines = message.strip().split("\n") max_len = max([len(x) for x in lines]) diff --git a/modules/extensions.py b/modules/extensions.py index 624832a0..369d2584 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,11 +1,10 @@ import os -import sys import threading -import traceback import git from modules import shared +from modules.errors import print_error from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 extensions = [] @@ -56,8 +55,7 @@ class Extension: if os.path.exists(os.path.join(self.path, ".git")): repo = git.Repo(self.path) except Exception: - print(f"Error reading github repository info from {self.path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -72,8 +70,8 @@ class Extension: self.commit_hash = commit.hexsha self.version = self.commit_hash[:8] - except Exception as ex: - print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) + except Exception: + print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 0131dea4..d2f647fe 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import facexlib import gfpgan import modules.face_restoration from modules import paths, shared, devices, modelloader +from modules.errors import print_error model_dir = "GFPGAN" user_path = None @@ -112,5 +111,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print("Error setting up GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up GFPGAN", exc_info=True) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 570b5603..fcc1ef20 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -2,8 +2,6 @@ import datetime import glob import html import os -import sys -import traceback import inspect import modules.textual_inversion.dataset @@ -12,6 +10,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint +from modules.errors import print_error from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -325,17 +324,14 @@ def load_hypernetwork(name): if path is None: return None - hypernetwork = Hypernetwork() - try: + hypernetwork = Hypernetwork() hypernetwork.load(path) + return hypernetwork except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading hypernetwork {path}", exc_info=True) return None - return hypernetwork - def load_hypernetworks(names, multipliers=None): already_loaded = {} @@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print(traceback.format_exc(), file=sys.stderr) + print_error("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index e21e554c..69151bec 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,6 +1,4 @@ import datetime -import sys -import traceback import pytz import io @@ -18,6 +16,7 @@ import json import hashlib from modules import sd_samplers, shared, script_callbacks, errors +from modules.errors import print_error from modules.paths_internal import roboto_ttf_file from modules.shared import opts @@ -464,8 +463,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print(f"Error adding [{pattern}] to filename", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -697,8 +695,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print("Error parsing NovelAI image generation parameters:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index 111b1322..d36e1a5a 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,6 +1,5 @@ import os import sys -import traceback from collections import namedtuple from pathlib import Path import re @@ -12,6 +11,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.errors import print_error blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -216,8 +216,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print("Error interrogating", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 35a52310..22edc106 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -8,6 +8,7 @@ import json from functools import lru_cache from modules import cmd_args +from modules.errors import print_error from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -188,7 +189,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print(e, file=sys.stderr) + print_error(str(e)) def list_extensions(settings_file): @@ -198,8 +199,8 @@ def list_extensions(settings_file): if os.path.isfile(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) - except Exception as e: - print(e, file=sys.stderr) + except Exception: + print_error("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/localization.py b/modules/localization.py index ee9c65e7..9a1df343 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,8 +1,7 @@ import json import os -import sys -import traceback +from modules.errors import print_error localizations = {} @@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print(f"Error loading localization from {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/processing.py b/modules/processing.py index b75f2515..5c9bcce8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,5 @@ import json +import logging import math import os import sys @@ -23,7 +24,6 @@ import modules.images as images import modules.styles import modules.sd_models as sd_models import modules.sd_vae as sd_vae -import logging from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 99983678..c8d0c64f 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import numpy as np from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader @@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print("Error importing Real-ESRGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) return info - except Exception as e: - print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + except Exception: + print_error("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -135,5 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print("Error making Real-ESRGAN models list:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/safe.py b/modules/safe.py index e8f50774..b596f565 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -2,8 +2,6 @@ import pickle import collections -import sys -import traceback import torch import numpy @@ -11,6 +9,8 @@ import _codecs import zipfile import re +from modules.errors import print_error + # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage @@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + "-----> !!!! The file is most likely corrupted !!!! <-----\n" + "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", + exc_info=True, + ) return None - except Exception: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + f"The file may be malicious, so the program is not going to read it.\n" + f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", + exc_info=True, + ) return None return unsafe_torch_load(filename, *args, **kwargs) @@ -190,4 +193,3 @@ with safe.Extra(handler): unsafe_torch_load = torch.load torch.load = load global_extra_handler = None - diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index d2728e12..6aa9c3b6 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,16 +1,15 @@ -import sys -import traceback -from collections import namedtuple import inspect +from collections import namedtuple from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks +from modules.errors import print_error + def report_exception(c, job): - print(f"Error executing callback {job} for {c.script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: diff --git a/modules/script_loading.py b/modules/script_loading.py index 57b15862..26efffcb 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,8 +1,8 @@ import os -import sys -import traceback import importlib.util +from modules.errors import print_error + def load_module(path): module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) @@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print(f"Error running preload() for {preload_script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index c902804b..a7168fd1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,12 +1,12 @@ import os import re import sys -import traceback from collections import namedtuple import gradio as gr from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing +from modules.errors import print_error AlwaysVisible = object() @@ -264,8 +264,7 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath @@ -280,11 +279,9 @@ def load_scripts(): def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: - res = func(*args, **kwargs) - return res + return func(*args, **kwargs) except Exception: - print(f"Error calling: {filename}/{funcname}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -450,8 +447,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print(f"Error running process: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -459,8 +455,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -468,8 +463,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -477,8 +471,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print(f"Error running postprocess: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -486,8 +479,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -495,24 +487,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print(f"Error running before_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print(f"Error running after_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 2ec0b049..fd186fa2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,5 @@ from __future__ import annotations import math -import sys -import traceback import psutil import torch @@ -11,6 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention +from modules.errors import print_error from modules.hypernetworks import hypernetwork import ldm.modules.attention @@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print("Cannot import xformers", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Cannot import xformers", exc_info=True) def get_available_vram(): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d489ed1e..a040a988 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,4 @@ import os -import sys -import traceback from collections import namedtuple import torch @@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint import modules.textual_inversion.dataset +from modules.errors import print_error from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay @@ -207,8 +206,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -632,8 +630,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print(traceback.format_exc(), file=sys.stderr) - pass + print_error("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/ui.py b/modules/ui.py index 001b9792..1ad94f02 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -2,7 +2,6 @@ import json import mimetypes import os import sys -import traceback from functools import reduce import warnings @@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave +from modules.errors import print_error from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: res = all_seeds[index if 0 <= index < len(all_seeds) else 0] except json.decoder.JSONDecodeError: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) + if gen_info_string: + print_error(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -1753,8 +1752,7 @@ def create_ui(): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 515ec262..cadf56be 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,10 +1,8 @@ import json import os.path -import sys import threading import time from datetime import datetime -import traceback import git @@ -14,6 +12,7 @@ import shutil import errno from modules import extensions, shared, paths, config_states +from modules.errors import print_error from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print(f"Error getting updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all @@ -113,8 +111,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print(f"Error checking updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index b918a764..4dc24615 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,13 +1,12 @@ import copy import random -import sys -import traceback import shlex import modules.scripts as scripts import gradio as gr from modules import sd_samplers +from modules.errors import print_error from modules.processing import Processed, process_images from modules.shared import state @@ -136,8 +135,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print(f"Error parsing line {line} as commandline:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} -- cgit v1.2.3 From 42e020c1c1c31b29ee7fc2e493a60613459642a1 Mon Sep 17 00:00:00 2001 From: James Date: Mon, 29 May 2023 22:25:43 +0100 Subject: Added VAE listing to web API. --- modules/api/api.py | 5 +++++ modules/api/models.py | 4 ++++ 2 files changed, 9 insertions(+) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb..31b9d90c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_vae import vae_dict from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -189,6 +190,7 @@ class Api: self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) @@ -541,6 +543,9 @@ class Api: def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] + def get_sd_vaes(self): + return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] + def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index 1ff2fb33..47fdede2 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -249,6 +249,10 @@ class SDModelItem(BaseModel): filename: str = Field(title="Filename") config: Optional[str] = Field(title="Config file") +class SDVaeItem(BaseModel): + model_name: str = Field(title="Model Name") + filename: str = Field(title="Filename") + class HypernetworkItem(BaseModel): name: str = Field(title="Name") path: Optional[str] = Field(title="Path") -- cgit v1.2.3 From 05933840f0676dd1a90a7e2ad3f2a0672624b2cd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 31 May 2023 19:56:37 +0300 Subject: rename print_error to report, use it with together with package name --- extensions-builtin/LDSR/scripts/ldsr_model.py | 5 ++--- extensions-builtin/ScuNET/scripts/scunet_model.py | 5 ++--- modules/api/api.py | 5 ++--- modules/call_queue.py | 5 ++--- modules/codeformer_model.py | 7 +++---- modules/config_states.py | 9 ++++----- modules/errors.py | 8 ++------ modules/extensions.py | 7 +++---- modules/gfpgan_model.py | 5 ++--- modules/hypernetworks/hypernetwork.py | 7 +++---- modules/images.py | 5 ++--- modules/interrogate.py | 3 +-- modules/launch_utils.py | 7 +++---- modules/localization.py | 4 ++-- modules/realesrgan_model.py | 10 +++++----- modules/safe.py | 7 ++++--- modules/script_callbacks.py | 4 ++-- modules/script_loading.py | 4 ++-- modules/scripts.py | 23 +++++++++++------------ modules/sd_hijack_optimizations.py | 3 +-- modules/textual_inversion/textual_inversion.py | 7 +++---- modules/ui.py | 7 +++---- modules/ui_extensions.py | 7 +++---- scripts/prompts_from_file.py | 5 ++--- 24 files changed, 69 insertions(+), 90 deletions(-) (limited to 'modules/api/api.py') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 95f1669d..dbd6d331 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -2,10 +2,9 @@ import os from basicsr.utils.download_util import load_file_from_url -from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR -from modules import shared, script_callbacks +from modules import shared, script_callbacks, errors import sd_hijack_autoencoder # noqa: F401 import sd_hijack_ddpm_v1 # noqa: F401 @@ -51,7 +50,7 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) except Exception: - print_error("Error importing LDSR", exc_info=True) + errors.report("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index dd1b822e..85b4505f 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -9,10 +9,9 @@ from tqdm import tqdm from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import devices, modelloader, script_callbacks +from modules import devices, modelloader, script_callbacks, errors from scunet_model_arch import SCUNet as net -from modules.errors import print_error from modules.shared import opts @@ -39,7 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print_error(f"Error loading ScuNET model: {file}", exc_info=True) + errors.report(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/modules/api/api.py b/modules/api/api.py index fbd616a3..d34ab422 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,9 +14,8 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors from modules.api import models -from modules.errors import print_error from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding @@ -145,7 +144,7 @@ def api_middleware(app: FastAPI): print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - print_error(message, exc_info=True) + errors.report(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") diff --git a/modules/call_queue.py b/modules/call_queue.py index dba2a9b4..53af6d70 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -2,8 +2,7 @@ import html import threading import time -from modules import shared, progress -from modules.errors import print_error +from modules import shared, progress, errors queue_lock = threading.Lock() @@ -62,7 +61,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] if len(arg_str) > max_debug_str_len: arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" - print_error(f"{message}\n{arg_str}", exc_info=True) + errors.report(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 76143e9f..4260b016 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -5,8 +5,7 @@ import torch import modules.face_restoration import modules.shared -from modules import shared, devices, modelloader -from modules.errors import print_error +from modules import shared, devices, modelloader, errors from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,7 +104,7 @@ def setup_model(dirname): del output torch.cuda.empty_cache() except Exception: - print_error('Failed inference for CodeFormer', exc_info=True) + errors.report('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -134,6 +133,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print_error("Error setting up CodeFormer", exc_info=True) + errors.report("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index faeaf28b..6f1ab53f 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -11,8 +11,7 @@ from datetime import datetime from collections import OrderedDict import git -from modules import shared, extensions -from modules.errors import print_error +from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir @@ -52,7 +51,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print_error(f"Error reading webui git info from {script_path}", exc_info=True) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -132,7 +131,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print_error(f"Error reading webui git info from {script_path}", exc_info=True) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -140,7 +139,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print_error(f"Error restoring webui to commit{webui_commit_hash}") + errors.report(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/errors.py b/modules/errors.py index 41d8dc93..e408f500 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -3,11 +3,7 @@ import textwrap import traceback -def print_error( - message: str, - *, - exc_info: bool = False, -) -> None: +def report(message: str, *, exc_info: bool = False) -> None: """ Print an error message to stderr, with optional traceback. """ @@ -15,7 +11,7 @@ def print_error( print("***", line, file=sys.stderr) if exc_info: print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) - print("---") + print("---", file=sys.stderr) def print_error_explanation(message): diff --git a/modules/extensions.py b/modules/extensions.py index 92f93ad9..8608584b 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,8 +1,7 @@ import os import threading -from modules import shared -from modules.errors import print_error +from modules import shared, errors from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -54,7 +53,7 @@ class Extension: if os.path.exists(os.path.join(self.path, ".git")): repo = Repo(self.path) except Exception: - print_error(f"Error reading github repository info from {self.path}", exc_info=True) + errors.report(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -70,7 +69,7 @@ class Extension: self.version = self.commit_hash[:8] except Exception: - print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) + errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index d2f647fe..e239a09d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -4,8 +4,7 @@ import facexlib import gfpgan import modules.face_restoration -from modules import paths, shared, devices, modelloader -from modules.errors import print_error +from modules import paths, shared, devices, modelloader, errors model_dir = "GFPGAN" user_path = None @@ -111,4 +110,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print_error("Error setting up GFPGAN", exc_info=True) + errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fcc1ef20..5d12b449 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -9,8 +9,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint -from modules.errors import print_error +from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -329,7 +328,7 @@ def load_hypernetwork(name): hypernetwork.load(path) return hypernetwork except Exception: - print_error(f"Error loading hypernetwork {path}", exc_info=True) + errors.report(f"Error loading hypernetwork {path}", exc_info=True) return None @@ -766,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print_error("Exception in training hypernetwork", exc_info=True) + errors.report("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index 09f728df..30e9ffc5 100644 --- a/modules/images.py +++ b/modules/images.py @@ -16,7 +16,6 @@ import json import hashlib from modules import sd_samplers, shared, script_callbacks, errors -from modules.errors import print_error from modules.paths_internal import roboto_ttf_file from modules.shared import opts @@ -463,7 +462,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print_error(f"Error adding [{pattern}] to filename", exc_info=True) + errors.report(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -698,7 +697,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print_error("Error parsing NovelAI image generation parameters", exc_info=True) + errors.report("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index d36e1a5a..9b2c5b60 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,7 +11,6 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.errors import print_error blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -216,7 +215,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print_error("Error interrogating", exc_info=True) + errors.report("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0bf4cb7e..6e9bb770 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -7,8 +7,7 @@ import platform import json from functools import lru_cache -from modules import cmd_args -from modules.errors import print_error +from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -189,7 +188,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print_error(str(e)) + errors.report(str(e)) def list_extensions(settings_file): @@ -200,7 +199,7 @@ def list_extensions(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) except Exception: - print_error("Could not load settings", exc_info=True) + errors.report("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/localization.py b/modules/localization.py index 9a1df343..e8f585da 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,7 +1,7 @@ import json import os -from modules.errors import print_error +from modules import errors localizations = {} @@ -30,6 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print_error(f"Error loading localization from {fn}", exc_info=True) + errors.report(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index c8d0c64f..2d27b321 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -5,10 +5,10 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer -from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts -from modules import modelloader +from modules import modelloader, errors + class UpscalerRealESRGAN(Upscaler): def __init__(self, path): @@ -35,7 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print_error("Error importing Real-ESRGAN", exc_info=True) + errors.report("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -75,7 +75,7 @@ class UpscalerRealESRGAN(Upscaler): return info except Exception: - print_error("Error making Real-ESRGAN models list", exc_info=True) + errors.report("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -132,4 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print_error("Error making Real-ESRGAN models list", exc_info=True) + errors.report("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/safe.py b/modules/safe.py index b596f565..b1d08a79 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -9,9 +9,10 @@ import _codecs import zipfile import re -from modules.errors import print_error # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +from modules import errors + TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage def encode(*args): @@ -136,7 +137,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print_error( + errors.report( f"Error verifying pickled file from {filename}\n" "-----> !!!! The file is most likely corrupted !!!! <-----\n" "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", @@ -144,7 +145,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): ) return None except Exception: - print_error( + errors.report( f"Error verifying pickled file from {filename}\n" f"The file may be malicious, so the program is not going to read it.\n" f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6aa9c3b6..ec1469d0 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -5,11 +5,11 @@ from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks -from modules.errors import print_error +from modules import errors def report_exception(c, job): - print_error(f"Error executing callback {job} for {c.script}", exc_info=True) + errors.report(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: diff --git a/modules/script_loading.py b/modules/script_loading.py index 26efffcb..306a1f35 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,7 +1,7 @@ import os import importlib.util -from modules.errors import print_error +from modules import errors def load_module(path): @@ -27,4 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print_error(f"Error running preload() for {preload_script}", exc_info=True) + errors.report(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index a7168fd1..0970f38e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -5,8 +5,7 @@ from collections import namedtuple import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing -from modules.errors import print_error +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors AlwaysVisible = object() @@ -264,7 +263,7 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print_error(f"Error loading script: {scriptfile.filename}", exc_info=True) + errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath @@ -281,7 +280,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: return func(*args, **kwargs) except Exception: - print_error(f"Error calling: {filename}/{funcname}", exc_info=True) + errors.report(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -447,7 +446,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print_error(f"Error running process: {script.filename}", exc_info=True) + errors.report(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -455,7 +454,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print_error(f"Error running before_process_batch: {script.filename}", exc_info=True) + errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -463,7 +462,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print_error(f"Error running process_batch: {script.filename}", exc_info=True) + errors.report(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -471,7 +470,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print_error(f"Error running postprocess: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -479,7 +478,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -487,21 +486,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print_error(f"Error running postprocess_image: {script.filename}", exc_info=True) + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print_error(f"Error running before_component: {script.filename}", exc_info=True) + errors.report(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print_error(f"Error running after_component: {script.filename}", exc_info=True) + errors.report(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index fd186fa2..5f0ff513 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,6 @@ from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention -from modules.errors import print_error from modules.hypernetworks import hypernetwork import ldm.modules.attention @@ -139,7 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print_error("Cannot import xformers", exc_info=True) + errors.report("Cannot import xformers", exc_info=True) def get_available_vram(): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b3dcb140..8da050ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,9 +12,8 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors import modules.textual_inversion.dataset -from modules.errors import print_error from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay @@ -219,7 +218,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print_error(f"Error loading embedding {fn}", exc_info=True) + errors.report(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -643,7 +642,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print_error("Error training embedding", exc_info=True) + errors.report("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/ui.py b/modules/ui.py index fb6b2498..f361264c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,8 +12,7 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave -from modules.errors import print_error +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -232,7 +231,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: except json.decoder.JSONDecodeError: if gen_info_string: - print_error(f"Error parsing JSON generation info: {gen_info_string}") + errors.report(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -1752,7 +1751,7 @@ def create_ui(): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print_error("Error loading/saving model file", exc_info=True) + errors.report("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index e2ee9d72..3140ed64 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -11,8 +11,7 @@ import html import shutil import errno -from modules import extensions, shared, paths, config_states -from modules.errors import print_error +from modules import extensions, shared, paths, config_states, errors from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -45,7 +44,7 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print_error(f"Error getting updates for {ext.name}", exc_info=True) + errors.report(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all @@ -111,7 +110,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print_error(f"Error checking updates for {ext.name}", exc_info=True) + errors.report(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 4dc24615..83a2f220 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -5,8 +5,7 @@ import shlex import modules.scripts as scripts import gradio as gr -from modules import sd_samplers -from modules.errors import print_error +from modules import sd_samplers, errors from modules.processing import Processed, process_images from modules.shared import state @@ -135,7 +134,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print_error(f"Error parsing line {line} as commandline", exc_info=True) + errors.report(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} -- cgit v1.2.3 From 51864790fd72386fbbbb015d24a43ce501ecaa4b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 2 Jun 2023 14:58:10 +0300 Subject: Simplify a bunch of `len(x) > 0`/`len(x) == 0` style expressions --- extensions-builtin/LDSR/sd_hijack_autoencoder.py | 3 ++- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 4 ++-- extensions-builtin/Lora/extra_networks_lora.py | 4 ++-- extensions-builtin/Lora/lora.py | 4 ++-- .../extra-options-section/scripts/extra_options_section.py | 2 +- modules/api/api.py | 2 +- modules/call_queue.py | 2 +- modules/extra_networks_hypernet.py | 4 ++-- modules/generation_parameters_copypaste.py | 6 ++---- modules/images.py | 6 +++--- modules/img2img.py | 3 +-- modules/models/diffusion/ddpm_edit.py | 4 ++-- modules/processing.py | 3 ++- modules/prompt_parser.py | 6 +++--- modules/script_callbacks.py | 4 ++-- modules/sd_hijack_clip.py | 2 +- modules/sd_hijack_clip_old.py | 2 +- modules/textual_inversion/autocrop.py | 14 +++++++------- modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/preprocess.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- modules/ui.py | 2 +- modules/ui_extensions.py | 5 +++-- modules/ui_settings.py | 2 +- scripts/prompts_from_file.py | 3 +-- 25 files changed, 47 insertions(+), 48 deletions(-) (limited to 'modules/api/api.py') diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 27a86e13..c29d274d 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -91,8 +91,9 @@ class VQModel(pl.LightningModule): del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") + if unexpected: print(f"Unexpected Keys: {unexpected}") def on_train_batch_end(self, *args, **kwargs): diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 631a08ef..04adc5eb 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index b5fea4d2..66ee9c85 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index eec14712..af93991c 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk): else: raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") - if len(keys_failed_to_match) > 0: + if keys_failed_to_match: print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") return lora @@ -267,7 +267,7 @@ def load_loras(names, multipliers=None): lora.multiplier = multipliers[i] if multipliers else 1.0 loaded_loras.append(lora) - if len(failed_to_load_loras) > 0: + if failed_to_load_loras: sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 17f84184..a05e10d8 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -21,7 +21,7 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row(): + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): for setting_name in shared.opts.extra_options: with FormColumn(): comp = ui_settings.create_setting_component(setting_name) diff --git a/modules/api/api.py b/modules/api/api.py index d34ab422..555eefdb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -280,7 +280,7 @@ class Api: script_args[0] = selectable_idx + 1 # Now check for always on scripts - if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + if request.alwayson_scripts: for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script is None: diff --git a/modules/call_queue.py b/modules/call_queue.py index 53af6d70..1b5e5273 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -21,7 +21,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id - if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): id_task = args[0] progress.add_task_to_queue(id_task) else: diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index aa2a14ef..b6a6dc0e 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_hypernetwork - if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional): hypernet_prompt_text = f"" p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 071bd9ea..237401a1 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -55,7 +55,7 @@ def image_from_url_text(filedata): if filedata is None: return None - if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): + if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] if type(filedata) == dict and filedata.get("is_file", False): @@ -437,7 +437,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, vals_pairs = [f"{k}: {v}" for k, v in vals.items()] - return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) paste_fields = paste_fields + [(override_settings_component, paste_settings)] @@ -454,5 +454,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, outputs=[], show_progress=False, ) - - diff --git a/modules/images.py b/modules/images.py index a12d252b..7bbfc3e0 100644 --- a/modules/images.py +++ b/modules/images.py @@ -406,7 +406,7 @@ class FilenameGenerator: prompt_no_style = self.prompt for style in shared.prompt_styles.get_style_prompts(self.p.styles): - if len(style) > 0: + if style: for part in style.split("{prompt}"): prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') @@ -415,7 +415,7 @@ class FilenameGenerator: return sanitize_filename_part(prompt_no_style, replace_spaces=False) def prompt_words(self): - words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0] + words = [x for x in re_nonletters.split(self.prompt or "") if x] if len(words) == 0: words = ["empty"] return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False) @@ -423,7 +423,7 @@ class FilenameGenerator: def datetime(self, *args): time_datetime = datetime.datetime.now() - time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format + time_format = args[0] if (args and args[0] != "") else self.default_time_format try: time_zone = pytz.timezone(args[1]) if len(args) > 1 else None except pytz.exceptions.UnknownTimeZoneError: diff --git a/modules/img2img.py b/modules/img2img.py index 4c12c2c5..35c4facc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -21,8 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): is_inpaint_batch = False if inpaint_mask_dir: inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = len(inpaint_masks) > 0 - if is_inpaint_batch: + is_inpaint_batch = bool(inpaint_masks) print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 3fb76b65..b892d5fc 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -230,9 +230,9 @@ class DDPM(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/modules/processing.py b/modules/processing.py index 362ab4c2..9ebdb549 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -975,7 +975,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and latent_scale_mode is None: - assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b4aff704..0069d8b0 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -336,11 +336,11 @@ def parse_prompt_attention(text): round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: + elif weight is not None and round_brackets: multiply_range(round_brackets.pop(), float(weight)) - elif text == ')' and len(round_brackets) > 0: + elif text == ')' and round_brackets: multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == ']' and len(square_brackets) > 0: + elif text == ']' and square_brackets: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: parts = re.split(re_break, text) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index f755283c..77ee55ee 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -287,14 +287,14 @@ def list_unets_callback(): def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' callbacks.append(ScriptCallback(filename, fun)) def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' if filename == 'unknown file': return for callback_list in callback_map.values(): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index cc6e8c21..3b5a7666 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0 or len(chunks) == 0: + if chunk.tokens or not chunks: next_chunk(is_last=True) return chunks, token_count diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index a3476e95..c5c6270b 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text self.hijack.comments += hijack_comments - if len(used_custom_terms) > 0: + if used_custom_terms: embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) self.hijack.comments.append(f"Used embeddings: {embedding_names}") diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 8e667a4d..75705459 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -77,27 +77,27 @@ def focal_point(im, settings): pois = [] weight_pref_total = 0 - if len(corner_points) > 0: + if corner_points: weight_pref_total += settings.corner_points_weight - if len(entropy_points) > 0: + if entropy_points: weight_pref_total += settings.entropy_points_weight - if len(face_points) > 0: + if face_points: weight_pref_total += settings.face_points_weight corner_centroid = None - if len(corner_points) > 0: + if corner_points: corner_centroid = centroid(corner_points) corner_centroid.weight = settings.corner_points_weight / weight_pref_total pois.append(corner_centroid) entropy_centroid = None - if len(entropy_points) > 0: + if entropy_points: entropy_centroid = centroid(entropy_points) entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total pois.append(entropy_centroid) face_centroid = None - if len(face_points) > 0: + if face_points: face_centroid = centroid(face_points) face_centroid.weight = settings.face_points_weight / weight_pref_total pois.append(face_centroid) @@ -187,7 +187,7 @@ def image_face_points(im, settings): except Exception: continue - if len(faces) > 0: + if faces: rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [] diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b9621fc9..7ee05061 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -32,7 +32,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None self.placeholder_token = placeholder_token diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a009d8e8..0d4c3f84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption += shared.interrogator.generate_caption(image) if params.process_caption_deepbooru: - if len(caption) > 0: + if caption: caption += ", " caption += deepbooru.model.tag_multi(image) @@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption = caption.strip() - if len(caption) > 0: + if caption: with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8da050ca..bb6f211c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -251,7 +251,7 @@ class EmbeddingDatabase: if self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: + if self.skipped_embeddings: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): diff --git a/modules/ui.py b/modules/ui.py index b7459f08..9a025cca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -398,7 +398,7 @@ def create_override_settings_dropdown(tabname, row): dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) dropdown.change( - fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + fn=lambda x: gr.Dropdown.update(visible=bool(x)), inputs=[dropdown], outputs=[dropdown], ) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 3140ed64..65173e06 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -333,7 +333,8 @@ def install_extension_from_url(dirname, url, branch_name=None): assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' normalized_url = normalize_git_url(url) - assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' + if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url): + raise Exception(f'Extension with this URL is already installed: {url}') tmpdir = os.path.join(paths.data_path, "tmp", dirname) @@ -449,7 +450,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" existing = installed_extension_urls.get(normalize_git_url(url), None) extension_tags = extension_tags + ["installed"] if existing else extension_tags - if len([x for x in extension_tags if x in tags_to_hide]) > 0: + if any(x for x in extension_tags if x in tags_to_hide): hidden += 1 continue diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 7874298e..2688d8c2 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -81,7 +81,7 @@ class UiSettings: opts.save(shared.config_filename) except RuntimeError: return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.' def run_settings_single(self, value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 83a2f220..50320d55 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -121,8 +121,7 @@ class Script(scripts.Script): return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): - lines = [x.strip() for x in prompt_txt.splitlines()] - lines = [x for x in lines if len(x) > 0] + lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True -- cgit v1.2.3 From 4faaf3e723ddaec022af85d6e66c3b1bac449584 Mon Sep 17 00:00:00 2001 From: ramyma Date: Sun, 4 Jun 2023 16:59:23 +0300 Subject: Add endpoint to get latent_upscale_modes for hires fix --- modules/api/api.py | 9 +++++++++ modules/api/models.py | 3 +++ 2 files changed, 12 insertions(+) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index d34ab422..61a47005 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -189,6 +189,7 @@ class Api: self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) @@ -540,6 +541,14 @@ class Api: for upscaler in shared.sd_upscalers ] + def get_latent_upscale_modes(self): + return [ + { + "name": upscale_mode, + } + for upscale_mode in [*(shared.latent_upscale_modes or {})] + ] + def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] diff --git a/modules/api/models.py b/modules/api/models.py index 47fdede2..b3a745f0 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -241,6 +241,9 @@ class UpscalerItem(BaseModel): model_url: Optional[str] = Field(title="URL") scale: Optional[float] = Field(title="Scale") +class LatentUpscalerModeItem(BaseModel): + name: str = Field(title="Name") + class SDModelItem(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") -- cgit v1.2.3 From ba70a220e3176153ba2a559acb9e5aa692dce7ca Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 5 Jun 2023 22:20:29 +0300 Subject: Remove a bunch of unused/vestigial code As found by Vulture and some eyes --- modules/api/api.py | 7 ------- modules/api/models.py | 4 ---- modules/codeformer_model.py | 4 ---- modules/devices.py | 7 ------- modules/generation_parameters_copypaste.py | 29 ----------------------------- modules/hypernetworks/hypernetwork.py | 24 ------------------------ modules/paths.py | 14 -------------- 7 files changed, 89 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..41cd7eca 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -32,13 +32,6 @@ import piexif import piexif.helper -def upscaler_to_index(name: str): - try: - return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e - - def script_name_to_index(name, scripts): try: return [script.title().lower() for script in scripts].index(name.lower()) diff --git a/modules/api/models.py b/modules/api/models.py index b3a745f0..b5683071 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel): prompt: Optional[str] = Field(title="Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt") -class ArtistItem(BaseModel): - name: str = Field(title="Name") - score: float = Field(title="Score") - category: str = Field(title="Category") class EmbeddingItem(BaseModel): step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 4260b016..a01fe63d 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -15,7 +15,6 @@ model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' -have_codeformer = False codeformer = None @@ -125,9 +124,6 @@ def setup_model(dirname): return restored_img - global have_codeformer - have_codeformer = True - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) diff --git a/modules/devices.py b/modules/devices.py index 1ed6ffdc..620ed1a6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -15,13 +15,6 @@ def has_mps() -> bool: else: return mac_specific.has_mps -def extract_device_id(args, name): - for x in range(len(args)): - if name in args[x]: - return args[x + 1] - - return None - def get_cuda_device_string(): from modules import shared diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 1d02ffae..699b1a81 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -174,31 +174,6 @@ def send_image_and_dimensions(x): return img, w, h - -def find_hypernetwork_key(hypernet_name, hypernet_hash=None): - """Determines the config parameter name to use for the hypernet based on the parameters in the infotext. - - Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config - parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to. - - If the infotext has no hash, then a hypernet with the same name will be selected instead. - """ - hypernet_name = hypernet_name.lower() - if hypernet_hash is not None: - # Try to match the hash in the name - for hypernet_key in shared.hypernetworks.keys(): - result = re_hypernet_hash.search(hypernet_key) - if result is not None and result[1] == hypernet_hash: - return hypernet_key - else: - # Fall back to a hypernet with the same name - for hypernet_key in shared.hypernetworks.keys(): - if hypernet_key.lower().startswith(hypernet_name): - return hypernet_key - - return None - - def restore_old_hires_fix_params(res): """for infotexts that specify old First pass size parameter, convert it into width, height, and hr scale""" @@ -329,10 +304,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res -settings_map = {} - - - infotext_to_setting_name_mapping = [ ('Clip skip', 'CLIP_stop_at_last_layers', ), ('Conditional mask weight', 'inpainting_mask_weight'), diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5d12b449..51941c11 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None): shared.loaded_hypernetworks.append(hypernetwork) -def find_closest_hypernetwork_name(search: str): - if not search: - return None - search = search.lower() - applicable = [name for name in shared.hypernetworks if search in name.lower()] - if not applicable: - return None - applicable = sorted(applicable, key=lambda name: len(name)) - return applicable[0] - - def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) @@ -446,18 +435,6 @@ def statistics(data): return total_information, recent_information -def report_statistics(loss_info:dict): - keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) - for key in keys: - try: - print("Loss statistics for file " + key) - info, recent = statistics(list(loss_info[key])) - print(info) - print(recent) - except Exception as e: - print(e) - - def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) @@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False pbar.close() hypernetwork.eval() - #report_statistics(loss_dict) sd_hijack_checkpoint.remove() diff --git a/modules/paths.py b/modules/paths.py index 5171df4f..bada804e 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d - - -class Prioritize: - def __init__(self, name): - self.name = name - self.path = None - - def __enter__(self): - self.path = sys.path.copy() - sys.path = [paths[self.name]] + sys.path - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.path = self.path - self.path = None -- cgit v1.2.3 From 8ca34ad6d8cc2502403b3b96bb811366bc13c076 Mon Sep 17 00:00:00 2001 From: Su Wei Date: Fri, 9 Jun 2023 13:14:20 +0800 Subject: add model exists status check to modeuls/api/api.py , /sdapi/v1/options [POST] --- modules/api/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb..56b7858d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights,checkpoint_alisases from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -515,6 +515,11 @@ class Api: def set_config(self, req: Dict[str, Any]): for k, v in req.items(): + if k == "sd_model_checkpoint": + checkpoint_info = checkpoint_alisases.get(v, None) + if checkpoint_info is None: + print(f"model [{v}] not founded, skip config saving process") + return shared.opts.set(k, v) shared.opts.save(shared.config_filename) -- cgit v1.2.3 From 9142be0a0d8cea37cf1ae86c17fc7dcb161d9a42 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 10 Jun 2023 23:36:34 +0900 Subject: quit restart --- modules/api/api.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..317c809f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -208,6 +208,8 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -715,3 +717,10 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) + + def quit_webui(self): + restart.stop_program() + + def restart_webui(self): + if restart.is_restartable(): + restart.restart_program() -- cgit v1.2.3 From 7e2d39a2d158d1426321686b05d31a3ea694a99e Mon Sep 17 00:00:00 2001 From: Su Wei Date: Mon, 12 Jun 2023 15:22:49 +0800 Subject: update model checkpoint switch code --- modules/api/api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 56b7858d..7d7dfe9a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -514,12 +514,11 @@ class Api: return options def set_config(self, req: Dict[str, Any]): + checkpoint_key="sd_model_checkpoint" + if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases: + raise RuntimeError(f"model {v!r} not found") + for k, v in req.items(): - if k == "sd_model_checkpoint": - checkpoint_info = checkpoint_alisases.get(v, None) - if checkpoint_info is None: - print(f"model [{v}] not founded, skip config saving process") - return shared.opts.set(k, v) shared.opts.save(shared.config_filename) -- cgit v1.2.3 From b9664ab6154818680ee25920e229b808a3cdec68 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 12 Jun 2023 18:15:27 +0900 Subject: move _stop route to api --- modules/api/api.py | 11 +++++++++-- webui.py | 7 ------- 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 317c809f..cb1cde78 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -208,8 +208,11 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) + + if shared.cmd_opts.add_stop_route: + self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) + self.add_api_route("/_stop", self.stop_route, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -724,3 +727,7 @@ class Api: def restart_webui(self): if restart.is_restartable(): restart.restart_program() + + def stop_route(request): + shared.state.server_command = "stop" + return Response("Stopping.") diff --git a/webui.py b/webui.py index 136d036d..ae6dc9fb 100644 --- a/webui.py +++ b/webui.py @@ -362,11 +362,6 @@ def api_only(): api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) -def stop_route(request): - shared.state.server_command = "stop" - return Response("Stopping.") - - def webui(): launch_api = cmd_opts.api initialize() @@ -404,8 +399,6 @@ def webui(): "redoc_url": "/redoc", }, ) - if cmd_opts.add_stop_route: - app.add_route("/_stop", stop_route, methods=["POST"]) # after initial launch, disable --autolaunch for subsequent restarts cmd_opts.autolaunch = False -- cgit v1.2.3 From 5be6c026f55760039b3ebb284cc2ce85586be4ac Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 18:51:47 +0900 Subject: rename routes --- modules/api/api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index cb1cde78..5ea1d21c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -210,9 +210,9 @@ class Api: self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) if shared.cmd_opts.add_stop_route: - self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) - self.add_api_route("/_stop", self.stop_route, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -721,13 +721,13 @@ class Api: self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) - def quit_webui(self): + def kill_webui(self): restart.stop_program() def restart_webui(self): if restart.is_restartable(): restart.restart_program() - def stop_route(request): + def terminate_webui(request): shared.state.server_command = "stop" return Response("Stopping.") -- cgit v1.2.3 From 49fb2a337661d1b9a80de8ff35a640083fa98d2f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 19:52:12 +0900 Subject: response 501 if not a able to restart --- modules/api/api.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 5ea1d21c..4dc48a03 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -727,6 +727,7 @@ class Api: def restart_webui(self): if restart.is_restartable(): restart.restart_program() + return Response(status_code=501) def terminate_webui(request): shared.state.server_command = "stop" -- cgit v1.2.3 From 6091c4e4aa32b674f8ec755e6bd58989f09b08c5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 19:53:08 +0900 Subject: terminate -> stop --- .github/workflows/run_tests.yaml | 2 +- modules/api/api.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/api/api.py') diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 394811ac..96546011 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -50,7 +50,7 @@ jobs: python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - name: Kill test server if: always() - run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-terminate && sleep 10 + run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 - name: Show coverage run: | python -m coverage combine .coverage* diff --git a/modules/api/api.py b/modules/api/api.py index 4dc48a03..80d45ca7 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -212,7 +212,7 @@ class Api: if shared.cmd_opts.add_stop_route: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -729,6 +729,6 @@ class Api: restart.restart_program() return Response(status_code=501) - def terminate_webui(request): + def stop_webui(request): shared.state.server_command = "stop" return Response("Stopping.") -- cgit v1.2.3 From d06af4e517865277d0521642c2c5513af9afd76f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 09:26:18 +0300 Subject: fix and rework #11113 --- modules/api/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index f96056b6..1d4aeff5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -522,9 +522,9 @@ class Api: return options def set_config(self, req: Dict[str, Any]): - checkpoint_key="sd_model_checkpoint" - if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases: - raise RuntimeError(f"model {v!r} not found") + checkpoint_name = req.get("sd_model_checkpoint", None) + if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases: + raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): shared.opts.set(k, v) -- cgit v1.2.3 From cc9c1719786de00d4a5bfcf83be4bf2808cf0cb5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 14:21:28 +0900 Subject: rename --add-stop-route to --api-server-stop --- .github/workflows/run_tests.yaml | 2 +- modules/api/api.py | 2 +- modules/cmd_args.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/api/api.py') diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 96546011..178c026a 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -42,7 +42,7 @@ jobs: --no-half --disable-opt-split-attention --use-cpu all - --add-stop-route + --api-server-stop 2>&1 | tee output.txt & - name: Run tests run: | diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..adc633db 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -202,7 +202,7 @@ class Api: self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - if shared.cmd_opts.add_stop_route: + if shared.cmd_opts.api_server_stop: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 624dcb4f..278a605e 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api') +parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') -- cgit v1.2.3 From 74d001bc68c2106aa963e3474eee70327b8f3760 Mon Sep 17 00:00:00 2001 From: ramyma Date: Sun, 2 Jul 2023 04:59:59 +0300 Subject: Hotfix: call processing close to cleanup API generation calls --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..f10e3fe3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -335,6 +335,7 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -392,6 +393,7 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] -- cgit v1.2.3 From f44feb6a10aacc6a5ff4c9275fba2546b2858935 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:11:31 +0300 Subject: Add job argument to State.begin() --- modules/api/api.py | 14 +++++++------- modules/call_queue.py | 2 +- modules/extras.py | 3 +-- modules/interrogate.py | 3 +-- modules/postprocessing.py | 3 +-- modules/shared.py | 4 ++-- 6 files changed, 13 insertions(+), 16 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..3ea099ad 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -327,7 +327,7 @@ class Api: p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples - shared.state.begin() + shared.state.begin(job="scripts_txt2img") if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here @@ -384,7 +384,7 @@ class Api: p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples - shared.state.begin() + shared.state.begin(job="scripts_img2img") if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here @@ -599,7 +599,7 @@ class Api: def create_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used shared.state.end() @@ -610,7 +610,7 @@ class Api: def create_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") @@ -620,7 +620,7 @@ class Api: def preprocess(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() return models.PreprocessResponse(info = 'preprocess complete') @@ -636,7 +636,7 @@ class Api: def train_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_embedding") apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -657,7 +657,7 @@ class Api: def train_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_hypernetwork") shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None diff --git a/modules/call_queue.py b/modules/call_queue.py index 69bf63d2..3b94f8a4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -30,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): id_task = None with queue_lock: - shared.state.begin() + shared.state.begin(job=id_task) progress.start_task(id_task) try: diff --git a/modules/extras.py b/modules/extras.py index 830b53aa..e9c0263e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -73,8 +73,7 @@ def to_half(tensor, enable): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): - shared.state.begin() - shared.state.job = 'model-merge' + shared.state.begin(job="model-merge") def fail(message): shared.state.textinfo = message diff --git a/modules/interrogate.py b/modules/interrogate.py index 9b2c5b60..a3ae1dd5 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -184,8 +184,7 @@ class InterrogateModels: def interrogate(self, pil_image): res = "" - shared.state.begin() - shared.state.job = 'interrogate' + shared.state.begin(job="interrogate") try: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 736315e2..544b2f72 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -9,8 +9,7 @@ from modules.shared import opts def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() - shared.state.begin() - shared.state.job = 'extras' + shared.state.begin(job="extras") image_data = [] image_names = [] diff --git a/modules/shared.py b/modules/shared.py index 203ee1b9..7df2879c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,7 +173,7 @@ class State: return obj - def begin(self): + def begin(self, job: str = "(unknown)"): self.sampling_step = 0 self.job_count = -1 self.processing_has_refined_job_count = False @@ -187,7 +187,7 @@ class State: self.interrupted = False self.textinfo = None self.time_start = time.time() - + self.job = job devices.torch_gc() def end(self): -- cgit v1.2.3 From e4303443477b9ac3c90ec4dd58a4810f7ac1eabe Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:11:49 +0300 Subject: API: use finally: for state.end() --- modules/api/api.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 3ea099ad..8b79495d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -602,37 +602,35 @@ class Api: shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used - shared.state.end() return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create embedding error: {e}") + finally: + shared.state.end() + def create_hypernetwork(self, args: dict): try: shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding - shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create hypernetwork error: {e}") + finally: + shared.state.end() def preprocess(self, args: dict): try: shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() - return models.PreprocessResponse(info = 'preprocess complete') + return models.PreprocessResponse(info='preprocess complete') except KeyError as e: - shared.state.end() return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except AssertionError as e: - shared.state.end() + except Exception as e: return models.PreprocessResponse(info=f"preprocess error: {e}") - except FileNotFoundError as e: + finally: shared.state.end() - return models.PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: @@ -649,11 +647,11 @@ class Api: finally: if not apply_optimizations: sd_hijack.apply_optimizations() - shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError as msg: - shared.state.end() + except Exception as msg: return models.TrainResponse(info=f"train embedding error: {msg}") + finally: + shared.state.end() def train_hypernetwork(self, args: dict): try: @@ -675,9 +673,10 @@ class Api: sd_hijack.apply_optimizations() shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError: + except Exception as exc: + return models.TrainResponse(info=f"train embedding error: {exc}") + finally: shared.state.end() - return models.TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: -- cgit v1.2.3 From 32788873176e9d79e1fffd6f89f94b6d0ec8bb91 Mon Sep 17 00:00:00 2001 From: ramyma Date: Mon, 3 Jul 2023 20:02:30 +0300 Subject: Handle cleanup in case there's an exception thrown --- modules/api/api.py | 57 +++++++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 26 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index f10e3fe3..d9278e9e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -323,19 +323,21 @@ class Api: with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - p.scripts = script_runner - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples - - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() - p.close() + try: + p.scripts = script_runner + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + + shared.state.begin() + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() + finally: + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -380,20 +382,23 @@ class Api: with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - p.init_images = [decode_base64_to_image(x) for x in init_images] - p.scripts = script_runner - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples + try: + p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + + shared.state.begin() + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() - p.close() + finally: + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] -- cgit v1.2.3 From c1c04928596f69ddb39b8841a8435ecefb0594e9 Mon Sep 17 00:00:00 2001 From: ramyma Date: Mon, 3 Jul 2023 20:17:47 +0300 Subject: Use contextlib for closing the generation process --- modules/api/api.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index d9278e9e..e92c2938 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -30,6 +30,7 @@ from modules import devices from typing import Dict, List, Any import piexif import piexif.helper +from contextlib import closing def script_name_to_index(name, scripts): @@ -322,8 +323,7 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - try: + with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples @@ -336,8 +336,6 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() - finally: - p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -381,8 +379,7 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - try: + with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids @@ -397,8 +394,6 @@ class Api: processed = process_images(p) shared.state.end() - finally: - p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] -- cgit v1.2.3 From 259967b7c60cbd2aeb091e691b5f49d9fb64b872 Mon Sep 17 00:00:00 2001 From: jovijovi Date: Thu, 6 Jul 2023 18:43:17 +0800 Subject: fix(api): convert to "RGB" if image mode is "RGBA" --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..6507f641 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -84,6 +84,8 @@ def encode_pil_to_base64(image): image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + if image.mode == "RGBA": + image = image.convert("RGB") parameters = image.info.get('parameters', None) exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } -- cgit v1.2.3