From efc98530595443d31c773526c6b760b722019d62 Mon Sep 17 00:00:00 2001
From: Monty Anderson
Date: Mon, 22 May 2023 15:52:44 +0100
Subject: `modules/api/api.py`: disable `timeout_keep_alive`
---
modules/api/api.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 9bb95dfd..bfeec385 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -682,4 +682,4 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
--
cgit v1.2.3
From 00dfe27f59727407c5b408a80ff2a262934df495 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Mon, 29 May 2023 08:54:13 +0300
Subject: Add & use modules.errors.print_error where currently printing
exception info by hand
---
extensions-builtin/LDSR/scripts/ldsr_model.py | 7 ++---
extensions-builtin/ScuNET/scripts/scunet_model.py | 6 ++--
modules/api/api.py | 7 +++--
modules/call_queue.py | 22 ++++++--------
modules/codeformer_model.py | 10 +++----
modules/config_states.py | 12 +++-----
modules/errors.py | 16 +++++++++++
modules/extensions.py | 10 +++----
modules/gfpgan_model.py | 6 ++--
modules/hypernetworks/hypernetwork.py | 14 ++++-----
modules/images.py | 9 ++----
modules/interrogate.py | 5 ++--
modules/launch_utils.py | 7 +++--
modules/localization.py | 6 ++--
modules/processing.py | 2 +-
modules/realesrgan_model.py | 14 ++++-----
modules/safe.py | 26 +++++++++--------
modules/script_callbacks.py | 9 +++---
modules/script_loading.py | 7 ++---
modules/scripts.py | 35 ++++++++---------------
modules/sd_hijack_optimizations.py | 6 ++--
modules/textual_inversion/textual_inversion.py | 9 ++----
modules/ui.py | 10 +++----
modules/ui_extensions.py | 9 ++----
scripts/prompts_from_file.py | 6 ++--
25 files changed, 117 insertions(+), 153 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index c4da79f3..95f1669d 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -1,9 +1,8 @@
import os
-import sys
-import traceback
from basicsr.utils.download_util import load_file_from_url
+from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
@@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler):
try:
return LDSR(model, yaml)
-
except Exception:
- print("Error importing LDSR:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path):
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 45d9297b..dd1b822e 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,6 +1,5 @@
import os.path
import sys
-import traceback
import PIL.Image
import numpy as np
@@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net
+
+from modules.errors import print_error
from modules.shared import opts
@@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
- print(f"Error loading ScuNET model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
diff --git a/modules/api/api.py b/modules/api/api.py
index 6a456861..79ce9228 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -16,6 +16,7 @@ from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api import models
+from modules.errors import print_error
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
@@ -108,7 +109,6 @@ def api_middleware(app: FastAPI):
from rich.console import Console
console = Console()
except Exception:
- import traceback
rich_available = False
@app.middleware("http")
@@ -139,11 +139,12 @@ def api_middleware(app: FastAPI):
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
- print(f"API error: {request.method}: {request.url} {err}")
+ message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
+ print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
- traceback.print_exc()
+ print_error(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 447bb764..dba2a9b4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,10 +1,9 @@
import html
-import sys
import threading
-import traceback
import time
from modules import shared, progress
+from modules.errors import print_error
queue_lock = threading.Lock()
@@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
try:
res = list(func(*args, **kwargs))
except Exception as e:
- # When printing out our debug argument list, do not print out more than a MB of text
- max_debug_str_len = 131072 # (1024*1024)/8
-
- print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {args} {kwargs}"
- print(argStr[:max_debug_str_len], file=sys.stderr)
- if len(argStr) > max_debug_str_len:
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
- print(traceback.format_exc(), file=sys.stderr)
+ # When printing out our debug argument list,
+ # do not print out more than a 100 KB of text
+ max_debug_str_len = 131072
+ message = "Error completing request"
+ arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
+ if len(arg_str) > max_debug_str_len:
+ arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
+ print_error(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
@@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res)
return f
-
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index ececdbae..76143e9f 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
import cv2
import torch
@@ -8,6 +6,7 @@ import torch
import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader
+from modules.errors import print_error
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
@@ -105,8 +104,8 @@ def setup_model(dirname):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
+ except Exception:
+ print_error('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
@@ -135,7 +134,6 @@ def setup_model(dirname):
shared.face_restorers.append(codeformer)
except Exception:
- print("Error setting up CodeFormer:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
diff --git a/modules/config_states.py b/modules/config_states.py
index db65bcdb..faeaf28b 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
"""
import os
-import sys
-import traceback
import json
import time
import tqdm
@@ -14,6 +12,7 @@ from collections import OrderedDict
import git
from modules import shared, extensions
+from modules.errors import print_error
from modules.paths_internal import script_path, config_states_dir
@@ -53,8 +52,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading webui git info from {script_path}", exc_info=True)
webui_remote = None
webui_commit_hash = None
@@ -134,8 +132,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading webui git info from {script_path}", exc_info=True)
return
try:
@@ -143,8 +140,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
- print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error restoring webui to commit{webui_commit_hash}")
def restore_extension_config(config):
diff --git a/modules/errors.py b/modules/errors.py
index da4694f8..41d8dc93 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -1,7 +1,23 @@
import sys
+import textwrap
import traceback
+def print_error(
+ message: str,
+ *,
+ exc_info: bool = False,
+) -> None:
+ """
+ Print an error message to stderr, with optional traceback.
+ """
+ for line in message.splitlines():
+ print("***", line, file=sys.stderr)
+ if exc_info:
+ print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
+ print("---")
+
+
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
diff --git a/modules/extensions.py b/modules/extensions.py
index 624832a0..369d2584 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,11 +1,10 @@
import os
-import sys
import threading
-import traceback
import git
from modules import shared
+from modules.errors import print_error
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
@@ -56,8 +55,7 @@ class Extension:
if os.path.exists(os.path.join(self.path, ".git")):
repo = git.Repo(self.path)
except Exception:
- print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading github repository info from {self.path}", exc_info=True)
if repo is None or repo.bare:
self.remote = None
@@ -72,8 +70,8 @@ class Extension:
self.commit_hash = commit.hexsha
self.version = self.commit_hash[:8]
- except Exception as ex:
- print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
+ except Exception:
+ print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None
self.have_info_from_repo = True
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 0131dea4..d2f647fe 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,12 +1,11 @@
import os
-import sys
-import traceback
import facexlib
import gfpgan
import modules.face_restoration
from modules import paths, shared, devices, modelloader
+from modules.errors import print_error
model_dir = "GFPGAN"
user_path = None
@@ -112,5 +111,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
- print("Error setting up GFPGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 570b5603..fcc1ef20 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -2,8 +2,6 @@ import datetime
import glob
import html
import os
-import sys
-import traceback
import inspect
import modules.textual_inversion.dataset
@@ -12,6 +10,7 @@ import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
+from modules.errors import print_error
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -325,17 +324,14 @@ def load_hypernetwork(name):
if path is None:
return None
- hypernetwork = Hypernetwork()
-
try:
+ hypernetwork = Hypernetwork()
hypernetwork.load(path)
+ return hypernetwork
except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading hypernetwork {path}", exc_info=True)
return None
- return hypernetwork
-
def load_hypernetworks(names, multipliers=None):
already_loaded = {}
@@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}
"""
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Exception in training hypernetwork", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/images.py b/modules/images.py
index e21e554c..69151bec 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,6 +1,4 @@
import datetime
-import sys
-import traceback
import pytz
import io
@@ -18,6 +16,7 @@ import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
+from modules.errors import print_error
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
@@ -464,8 +463,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -697,8 +695,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 111b1322..d36e1a5a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,6 +1,5 @@
import os
import sys
-import traceback
from collections import namedtuple
from pathlib import Path
import re
@@ -12,6 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors
+from modules.errors import print_error
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -216,8 +216,7 @@ class InterrogateModels:
res += f", {match}"
except Exception:
- print("Error interrogating", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error interrogating", exc_info=True)
res += ""
self.unload()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 35a52310..22edc106 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -8,6 +8,7 @@ import json
from functools import lru_cache
from modules import cmd_args
+from modules.errors import print_error
from modules.paths_internal import script_path, extensions_dir
args, _ = cmd_args.parser.parse_known_args()
@@ -188,7 +189,7 @@ def run_extension_installer(extension_dir):
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
- print(e, file=sys.stderr)
+ print_error(str(e))
def list_extensions(settings_file):
@@ -198,8 +199,8 @@ def list_extensions(settings_file):
if os.path.isfile(settings_file):
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
- except Exception as e:
- print(e, file=sys.stderr)
+ except Exception:
+ print_error("Could not load settings", exc_info=True)
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
diff --git a/modules/localization.py b/modules/localization.py
index ee9c65e7..9a1df343 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,8 +1,7 @@
import json
import os
-import sys
-import traceback
+from modules.errors import print_error
localizations = {}
@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
- print(f"Error loading localization from {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/processing.py b/modules/processing.py
index b75f2515..5c9bcce8 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,5 @@
import json
+import logging
import math
import os
import sys
@@ -23,7 +24,6 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
-import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 99983678..c8d0c64f 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,12 +1,11 @@
import os
-import sys
-import traceback
import numpy as np
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
+from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader
@@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
@@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
return info
- except Exception as e:
- print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ except Exception:
+ print_error("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _):
@@ -135,5 +132,4 @@ def get_realesrgan_models(scaler):
]
return models
except Exception:
- print("Error making Real-ESRGAN models list:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error making Real-ESRGAN models list", exc_info=True)
diff --git a/modules/safe.py b/modules/safe.py
index e8f50774..b596f565 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -2,8 +2,6 @@
import pickle
import collections
-import sys
-import traceback
import torch
import numpy
@@ -11,6 +9,8 @@ import _codecs
import zipfile
import re
+from modules.errors import print_error
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
@@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ print_error(
+ f"Error verifying pickled file from {filename}\n"
+ "-----> !!!! The file is most likely corrupted !!!! <-----\n"
+ "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
+ exc_info=True,
+ )
return None
-
except Exception:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ print_error(
+ f"Error verifying pickled file from {filename}\n"
+ f"The file may be malicious, so the program is not going to read it.\n"
+ f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
+ exc_info=True,
+ )
return None
return unsafe_torch_load(filename, *args, **kwargs)
@@ -190,4 +193,3 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None
-
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index d2728e12..6aa9c3b6 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,16 +1,15 @@
-import sys
-import traceback
-from collections import namedtuple
import inspect
+from collections import namedtuple
from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
+from modules.errors import print_error
+
def report_exception(c, job):
- print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error executing callback {job} for {c.script}", exc_info=True)
class ImageSaveParams:
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 57b15862..26efffcb 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,8 +1,8 @@
import os
-import sys
-import traceback
import importlib.util
+from modules.errors import print_error
+
def load_module(path):
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser)
except Exception:
- print(f"Error running preload() for {preload_script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running preload() for {preload_script}", exc_info=True)
diff --git a/modules/scripts.py b/modules/scripts.py
index c902804b..a7168fd1 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,12 +1,12 @@
import os
import re
import sys
-import traceback
from collections import namedtuple
import gradio as gr
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
+from modules.errors import print_error
AlwaysVisible = object()
@@ -264,8 +264,7 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
@@ -280,11 +279,9 @@ def load_scripts():
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
- res = func(*args, **kwargs)
- return res
+ return func(*args, **kwargs)
except Exception:
- print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -450,8 +447,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running process: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -459,8 +455,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running before_process_batch: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -468,8 +463,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -477,8 +471,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print(f"Error running postprocess: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -486,8 +479,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -495,24 +487,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print(f"Error running before_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print(f"Error running after_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 2ec0b049..fd186fa2 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,5 @@
from __future__ import annotations
import math
-import sys
-import traceback
import psutil
import torch
@@ -11,6 +9,7 @@ from ldm.util import default
from einops import rearrange
from modules import shared, errors, devices, sub_quadratic_attention
+from modules.errors import print_error
from modules.hypernetworks import hypernetwork
import ldm.modules.attention
@@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print("Cannot import xformers", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Cannot import xformers", exc_info=True)
def get_available_vram():
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d489ed1e..a040a988 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
from collections import namedtuple
import torch
@@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
+from modules.errors import print_error
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
@@ -207,8 +206,7 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
@@ -632,8 +630,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- pass
+ print_error("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/ui.py b/modules/ui.py
index 001b9792..1ad94f02 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -2,7 +2,6 @@ import json
import mimetypes
import os
import sys
-import traceback
from functools import reduce
import warnings
@@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
+from modules.errors import print_error
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
@@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
+ if gen_info_string:
+ print_error(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
@@ -1753,8 +1752,7 @@ def create_ui():
try:
results = modules.extras.run_modelmerger(*args)
except Exception as e:
- print("Error loading/saving model file:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error loading/saving model file", exc_info=True)
modules.sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 515ec262..cadf56be 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,10 +1,8 @@
import json
import os.path
-import sys
import threading
import time
from datetime import datetime
-import traceback
import git
@@ -14,6 +12,7 @@ import shutil
import errno
from modules import extensions, shared, paths, config_states
+from modules.errors import print_error
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
@@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
try:
ext.fetch_and_reset_hard()
except Exception:
- print(f"Error getting updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error getting updates for {ext.name}", exc_info=True)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
@@ -113,8 +111,7 @@ def check_updates(id_task, disable_list):
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
- print(f"Error checking updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error checking updates for {ext.name}", exc_info=True)
shared.state.nextjob()
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index b918a764..4dc24615 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -1,13 +1,12 @@
import copy
import random
-import sys
-import traceback
import shlex
import modules.scripts as scripts
import gradio as gr
from modules import sd_samplers
+from modules.errors import print_error
from modules.processing import Processed, process_images
from modules.shared import state
@@ -136,8 +135,7 @@ class Script(scripts.Script):
try:
args = cmdargs(line)
except Exception:
- print(f"Error parsing line {line} as commandline:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error parsing line {line} as commandline", exc_info=True)
args = {"prompt": line}
else:
args = {"prompt": line}
--
cgit v1.2.3
From 42e020c1c1c31b29ee7fc2e493a60613459642a1 Mon Sep 17 00:00:00 2001
From: James
Date: Mon, 29 May 2023 22:25:43 +0100
Subject: Added VAE listing to web API.
---
modules/api/api.py | 5 +++++
modules/api/models.py | 4 ++++
2 files changed, 9 insertions(+)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index eee99bbb..31b9d90c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -189,6 +190,7 @@ class Api:
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
@@ -541,6 +543,9 @@ class Api:
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+ def get_sd_vaes(self):
+ return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
+
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/api/models.py b/modules/api/models.py
index 1ff2fb33..47fdede2 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,6 +249,10 @@ class SDModelItem(BaseModel):
filename: str = Field(title="Filename")
config: Optional[str] = Field(title="Config file")
+class SDVaeItem(BaseModel):
+ model_name: str = Field(title="Model Name")
+ filename: str = Field(title="Filename")
+
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
--
cgit v1.2.3
From 05933840f0676dd1a90a7e2ad3f2a0672624b2cd Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 31 May 2023 19:56:37 +0300
Subject: rename print_error to report, use it with together with package name
---
extensions-builtin/LDSR/scripts/ldsr_model.py | 5 ++---
extensions-builtin/ScuNET/scripts/scunet_model.py | 5 ++---
modules/api/api.py | 5 ++---
modules/call_queue.py | 5 ++---
modules/codeformer_model.py | 7 +++----
modules/config_states.py | 9 ++++-----
modules/errors.py | 8 ++------
modules/extensions.py | 7 +++----
modules/gfpgan_model.py | 5 ++---
modules/hypernetworks/hypernetwork.py | 7 +++----
modules/images.py | 5 ++---
modules/interrogate.py | 3 +--
modules/launch_utils.py | 7 +++----
modules/localization.py | 4 ++--
modules/realesrgan_model.py | 10 +++++-----
modules/safe.py | 7 ++++---
modules/script_callbacks.py | 4 ++--
modules/script_loading.py | 4 ++--
modules/scripts.py | 23 +++++++++++------------
modules/sd_hijack_optimizations.py | 3 +--
modules/textual_inversion/textual_inversion.py | 7 +++----
modules/ui.py | 7 +++----
modules/ui_extensions.py | 7 +++----
scripts/prompts_from_file.py | 5 ++---
24 files changed, 69 insertions(+), 90 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index 95f1669d..dbd6d331 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -2,10 +2,9 @@ import os
from basicsr.utils.download_util import load_file_from_url
-from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
-from modules import shared, script_callbacks
+from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
@@ -51,7 +50,7 @@ class UpscalerLDSR(Upscaler):
try:
return LDSR(model, yaml)
except Exception:
- print_error("Error importing LDSR", exc_info=True)
+ errors.report("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path):
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index dd1b822e..85b4505f 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -9,10 +9,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import devices, modelloader, script_callbacks
+from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net
-from modules.errors import print_error
from modules.shared import opts
@@ -39,7 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
- print_error(f"Error loading ScuNET model: {file}", exc_info=True)
+ errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
diff --git a/modules/api/api.py b/modules/api/api.py
index fbd616a3..d34ab422 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -14,9 +14,8 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
from modules.api import models
-from modules.errors import print_error
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
@@ -145,7 +144,7 @@ def api_middleware(app: FastAPI):
print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
- print_error(message, exc_info=True)
+ errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index dba2a9b4..53af6d70 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -2,8 +2,7 @@ import html
import threading
import time
-from modules import shared, progress
-from modules.errors import print_error
+from modules import shared, progress, errors
queue_lock = threading.Lock()
@@ -62,7 +61,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
if len(arg_str) > max_debug_str_len:
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
- print_error(f"{message}\n{arg_str}", exc_info=True)
+ errors.report(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 76143e9f..4260b016 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -5,8 +5,7 @@ import torch
import modules.face_restoration
import modules.shared
-from modules import shared, devices, modelloader
-from modules.errors import print_error
+from modules import shared, devices, modelloader, errors
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
@@ -105,7 +104,7 @@ def setup_model(dirname):
del output
torch.cuda.empty_cache()
except Exception:
- print_error('Failed inference for CodeFormer', exc_info=True)
+ errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
@@ -134,6 +133,6 @@ def setup_model(dirname):
shared.face_restorers.append(codeformer)
except Exception:
- print_error("Error setting up CodeFormer", exc_info=True)
+ errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
diff --git a/modules/config_states.py b/modules/config_states.py
index faeaf28b..6f1ab53f 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -11,8 +11,7 @@ from datetime import datetime
from collections import OrderedDict
import git
-from modules import shared, extensions
-from modules.errors import print_error
+from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir
@@ -52,7 +51,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print_error(f"Error reading webui git info from {script_path}", exc_info=True)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
webui_remote = None
webui_commit_hash = None
@@ -132,7 +131,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print_error(f"Error reading webui git info from {script_path}", exc_info=True)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
return
try:
@@ -140,7 +139,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
- print_error(f"Error restoring webui to commit{webui_commit_hash}")
+ errors.report(f"Error restoring webui to commit{webui_commit_hash}")
def restore_extension_config(config):
diff --git a/modules/errors.py b/modules/errors.py
index 41d8dc93..e408f500 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -3,11 +3,7 @@ import textwrap
import traceback
-def print_error(
- message: str,
- *,
- exc_info: bool = False,
-) -> None:
+def report(message: str, *, exc_info: bool = False) -> None:
"""
Print an error message to stderr, with optional traceback.
"""
@@ -15,7 +11,7 @@ def print_error(
print("***", line, file=sys.stderr)
if exc_info:
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
- print("---")
+ print("---", file=sys.stderr)
def print_error_explanation(message):
diff --git a/modules/extensions.py b/modules/extensions.py
index 92f93ad9..8608584b 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,8 +1,7 @@
import os
import threading
-from modules import shared
-from modules.errors import print_error
+from modules import shared, errors
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -54,7 +53,7 @@ class Extension:
if os.path.exists(os.path.join(self.path, ".git")):
repo = Repo(self.path)
except Exception:
- print_error(f"Error reading github repository info from {self.path}", exc_info=True)
+ errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
if repo is None or repo.bare:
self.remote = None
@@ -70,7 +69,7 @@ class Extension:
self.version = self.commit_hash[:8]
except Exception:
- print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
+ errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None
self.have_info_from_repo = True
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index d2f647fe..e239a09d 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -4,8 +4,7 @@ import facexlib
import gfpgan
import modules.face_restoration
-from modules import paths, shared, devices, modelloader
-from modules.errors import print_error
+from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN"
user_path = None
@@ -111,4 +110,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
- print_error("Error setting up GFPGAN", exc_info=True)
+ errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index fcc1ef20..5d12b449 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -9,8 +9,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
-from modules.errors import print_error
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -329,7 +328,7 @@ def load_hypernetwork(name):
hypernetwork.load(path)
return hypernetwork
except Exception:
- print_error(f"Error loading hypernetwork {path}", exc_info=True)
+ errors.report(f"Error loading hypernetwork {path}", exc_info=True)
return None
@@ -766,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}
"""
except Exception:
- print_error("Exception in training hypernetwork", exc_info=True)
+ errors.report("Exception in training hypernetwork", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/images.py b/modules/images.py
index 09f728df..30e9ffc5 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -16,7 +16,6 @@ import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
-from modules.errors import print_error
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
@@ -463,7 +462,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print_error(f"Error adding [{pattern}] to filename", exc_info=True)
+ errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -698,7 +697,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print_error("Error parsing NovelAI image generation parameters", exc_info=True)
+ errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
diff --git a/modules/interrogate.py b/modules/interrogate.py
index d36e1a5a..9b2c5b60 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -11,7 +11,6 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors
-from modules.errors import print_error
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -216,7 +215,7 @@ class InterrogateModels:
res += f", {match}"
except Exception:
- print_error("Error interrogating", exc_info=True)
+ errors.report("Error interrogating", exc_info=True)
res += ""
self.unload()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 0bf4cb7e..6e9bb770 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -7,8 +7,7 @@ import platform
import json
from functools import lru_cache
-from modules import cmd_args
-from modules.errors import print_error
+from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
args, _ = cmd_args.parser.parse_known_args()
@@ -189,7 +188,7 @@ def run_extension_installer(extension_dir):
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
- print_error(str(e))
+ errors.report(str(e))
def list_extensions(settings_file):
@@ -200,7 +199,7 @@ def list_extensions(settings_file):
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
except Exception:
- print_error("Could not load settings", exc_info=True)
+ errors.report("Could not load settings", exc_info=True)
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
diff --git a/modules/localization.py b/modules/localization.py
index 9a1df343..e8f585da 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,7 +1,7 @@
import json
import os
-from modules.errors import print_error
+from modules import errors
localizations = {}
@@ -30,6 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
- print_error(f"Error loading localization from {fn}", exc_info=True)
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index c8d0c64f..2d27b321 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -5,10 +5,10 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
-from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
-from modules import modelloader
+from modules import modelloader, errors
+
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
@@ -35,7 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
except Exception:
- print_error("Error importing Real-ESRGAN", exc_info=True)
+ errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
@@ -75,7 +75,7 @@ class UpscalerRealESRGAN(Upscaler):
return info
except Exception:
- print_error("Error making Real-ESRGAN models list", exc_info=True)
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _):
@@ -132,4 +132,4 @@ def get_realesrgan_models(scaler):
]
return models
except Exception:
- print_error("Error making Real-ESRGAN models list", exc_info=True)
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
diff --git a/modules/safe.py b/modules/safe.py
index b596f565..b1d08a79 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -9,9 +9,10 @@ import _codecs
import zipfile
import re
-from modules.errors import print_error
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+from modules import errors
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
@@ -136,7 +137,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
- print_error(
+ errors.report(
f"Error verifying pickled file from {filename}\n"
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
@@ -144,7 +145,7 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
)
return None
except Exception:
- print_error(
+ errors.report(
f"Error verifying pickled file from {filename}\n"
f"The file may be malicious, so the program is not going to read it.\n"
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6aa9c3b6..ec1469d0 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -5,11 +5,11 @@ from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
-from modules.errors import print_error
+from modules import errors
def report_exception(c, job):
- print_error(f"Error executing callback {job} for {c.script}", exc_info=True)
+ errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
class ImageSaveParams:
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 26efffcb..306a1f35 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,7 +1,7 @@
import os
import importlib.util
-from modules.errors import print_error
+from modules import errors
def load_module(path):
@@ -27,4 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser)
except Exception:
- print_error(f"Error running preload() for {preload_script}", exc_info=True)
+ errors.report(f"Error running preload() for {preload_script}", exc_info=True)
diff --git a/modules/scripts.py b/modules/scripts.py
index a7168fd1..0970f38e 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -5,8 +5,7 @@ from collections import namedtuple
import gradio as gr
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
-from modules.errors import print_error
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors
AlwaysVisible = object()
@@ -264,7 +263,7 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print_error(f"Error loading script: {scriptfile.filename}", exc_info=True)
+ errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
@@ -281,7 +280,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
- print_error(f"Error calling: {filename}/{funcname}", exc_info=True)
+ errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -447,7 +446,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print_error(f"Error running process: {script.filename}", exc_info=True)
+ errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -455,7 +454,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print_error(f"Error running before_process_batch: {script.filename}", exc_info=True)
+ errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -463,7 +462,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print_error(f"Error running process_batch: {script.filename}", exc_info=True)
+ errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -471,7 +470,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print_error(f"Error running postprocess: {script.filename}", exc_info=True)
+ errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -479,7 +478,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True)
+ errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -487,21 +486,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print_error(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print_error(f"Error running before_component: {script.filename}", exc_info=True)
+ errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print_error(f"Error running after_component: {script.filename}", exc_info=True)
+ errors.report(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index fd186fa2..5f0ff513 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,7 +9,6 @@ from ldm.util import default
from einops import rearrange
from modules import shared, errors, devices, sub_quadratic_attention
-from modules.errors import print_error
from modules.hypernetworks import hypernetwork
import ldm.modules.attention
@@ -139,7 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print_error("Cannot import xformers", exc_info=True)
+ errors.report("Cannot import xformers", exc_info=True)
def get_available_vram():
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index b3dcb140..8da050ca 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -12,9 +12,8 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
import modules.textual_inversion.dataset
-from modules.errors import print_error
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
@@ -219,7 +218,7 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print_error(f"Error loading embedding {fn}", exc_info=True)
+ errors.report(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
@@ -643,7 +642,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print_error("Error training embedding", exc_info=True)
+ errors.report("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/ui.py b/modules/ui.py
index fb6b2498..f361264c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -12,8 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
-from modules.errors import print_error
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
@@ -232,7 +231,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
except json.decoder.JSONDecodeError:
if gen_info_string:
- print_error(f"Error parsing JSON generation info: {gen_info_string}")
+ errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
@@ -1752,7 +1751,7 @@ def create_ui():
try:
results = modules.extras.run_modelmerger(*args)
except Exception as e:
- print_error("Error loading/saving model file", exc_info=True)
+ errors.report("Error loading/saving model file", exc_info=True)
modules.sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index e2ee9d72..3140ed64 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -11,8 +11,7 @@ import html
import shutil
import errno
-from modules import extensions, shared, paths, config_states
-from modules.errors import print_error
+from modules import extensions, shared, paths, config_states, errors
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
@@ -45,7 +44,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
try:
ext.fetch_and_reset_hard()
except Exception:
- print_error(f"Error getting updates for {ext.name}", exc_info=True)
+ errors.report(f"Error getting updates for {ext.name}", exc_info=True)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
@@ -111,7 +110,7 @@ def check_updates(id_task, disable_list):
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
- print_error(f"Error checking updates for {ext.name}", exc_info=True)
+ errors.report(f"Error checking updates for {ext.name}", exc_info=True)
shared.state.nextjob()
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 4dc24615..83a2f220 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -5,8 +5,7 @@ import shlex
import modules.scripts as scripts
import gradio as gr
-from modules import sd_samplers
-from modules.errors import print_error
+from modules import sd_samplers, errors
from modules.processing import Processed, process_images
from modules.shared import state
@@ -135,7 +134,7 @@ class Script(scripts.Script):
try:
args = cmdargs(line)
except Exception:
- print_error(f"Error parsing line {line} as commandline", exc_info=True)
+ errors.report(f"Error parsing line {line} as commandline", exc_info=True)
args = {"prompt": line}
else:
args = {"prompt": line}
--
cgit v1.2.3
From 51864790fd72386fbbbb015d24a43ce501ecaa4b Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 2 Jun 2023 14:58:10 +0300
Subject: Simplify a bunch of `len(x) > 0`/`len(x) == 0` style expressions
---
extensions-builtin/LDSR/sd_hijack_autoencoder.py | 3 ++-
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 4 ++--
extensions-builtin/Lora/extra_networks_lora.py | 4 ++--
extensions-builtin/Lora/lora.py | 4 ++--
.../extra-options-section/scripts/extra_options_section.py | 2 +-
modules/api/api.py | 2 +-
modules/call_queue.py | 2 +-
modules/extra_networks_hypernet.py | 4 ++--
modules/generation_parameters_copypaste.py | 6 ++----
modules/images.py | 6 +++---
modules/img2img.py | 3 +--
modules/models/diffusion/ddpm_edit.py | 4 ++--
modules/processing.py | 3 ++-
modules/prompt_parser.py | 6 +++---
modules/script_callbacks.py | 4 ++--
modules/sd_hijack_clip.py | 2 +-
modules/sd_hijack_clip_old.py | 2 +-
modules/textual_inversion/autocrop.py | 14 +++++++-------
modules/textual_inversion/dataset.py | 2 +-
modules/textual_inversion/preprocess.py | 4 ++--
modules/textual_inversion/textual_inversion.py | 2 +-
modules/ui.py | 2 +-
modules/ui_extensions.py | 5 +++--
modules/ui_settings.py | 2 +-
scripts/prompts_from_file.py | 3 +--
25 files changed, 47 insertions(+), 48 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index 27a86e13..c29d274d 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -91,8 +91,9 @@ class VQModel(pl.LightningModule):
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
+ if missing:
print(f"Missing Keys: {missing}")
+ if unexpected:
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
index 631a08ef..04adc5eb 100644
--- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
+++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
@@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
+ if missing:
print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
+ if unexpected:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index b5fea4d2..66ee9c85 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
- if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
multipliers = []
for params in params_list:
- assert len(params.items) > 0
+ assert params.items
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index eec14712..af93991c 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk):
else:
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
- if len(keys_failed_to_match) > 0:
+ if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora
@@ -267,7 +267,7 @@ def load_loras(names, multipliers=None):
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
- if len(failed_to_load_loras) > 0:
+ if failed_to_load_loras:
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
index 17f84184..a05e10d8 100644
--- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py
+++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
@@ -21,7 +21,7 @@ class ExtraOptionsSection(scripts.Script):
self.setting_names = []
with gr.Blocks() as interface:
- with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row():
+ with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
for setting_name in shared.opts.extra_options:
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
diff --git a/modules/api/api.py b/modules/api/api.py
index d34ab422..555eefdb 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -280,7 +280,7 @@ class Api:
script_args[0] = selectable_idx + 1
# Now check for always on scripts
- if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
+ if request.alwayson_scripts:
for alwayson_script_name in request.alwayson_scripts.keys():
alwayson_script = self.get_script(alwayson_script_name, script_runner)
if alwayson_script is None:
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 53af6d70..1b5e5273 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -21,7 +21,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
- if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
id_task = args[0]
progress.add_task_to_queue(id_task)
else:
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index aa2a14ef..b6a6dc0e 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
- if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
hypernet_prompt_text = f""
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
names = []
multipliers = []
for params in params_list:
- assert len(params.items) > 0
+ assert params.items
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 071bd9ea..237401a1 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -55,7 +55,7 @@ def image_from_url_text(filedata):
if filedata is None:
return None
- if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False):
@@ -437,7 +437,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
- return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
@@ -454,5 +454,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[],
show_progress=False,
)
-
-
diff --git a/modules/images.py b/modules/images.py
index a12d252b..7bbfc3e0 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -406,7 +406,7 @@ class FilenameGenerator:
prompt_no_style = self.prompt
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
- if len(style) > 0:
+ if style:
for part in style.split("{prompt}"):
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
@@ -415,7 +415,7 @@ class FilenameGenerator:
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
def prompt_words(self):
- words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
+ words = [x for x in re_nonletters.split(self.prompt or "") if x]
if len(words) == 0:
words = ["empty"]
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
@@ -423,7 +423,7 @@ class FilenameGenerator:
def datetime(self, *args):
time_datetime = datetime.datetime.now()
- time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
+ time_format = args[0] if (args and args[0] != "") else self.default_time_format
try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError:
diff --git a/modules/img2img.py b/modules/img2img.py
index 4c12c2c5..35c4facc 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -21,8 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
is_inpaint_batch = False
if inpaint_mask_dir:
inpaint_masks = shared.listfiles(inpaint_mask_dir)
- is_inpaint_batch = len(inpaint_masks) > 0
- if is_inpaint_batch:
+ is_inpaint_batch = bool(inpaint_masks)
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py
index 3fb76b65..b892d5fc 100644
--- a/modules/models/diffusion/ddpm_edit.py
+++ b/modules/models/diffusion/ddpm_edit.py
@@ -230,9 +230,9 @@ class DDPM(pl.LightningModule):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
+ if missing:
print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
+ if unexpected:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
diff --git a/modules/processing.py b/modules/processing.py
index 362ab4c2..9ebdb549 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -975,7 +975,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
- assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index b4aff704..0069d8b0 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -336,11 +336,11 @@ def parse_prompt_attention(text):
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
- elif weight is not None and len(round_brackets) > 0:
+ elif weight is not None and round_brackets:
multiply_range(round_brackets.pop(), float(weight))
- elif text == ')' and len(round_brackets) > 0:
+ elif text == ')' and round_brackets:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
- elif text == ']' and len(square_brackets) > 0:
+ elif text == ']' and square_brackets:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index f755283c..77ee55ee 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -287,14 +287,14 @@ def list_unets_callback():
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ filename = stack[0].filename if stack else 'unknown file'
callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ filename = stack[0].filename if stack else 'unknown file'
if filename == 'unknown file':
return
for callback_list in callback_map.values():
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index cc6e8c21..3b5a7666 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens
- if len(chunk.tokens) > 0 or len(chunks) == 0:
+ if chunk.tokens or not chunks:
next_chunk(is_last=True)
return chunks, token_count
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
index a3476e95..c5c6270b 100644
--- a/modules/sd_hijack_clip_old.py
+++ b/modules/sd_hijack_clip_old.py
@@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
self.hijack.comments += hijack_comments
- if len(used_custom_terms) > 0:
+ if used_custom_terms:
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 8e667a4d..75705459 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -77,27 +77,27 @@ def focal_point(im, settings):
pois = []
weight_pref_total = 0
- if len(corner_points) > 0:
+ if corner_points:
weight_pref_total += settings.corner_points_weight
- if len(entropy_points) > 0:
+ if entropy_points:
weight_pref_total += settings.entropy_points_weight
- if len(face_points) > 0:
+ if face_points:
weight_pref_total += settings.face_points_weight
corner_centroid = None
- if len(corner_points) > 0:
+ if corner_points:
corner_centroid = centroid(corner_points)
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid)
entropy_centroid = None
- if len(entropy_points) > 0:
+ if entropy_points:
entropy_centroid = centroid(entropy_points)
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
pois.append(entropy_centroid)
face_centroid = None
- if len(face_points) > 0:
+ if face_points:
face_centroid = centroid(face_points)
face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid)
@@ -187,7 +187,7 @@ def image_face_points(im, settings):
except Exception:
continue
- if len(faces) > 0:
+ if faces:
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
return []
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index b9621fc9..7ee05061 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -32,7 +32,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
- re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
self.placeholder_token = placeholder_token
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index a009d8e8..0d4c3f84 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
caption += shared.interrogator.generate_caption(image)
if params.process_caption_deepbooru:
- if len(caption) > 0:
+ if caption:
caption += ", "
caption += deepbooru.model.tag_multi(image)
@@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
caption = caption.strip()
- if len(caption) > 0:
+ if caption:
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8da050ca..bb6f211c 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -251,7 +251,7 @@ class EmbeddingDatabase:
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
- if len(self.skipped_embeddings) > 0:
+ if self.skipped_embeddings:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
diff --git a/modules/ui.py b/modules/ui.py
index b7459f08..9a025cca 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -398,7 +398,7 @@ def create_override_settings_dropdown(tabname, row):
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
dropdown.change(
- fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
+ fn=lambda x: gr.Dropdown.update(visible=bool(x)),
inputs=[dropdown],
outputs=[dropdown],
)
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 3140ed64..65173e06 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -333,7 +333,8 @@ def install_extension_from_url(dirname, url, branch_name=None):
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
normalized_url = normalize_git_url(url)
- assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
+ if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):
+ raise Exception(f'Extension with this URL is already installed: {url}')
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
@@ -449,7 +450,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
existing = installed_extension_urls.get(normalize_git_url(url), None)
extension_tags = extension_tags + ["installed"] if existing else extension_tags
- if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+ if any(x for x in extension_tags if x in tags_to_hide):
hidden += 1
continue
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 7874298e..2688d8c2 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -81,7 +81,7 @@ class UiSettings:
opts.save(shared.config_filename)
except RuntimeError:
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
- return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.'
def run_settings_single(self, value, key):
if not opts.same_type(value, opts.data_labels[key].default):
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 83a2f220..50320d55 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -121,8 +121,7 @@ class Script(scripts.Script):
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
- lines = [x.strip() for x in prompt_txt.splitlines()]
- lines = [x for x in lines if len(x) > 0]
+ lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
p.do_not_save_grid = True
--
cgit v1.2.3
From 4faaf3e723ddaec022af85d6e66c3b1bac449584 Mon Sep 17 00:00:00 2001
From: ramyma
Date: Sun, 4 Jun 2023 16:59:23 +0300
Subject: Add endpoint to get latent_upscale_modes for hires fix
---
modules/api/api.py | 9 +++++++++
modules/api/models.py | 3 +++
2 files changed, 12 insertions(+)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index d34ab422..61a47005 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -189,6 +189,7 @@ class Api:
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
+ self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
@@ -540,6 +541,14 @@ class Api:
for upscaler in shared.sd_upscalers
]
+ def get_latent_upscale_modes(self):
+ return [
+ {
+ "name": upscale_mode,
+ }
+ for upscale_mode in [*(shared.latent_upscale_modes or {})]
+ ]
+
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
diff --git a/modules/api/models.py b/modules/api/models.py
index 47fdede2..b3a745f0 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -241,6 +241,9 @@ class UpscalerItem(BaseModel):
model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale")
+class LatentUpscalerModeItem(BaseModel):
+ name: str = Field(title="Name")
+
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
--
cgit v1.2.3
From ba70a220e3176153ba2a559acb9e5aa692dce7ca Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Mon, 5 Jun 2023 22:20:29 +0300
Subject: Remove a bunch of unused/vestigial code
As found by Vulture and some eyes
---
modules/api/api.py | 7 -------
modules/api/models.py | 4 ----
modules/codeformer_model.py | 4 ----
modules/devices.py | 7 -------
modules/generation_parameters_copypaste.py | 29 -----------------------------
modules/hypernetworks/hypernetwork.py | 24 ------------------------
modules/paths.py | 14 --------------
7 files changed, 89 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 2e49526e..41cd7eca 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -32,13 +32,6 @@ import piexif
import piexif.helper
-def upscaler_to_index(name: str):
- try:
- return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
-
-
def script_name_to_index(name, scripts):
try:
return [script.title().lower() for script in scripts].index(name.lower())
diff --git a/modules/api/models.py b/modules/api/models.py
index b3a745f0..b5683071 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
-class ArtistItem(BaseModel):
- name: str = Field(title="Name")
- score: float = Field(title="Score")
- category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 4260b016..a01fe63d 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -15,7 +15,6 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-have_codeformer = False
codeformer = None
@@ -125,9 +124,6 @@ def setup_model(dirname):
return restored_img
- global have_codeformer
- have_codeformer = True
-
global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
diff --git a/modules/devices.py b/modules/devices.py
index 1ed6ffdc..620ed1a6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -15,13 +15,6 @@ def has_mps() -> bool:
else:
return mac_specific.has_mps
-def extract_device_id(args, name):
- for x in range(len(args)):
- if name in args[x]:
- return args[x + 1]
-
- return None
-
def get_cuda_device_string():
from modules import shared
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 1d02ffae..699b1a81 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
@@ -329,10 +304,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-settings_map = {}
-
-
-
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 5d12b449..51941c11 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork)
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@@ -446,18 +435,6 @@ def statistics(data):
return total_information, recent_information
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False
pbar.close()
hypernetwork.eval()
- #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
diff --git a/modules/paths.py b/modules/paths.py
index 5171df4f..bada804e 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
-
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
--
cgit v1.2.3
From 8ca34ad6d8cc2502403b3b96bb811366bc13c076 Mon Sep 17 00:00:00 2001
From: Su Wei
Date: Fri, 9 Jun 2023 13:14:20 +0800
Subject: add model exists status check to modeuls/api/api.py ,
/sdapi/v1/options [POST]
---
modules/api/api.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index eee99bbb..56b7858d 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights,checkpoint_alisases
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -515,6 +515,11 @@ class Api:
def set_config(self, req: Dict[str, Any]):
for k, v in req.items():
+ if k == "sd_model_checkpoint":
+ checkpoint_info = checkpoint_alisases.get(v, None)
+ if checkpoint_info is None:
+ print(f"model [{v}] not founded, skip config saving process")
+ return
shared.opts.set(k, v)
shared.opts.save(shared.config_filename)
--
cgit v1.2.3
From 9142be0a0d8cea37cf1ae86c17fc7dcb161d9a42 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sat, 10 Jun 2023 23:36:34 +0900
Subject: quit restart
---
modules/api/api.py | 11 ++++++++++-
1 file changed, 10 insertions(+), 1 deletion(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 2e49526e..317c809f 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -208,6 +208,8 @@ class Api:
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+ self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -715,3 +717,10 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
+
+ def quit_webui(self):
+ restart.stop_program()
+
+ def restart_webui(self):
+ if restart.is_restartable():
+ restart.restart_program()
--
cgit v1.2.3
From 7e2d39a2d158d1426321686b05d31a3ea694a99e Mon Sep 17 00:00:00 2001
From: Su Wei
Date: Mon, 12 Jun 2023 15:22:49 +0800
Subject: update model checkpoint switch code
---
modules/api/api.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 56b7858d..7d7dfe9a 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -514,12 +514,11 @@ class Api:
return options
def set_config(self, req: Dict[str, Any]):
+ checkpoint_key="sd_model_checkpoint"
+ if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases:
+ raise RuntimeError(f"model {v!r} not found")
+
for k, v in req.items():
- if k == "sd_model_checkpoint":
- checkpoint_info = checkpoint_alisases.get(v, None)
- if checkpoint_info is None:
- print(f"model [{v}] not founded, skip config saving process")
- return
shared.opts.set(k, v)
shared.opts.save(shared.config_filename)
--
cgit v1.2.3
From b9664ab6154818680ee25920e229b808a3cdec68 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Mon, 12 Jun 2023 18:15:27 +0900
Subject: move _stop route to api
---
modules/api/api.py | 11 +++++++++--
webui.py | 7 -------
2 files changed, 9 insertions(+), 9 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 317c809f..cb1cde78 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -208,8 +208,11 @@ class Api:
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
- self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"])
- self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"])
+
+ if shared.cmd_opts.add_stop_route:
+ self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"])
+ self.add_api_route("/_stop", self.stop_route, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -724,3 +727,7 @@ class Api:
def restart_webui(self):
if restart.is_restartable():
restart.restart_program()
+
+ def stop_route(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
diff --git a/webui.py b/webui.py
index 136d036d..ae6dc9fb 100644
--- a/webui.py
+++ b/webui.py
@@ -362,11 +362,6 @@ def api_only():
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-def stop_route(request):
- shared.state.server_command = "stop"
- return Response("Stopping.")
-
-
def webui():
launch_api = cmd_opts.api
initialize()
@@ -404,8 +399,6 @@ def webui():
"redoc_url": "/redoc",
},
)
- if cmd_opts.add_stop_route:
- app.add_route("/_stop", stop_route, methods=["POST"])
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
--
cgit v1.2.3
From 5be6c026f55760039b3ebb284cc2ce85586be4ac Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Wed, 14 Jun 2023 18:51:47 +0900
Subject: rename routes
---
modules/api/api.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index cb1cde78..5ea1d21c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -210,9 +210,9 @@ class Api:
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
if shared.cmd_opts.add_stop_route:
- self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"])
- self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"])
- self.add_api_route("/_stop", self.stop_route, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -721,13 +721,13 @@ class Api:
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
- def quit_webui(self):
+ def kill_webui(self):
restart.stop_program()
def restart_webui(self):
if restart.is_restartable():
restart.restart_program()
- def stop_route(request):
+ def terminate_webui(request):
shared.state.server_command = "stop"
return Response("Stopping.")
--
cgit v1.2.3
From 49fb2a337661d1b9a80de8ff35a640083fa98d2f Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Wed, 14 Jun 2023 19:52:12 +0900
Subject: response 501 if not a able to restart
---
modules/api/api.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5ea1d21c..4dc48a03 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -727,6 +727,7 @@ class Api:
def restart_webui(self):
if restart.is_restartable():
restart.restart_program()
+ return Response(status_code=501)
def terminate_webui(request):
shared.state.server_command = "stop"
--
cgit v1.2.3
From 6091c4e4aa32b674f8ec755e6bd58989f09b08c5 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Wed, 14 Jun 2023 19:53:08 +0900
Subject: terminate -> stop
---
.github/workflows/run_tests.yaml | 2 +-
modules/api/api.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 394811ac..96546011 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -50,7 +50,7 @@ jobs:
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server
if: always()
- run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-terminate && sleep 10
+ run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
- name: Show coverage
run: |
python -m coverage combine .coverage*
diff --git a/modules/api/api.py b/modules/api/api.py
index 4dc48a03..80d45ca7 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -212,7 +212,7 @@ class Api:
if shared.cmd_opts.add_stop_route:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
- self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -729,6 +729,6 @@ class Api:
restart.restart_program()
return Response(status_code=501)
- def terminate_webui(request):
+ def stop_webui(request):
shared.state.server_command = "stop"
return Response("Stopping.")
--
cgit v1.2.3
From d06af4e517865277d0521642c2c5513af9afd76f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 27 Jun 2023 09:26:18 +0300
Subject: fix and rework #11113
---
modules/api/api.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index f96056b6..1d4aeff5 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -522,9 +522,9 @@ class Api:
return options
def set_config(self, req: Dict[str, Any]):
- checkpoint_key="sd_model_checkpoint"
- if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases:
- raise RuntimeError(f"model {v!r} not found")
+ checkpoint_name = req.get("sd_model_checkpoint", None)
+ if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases:
+ raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():
shared.opts.set(k, v)
--
cgit v1.2.3
From cc9c1719786de00d4a5bfcf83be4bf2808cf0cb5 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Thu, 29 Jun 2023 14:21:28 +0900
Subject: rename --add-stop-route to --api-server-stop
---
.github/workflows/run_tests.yaml | 2 +-
modules/api/api.py | 2 +-
modules/cmd_args.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 96546011..178c026a 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -42,7 +42,7 @@ jobs:
--no-half
--disable-opt-split-attention
--use-cpu all
- --add-stop-route
+ --api-server-stop
2>&1 | tee output.txt &
- name: Run tests
run: |
diff --git a/modules/api/api.py b/modules/api/api.py
index 279c384a..adc633db 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -202,7 +202,7 @@ class Api:
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
- if shared.cmd_opts.add_stop_route:
+ if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 624dcb4f..278a605e 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
-parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api')
+parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
--
cgit v1.2.3
From 74d001bc68c2106aa963e3474eee70327b8f3760 Mon Sep 17 00:00:00 2001
From: ramyma
Date: Sun, 2 Jul 2023 04:59:59 +0300
Subject: Hotfix: call processing close to cleanup API generation calls
---
modules/api/api.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 279c384a..f10e3fe3 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -335,6 +335,7 @@ class Api:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
+ p.close()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -392,6 +393,7 @@ class Api:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
+ p.close()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
--
cgit v1.2.3
From f44feb6a10aacc6a5ff4c9275fba2546b2858935 Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 30 Jun 2023 13:11:31 +0300
Subject: Add job argument to State.begin()
---
modules/api/api.py | 14 +++++++-------
modules/call_queue.py | 2 +-
modules/extras.py | 3 +--
modules/interrogate.py | 3 +--
modules/postprocessing.py | 3 +--
modules/shared.py | 4 ++--
6 files changed, 13 insertions(+), 16 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 279c384a..3ea099ad 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -327,7 +327,7 @@ class Api:
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
- shared.state.begin()
+ shared.state.begin(job="scripts_txt2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
@@ -384,7 +384,7 @@ class Api:
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
- shared.state.begin()
+ shared.state.begin(job="scripts_img2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
@@ -599,7 +599,7 @@ class Api:
def create_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
@@ -610,7 +610,7 @@ class Api:
def create_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
@@ -620,7 +620,7 @@ class Api:
def preprocess(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info = 'preprocess complete')
@@ -636,7 +636,7 @@ class Api:
def train_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -657,7 +657,7 @@ class Api:
def train_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 69bf63d2..3b94f8a4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -30,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
id_task = None
with queue_lock:
- shared.state.begin()
+ shared.state.begin(job=id_task)
progress.start_task(id_task)
try:
diff --git a/modules/extras.py b/modules/extras.py
index 830b53aa..e9c0263e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
- shared.state.begin()
- shared.state.job = 'model-merge'
+ shared.state.begin(job="model-merge")
def fail(message):
shared.state.textinfo = message
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9b2c5b60..a3ae1dd5 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -184,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
+ shared.state.begin(job="interrogate")
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 736315e2..544b2f72 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
- shared.state.begin()
- shared.state.job = 'extras'
+ shared.state.begin(job="extras")
image_data = []
image_names = []
diff --git a/modules/shared.py b/modules/shared.py
index 203ee1b9..7df2879c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -173,7 +173,7 @@ class State:
return obj
- def begin(self):
+ def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
@@ -187,7 +187,7 @@ class State:
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
-
+ self.job = job
devices.torch_gc()
def end(self):
--
cgit v1.2.3
From e4303443477b9ac3c90ec4dd58a4810f7ac1eabe Mon Sep 17 00:00:00 2001
From: Aarni Koskela
Date: Fri, 30 Jun 2023 13:11:49 +0300
Subject: API: use finally: for state.end()
---
modules/api/api.py | 29 ++++++++++++++---------------
1 file changed, 14 insertions(+), 15 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 3ea099ad..8b79495d 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -602,37 +602,35 @@ class Api:
shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
- shared.state.end()
return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create embedding error: {e}")
+ finally:
+ shared.state.end()
+
def create_hypernetwork(self, args: dict):
try:
shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
- shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create hypernetwork error: {e}")
+ finally:
+ shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
- return models.PreprocessResponse(info = 'preprocess complete')
+ return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
- shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
- except AssertionError as e:
- shared.state.end()
+ except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
- except FileNotFoundError as e:
+ finally:
shared.state.end()
- return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
@@ -649,11 +647,11 @@ class Api:
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
- shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError as msg:
- shared.state.end()
+ except Exception as msg:
return models.TrainResponse(info=f"train embedding error: {msg}")
+ finally:
+ shared.state.end()
def train_hypernetwork(self, args: dict):
try:
@@ -675,9 +673,10 @@ class Api:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError:
+ except Exception as exc:
+ return models.TrainResponse(info=f"train embedding error: {exc}")
+ finally:
shared.state.end()
- return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try:
--
cgit v1.2.3
From 32788873176e9d79e1fffd6f89f94b6d0ec8bb91 Mon Sep 17 00:00:00 2001
From: ramyma
Date: Mon, 3 Jul 2023 20:02:30 +0300
Subject: Handle cleanup in case there's an exception thrown
---
modules/api/api.py | 57 +++++++++++++++++++++++++++++-------------------------
1 file changed, 31 insertions(+), 26 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index f10e3fe3..d9278e9e 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -323,19 +323,21 @@ class Api:
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
-
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
- p.close()
+ try:
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+
+ shared.state.begin()
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
+ finally:
+ p.close()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -380,20 +382,23 @@ class Api:
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
- p.init_images = [decode_base64_to_image(x) for x in init_images]
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
+ try:
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+
+ shared.state.begin()
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
- p.close()
+ finally:
+ p.close()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
--
cgit v1.2.3
From c1c04928596f69ddb39b8841a8435ecefb0594e9 Mon Sep 17 00:00:00 2001
From: ramyma
Date: Mon, 3 Jul 2023 20:17:47 +0300
Subject: Use contextlib for closing the generation process
---
modules/api/api.py | 11 +++--------
1 file changed, 3 insertions(+), 8 deletions(-)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index d9278e9e..e92c2938 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -30,6 +30,7 @@ from modules import devices
from typing import Dict, List, Any
import piexif
import piexif.helper
+from contextlib import closing
def script_name_to_index(name, scripts):
@@ -322,8 +323,7 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
- try:
+ with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
@@ -336,8 +336,6 @@ class Api:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
- finally:
- p.close()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -381,8 +379,7 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
- try:
+ with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids
@@ -397,8 +394,6 @@ class Api:
processed = process_images(p)
shared.state.end()
- finally:
- p.close()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
--
cgit v1.2.3
From 259967b7c60cbd2aeb091e691b5f49d9fb64b872 Mon Sep 17 00:00:00 2001
From: jovijovi
Date: Thu, 6 Jul 2023 18:43:17 +0800
Subject: fix(api): convert to "RGB" if image mode is "RGBA"
---
modules/api/api.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules/api/api.py')
diff --git a/modules/api/api.py b/modules/api/api.py
index 2e49526e..6507f641 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -84,6 +84,8 @@ def encode_pil_to_base64(image):
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+ if image.mode == "RGBA":
+ image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
--
cgit v1.2.3