diff options
Diffstat (limited to 'modules/api/api.py')
-rw-r--r-- | modules/api/api.py | 222 |
1 files changed, 178 insertions, 44 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index eb7b1da5..9ffcbd5f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -3,11 +3,14 @@ import io import time import datetime import uvicorn +import gradio as gr from threading import Lock from io import BytesIO -from gradio.processing_utils import decode_base64_to_file -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response +from fastapi import APIRouter, Depends, FastAPI, Request, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.exceptions import HTTPException +from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared @@ -18,7 +21,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -90,6 +93,16 @@ def encode_pil_to_base64(image): return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): + rich_available = True + try: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + console = Console() + except: + import traceback + rich_available = False + @app.middleware("http") async def log_and_time(req: Request, call_next): ts = time.time() @@ -110,6 +123,36 @@ def api_middleware(app: FastAPI): )) return res + def handle_exception(request: Request, e: Exception): + err = { + "error": type(e).__name__, + "detail": vars(e).get('detail', ''), + "body": vars(e).get('body', ''), + "errors": str(e), + } + print(f"API error: {request.method}: {request.url} {err}") + if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if rich_available: + console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) + else: + traceback.print_exc() + return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) + + @app.middleware("http") + async def exception_handling(request: Request, call_next): + try: + return await call_next(request) + except Exception as e: + return handle_exception(request, e) + + @app.exception_handler(Exception) + async def fastapi_exception_handler(request: Request, e: Exception): + return handle_exception(request, e) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, e: HTTPException): + return handle_exception(request, e) + class Api: def __init__(self, app: FastAPI, queue_lock: Lock): @@ -150,6 +193,12 @@ class Api: self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) + self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) + + self.default_script_arg_txt2img = [] + self.default_script_arg_img2img = [] def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -163,47 +212,113 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) - def get_script(self, script_name, script_runner): - if script_name is None: + def get_selectable_script(self, script_name, script_runner): + if script_name is None or script_name == "": return None, None - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx + + def get_scripts_list(self): + t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] + i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] + + return ScriptsList(txt2img = t2ilist, img2img = i2ilist) + + def get_script(self, script_name, script_runner): + if script_name is None or script_name == "": + return None, None + + script_idx = script_name_to_index(script_name, script_runner.scripts) + return script_runner.scripts[script_idx] + + def init_default_script_args(self, script_runner): + #find max idx from the scripts in runner and generate a none array to init script_args + last_arg_index = 1 + for script in script_runner.scripts: + if last_arg_index < script.args_to: + last_arg_index = script.args_to + # None everywhere except position 0 to initialize script args + script_args = [None]*last_arg_index + script_args[0] = 0 + + # get default values + with gr.Blocks(): # will throw errors calling ui function without this + for script in script_runner.scripts: + if script.ui(script.is_img2img): + ui_default_values = [] + for elem in script.ui(script.is_img2img): + ui_default_values.append(elem.value) + script_args[script.args_from:script.args_to] = ui_default_values + return script_args + + def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): + script_args = default_script_args.copy() + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() + if selectable_scripts: + script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args + script_args[0] = selectable_idx + 1 + + # Now check for always on scripts + if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + for alwayson_script_name in request.alwayson_scripts.keys(): + alwayson_script = self.get_script(alwayson_script_name, script_runner) + if alwayson_script == None: + raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") + # Selectable script in always on script param check + if alwayson_script.alwayson == False: + raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + # always on script with no arg should always run so you don't really need to add them to the requests + if "args" in request.alwayson_scripts[alwayson_script_name]: + # min between arg length in scriptrunner and arg length in the request + for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))): + script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] + return script_args def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) + script_runner = scripts.scripts_txt2img + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(script_runner) + selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) - populate = txt2imgreq.copy(update={ # Override __init__ params + populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True - } - ) + "do_not_save_samples": not txt2imgreq.save_images, + "do_not_save_grid": not txt2imgreq.save_images, + }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('script_name', None) + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them + args.pop('alwayson_scripts', None) + + script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) + + send_images = args.pop('send_images', True) + args.pop('save_images', None) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) + p.scripts = script_runner + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples shared.state.begin() - if script is not None: - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args - processed = scripts.scripts_txt2img.run(p, *p.script_args) + if selectable_scripts != None: + p.script_args = script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -212,41 +327,55 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") - script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) - mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params + script_runner = scripts.scripts_img2img + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(script_runner) + selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) + + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True, - "mask": mask - } - ) + "do_not_save_samples": not img2imgreq.save_images, + "do_not_save_grid": not img2imgreq.save_images, + "mask": mask, + }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them + args.pop('alwayson_scripts', None) + + script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) + + send_images = args.pop('send_images', True) + args.pop('save_images', None) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples shared.state.begin() - if script is not None: - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args - processed = scripts.scripts_img2img.run(p, *p.script_args) + if selectable_scripts != None: + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] if not img2imgreq.include_init_images: img2imgreq.init_images = None @@ -267,16 +396,11 @@ class Api: def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): reqDict = setUpscalers(req) - def prepareFiles(file): - file = decode_base64_to_file(file.data, file_path=file.name) - file.orig_name = file.name - return file - - reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList'])) - reqDict.pop('imageList') + image_list = reqDict.pop('imageList', []) + image_folder = [decode_base64_to_image(x.data) for x in image_list] with self.queue_lock: - result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) @@ -348,6 +472,16 @@ class Api: return {} + def unloadapi(self): + unload_model_weights() + + return {} + + def reloadapi(self): + reload_model_weights() + + return {} + def skip(self): shared.state.skip() @@ -498,7 +632,7 @@ class Api: if not apply_optimizations: sd_hijack.undo_optimizations() try: - hypernetwork, filename = train_hypernetwork(*args) + hypernetwork, filename = train_hypernetwork(**args) except Exception as e: error = e finally: |