diff options
Diffstat (limited to 'modules/api')
-rw-r--r-- | modules/api/api.py | 48 | ||||
-rw-r--r-- | modules/api/models.py | 4 |
2 files changed, 44 insertions, 8 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 1c121ff0..6c564ad8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras @@ -28,8 +28,13 @@ def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}") +def script_name_to_index(name, scripts): + try: + return [script.title().lower() for script in scripts].index(name.lower()) + except: + raise HTTPException(status_code=422, detail=f"Script '{name}' not found") def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) @@ -144,7 +149,21 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) + def get_script(self, script_name, script_runner): + if script_name is None: + return None, None + + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + + script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) + script = script_runner.selectable_scripts[script_idx] + return script, script_idx + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) + populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, @@ -154,14 +173,22 @@ class Api: if populate.sampler_name: populate.sampler_index = None # prevent a warning later on + args = vars(populate) + args.pop('script_name', None) + with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate)) + p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) shared.state.begin() - processed = process_images(p) + if script is not None: + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) + else: + processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -171,6 +198,8 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") + script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) + mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) @@ -187,13 +216,20 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. + args.pop('script_name', None) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() - processed = process_images(p) + if script is not None: + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) + else: + processed = process_images(p) shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/api/models.py b/modules/api/models.py index 49bf1e7a..880edde6 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,13 +100,13 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() class TextToImageResponse(BaseModel): |