diff options
author | Maiko Tan <maiko.tan.coding@gmail.com> | 2022-11-19 12:13:07 +0000 |
---|---|---|
committer | Maiko Tan <maiko.tan.coding@gmail.com> | 2022-11-19 12:13:07 +0000 |
commit | 336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f (patch) | |
tree | 6760fd7f15a9049365a1ee2c56de37dd456504bd /modules/api/api.py | |
parent | 8f2ff861d31972d12de278075ea9c0c0deef99de (diff) | |
parent | 84a6f211d407cd748c603edc3a81862488505c24 (diff) | |
download | stable-diffusion-webui-gfx803-336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f.tar.gz stable-diffusion-webui-gfx803-336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f.tar.bz2 stable-diffusion-webui-gfx803-336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f.zip |
Merge branch 'master' into api-authorization
Diffstat (limited to 'modules/api/api.py')
-rw-r--r-- | modules/api/api.py | 42 |
1 files changed, 17 insertions, 25 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 6bb01603..195e8b58 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -9,9 +9,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared +from modules import sd_samplers from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo from PIL import PngImagePlugin from modules.sd_models import checkpoints_list @@ -28,8 +28,12 @@ def upscaler_to_index(name: str): raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") -sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def validate_sampler_name(name): + config = sd_samplers.all_samplers_map.get(name, None) + if config is None: + raise HTTPException(status_code=404, detail="Sampler not found") + return name def setUpscalers(req: dict): reqDict = vars(req) @@ -77,6 +81,7 @@ class Api: self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) @@ -103,14 +108,9 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - sampler_index = sampler_to_index(txt2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } @@ -130,12 +130,6 @@ class Api: return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): - sampler_index = sampler_to_index(img2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - - init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -144,10 +138,9 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask @@ -266,6 +259,9 @@ class Api: return {} + def skip(self): + shared.state.skip() + def get_config(self): options = {} for key in shared.opts.data.keys(): @@ -277,14 +273,10 @@ class Api: return options - def set_config(self, req: OptionsModel): - # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will - # overwrite all options with default values. - raise RuntimeError('Setting options via API is not supported') - - reqDict = vars(req) - for o in reqDict: - setattr(shared.opts, o, reqDict[o]) + def set_config(self, req: Dict[str, Any]): + + for o in req: + setattr(shared.opts, o, req[o]) shared.opts.save(shared.config_filename) return @@ -293,7 +285,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] |