diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-08 12:10:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-08 12:10:10 +0000 |
commit | ec9bbda3da846b4aa2f03b1e0f0952ffbc61f4f6 (patch) | |
tree | 5ae676b6ac7242ad082ab7421cf8ae9d50d52818 /modules/api/api.py | |
parent | c4c63dd5e4760c56405cef2e71abc5c3604c4578 (diff) | |
parent | 18256c5f0174126cb103afece2b39b6b831e034a (diff) | |
download | stable-diffusion-webui-gfx803-ec9bbda3da846b4aa2f03b1e0f0952ffbc61f4f6.tar.gz stable-diffusion-webui-gfx803-ec9bbda3da846b4aa2f03b1e0f0952ffbc61f4f6.tar.bz2 stable-diffusion-webui-gfx803-ec9bbda3da846b4aa2f03b1e0f0952ffbc61f4f6.zip |
Merge branch 'dev' into img2img-batch-png-info
Diffstat (limited to 'modules/api/api.py')
-rw-r--r-- | modules/api/api.py | 152 |
1 files changed, 91 insertions, 61 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb..224bbfc6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -22,20 +22,15 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases +from modules.sd_vae import vae_dict from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import Dict, List, Any import piexif import piexif.helper - - -def upscaler_to_index(name: str): - try: - return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e +from contextlib import closing def script_name_to_index(name, scripts): @@ -83,6 +78,8 @@ def encode_pil_to_base64(image): image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + if image.mode == "RGBA": + image = image.convert("RGB") parameters = image.info.get('parameters', None) exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } @@ -108,7 +105,6 @@ def api_middleware(app: FastAPI): from rich.console import Console console = Console() except Exception: - import traceback rich_available = False @app.middleware("http") @@ -139,11 +135,12 @@ def api_middleware(app: FastAPI): "errors": str(e), } if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions - print(f"API error: {request.method}: {request.url} {err}") + message = f"API error: {request.method}: {request.url} {err}" if rich_available: + print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - traceback.print_exc() + errors.report(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") @@ -188,7 +185,9 @@ class Api: self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) @@ -206,6 +205,11 @@ class Api: self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + if shared.cmd_opts.api_server_stop: + self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) + self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -278,7 +282,7 @@ class Api: script_args[0] = selectable_idx + 1 # Now check for always on scripts - if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + if request.alwayson_scripts: for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script is None: @@ -321,19 +325,19 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - p.scripts = script_runner - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples - - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() + with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: + p.scripts = script_runner + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + + shared.state.begin(job="scripts_txt2img") + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -377,20 +381,20 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - p.init_images = [decode_base64_to_image(x) for x in init_images] - p.scripts = script_runner - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples - - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() + with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: + p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + + shared.state.begin(job="scripts_img2img") + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -514,6 +518,10 @@ class Api: return options def set_config(self, req: Dict[str, Any]): + checkpoint_name = req.get("sd_model_checkpoint", None) + if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases: + raise RuntimeError(f"model {checkpoint_name!r} not found") + for k, v in req.items(): shared.opts.set(k, v) @@ -538,9 +546,20 @@ class Api: for upscaler in shared.sd_upscalers ] + def get_latent_upscale_modes(self): + return [ + { + "name": upscale_mode, + } + for upscale_mode in [*(shared.latent_upscale_modes or {})] + ] + def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] + def get_sd_vaes(self): + return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] + def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -583,44 +602,42 @@ class Api: def create_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used - shared.state.end() return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create embedding error: {e}") + finally: + shared.state.end() + def create_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding - shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create hypernetwork error: {e}") + finally: + shared.state.end() def preprocess(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() - return models.PreprocessResponse(info = 'preprocess complete') + return models.PreprocessResponse(info='preprocess complete') except KeyError as e: - shared.state.end() return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except AssertionError as e: - shared.state.end() + except Exception as e: return models.PreprocessResponse(info=f"preprocess error: {e}") - except FileNotFoundError as e: + finally: shared.state.end() - return models.PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_embedding") apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -633,15 +650,15 @@ class Api: finally: if not apply_optimizations: sd_hijack.apply_optimizations() - shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError as msg: - shared.state.end() + except Exception as msg: return models.TrainResponse(info=f"train embedding error: {msg}") + finally: + shared.state.end() def train_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_hypernetwork") shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None @@ -659,9 +676,10 @@ class Api: sd_hijack.apply_optimizations() shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError: + except Exception as exc: + return models.TrainResponse(info=f"train embedding error: {exc}") + finally: shared.state.end() - return models.TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: @@ -700,4 +718,16 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) + + def kill_webui(self): + restart.stop_program() + + def restart_webui(self): + if restart.is_restartable(): + restart.restart_program() + return Response(status_code=501) + + def stop_webui(request): + shared.state.server_command = "stop" + return Response("Stopping.") |