aboutsummaryrefslogtreecommitdiffstats
path: root/modules/api/api.py
diff options
context:
space:
mode:
authorMaiko Tan <maiko.tan.coding@gmail.com>2022-11-19 12:13:07 +0000
committerMaiko Tan <maiko.tan.coding@gmail.com>2022-11-19 12:13:07 +0000
commit336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f (patch)
tree6760fd7f15a9049365a1ee2c56de37dd456504bd /modules/api/api.py
parent8f2ff861d31972d12de278075ea9c0c0deef99de (diff)
parent84a6f211d407cd748c603edc3a81862488505c24 (diff)
downloadstable-diffusion-webui-gfx803-336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f.tar.gz
stable-diffusion-webui-gfx803-336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f.tar.bz2
stable-diffusion-webui-gfx803-336c341a7c3fe81cdf0fc45616ed0c16c79a2c6f.zip
Merge branch 'master' into api-authorization
Diffstat (limited to 'modules/api/api.py')
-rw-r--r--modules/api/api.py42
1 files changed, 17 insertions, 25 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 6bb01603..195e8b58 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -9,9 +9,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
+from modules import sd_samplers
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
@@ -28,8 +28,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
-sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+def validate_sampler_name(name):
+ config = sd_samplers.all_samplers_map.get(name, None)
+ if config is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+ return name
def setUpscalers(req: dict):
reqDict = vars(req)
@@ -77,6 +81,7 @@ class Api:
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
@@ -103,14 +108,9 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
- sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-
- if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
- "sampler_index": sampler_index[0],
+ "sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
@@ -130,12 +130,6 @@ class Api:
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
- sampler_index = sampler_to_index(img2imgreq.sampler_index)
-
- if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
-
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -144,10 +138,9 @@ class Api:
if mask:
mask = decode_base64_to_image(mask)
-
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
- "sampler_index": sampler_index[0],
+ "sampler_name": validate_sampler_name(img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
@@ -266,6 +259,9 @@ class Api:
return {}
+ def skip(self):
+ shared.state.skip()
+
def get_config(self):
options = {}
for key in shared.opts.data.keys():
@@ -277,14 +273,10 @@ class Api:
return options
- def set_config(self, req: OptionsModel):
- # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
- # overwrite all options with default values.
- raise RuntimeError('Setting options via API is not supported')
-
- reqDict = vars(req)
- for o in reqDict:
- setattr(shared.opts, o, reqDict[o])
+ def set_config(self, req: Dict[str, Any]):
+
+ for o in req:
+ setattr(shared.opts, o, req[o])
shared.opts.save(shared.config_filename)
return
@@ -293,7 +285,7 @@ class Api:
return vars(shared.cmd_opts)
def get_samplers(self):
- return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
+ return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
upscalers = []