aboutsummaryrefslogtreecommitdiffstats
path: root/modules/api/api.py
diff options
context:
space:
mode:
authorMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-11-07 15:43:38 +0000
committerMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-11-07 15:43:38 +0000
commitcabd4e3b3bf91e0cb5071398a8efddef495f6311 (patch)
tree55daa888a7e03e2e204daf6729835b94277350a2 /modules/api/api.py
parentbb832d7725187f8a8ab44faa6ee1b38cb5f600aa (diff)
parent804d9fb83d0c63ca3acd36378707ce47b8f12599 (diff)
downloadstable-diffusion-webui-gfx803-cabd4e3b3bf91e0cb5071398a8efddef495f6311.tar.gz
stable-diffusion-webui-gfx803-cabd4e3b3bf91e0cb5071398a8efddef495f6311.tar.bz2
stable-diffusion-webui-gfx803-cabd4e3b3bf91e0cb5071398a8efddef495f6311.zip
Merge branch 'master' into gradient-clipping
Diffstat (limited to 'modules/api/api.py')
-rw-r--r--modules/api/api.py53
1 files changed, 42 insertions, 11 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index a49f3755..688469ad 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -10,6 +10,7 @@ from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
+from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
@@ -34,9 +35,21 @@ def setUpscalers(req: dict):
def encode_pil_to_base64(image):
- buffer = io.BytesIO()
- image.save(buffer, format="png")
- return base64.b64encode(buffer.getvalue())
+ with io.BytesIO() as output_bytes:
+
+ # Copy any text-only metadata
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ image.save(
+ output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
+ )
+ bytes_data = output_bytes.getvalue()
+ return base64.b64encode(bytes_data)
class Api:
@@ -50,6 +63,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
@@ -201,11 +215,24 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+ def interrogateapi(self, interrogatereq: InterrogateRequest):
+ image_b64 = interrogatereq.image
+ if image_b64 is None:
+ raise HTTPException(status_code=404, detail="Image not found")
+
+ img = self.__base64_to_image(image_b64)
+
+ # Override object param
+ with self.queue_lock:
+ processed = shared.interrogator.interrogate(img)
+
+ return InterrogateResponse(caption=processed)
+
def interruptapi(self):
shared.state.interrupt()
return {}
-
+
def get_config(self):
options = {}
for key in shared.opts.data.keys():
@@ -214,10 +241,14 @@ class Api:
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
else:
options.update({key: shared.opts.data.get(key, None)})
-
+
return options
-
+
def set_config(self, req: OptionsModel):
+ # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
+ # overwrite all options with default values.
+ raise RuntimeError('Setting options via API is not supported')
+
reqDict = vars(req)
for o in reqDict:
setattr(shared.opts, o, reqDict[o])
@@ -233,13 +264,13 @@ class Api:
def get_upscalers(self):
upscalers = []
-
+
for upscaler in shared.sd_upscalers:
u = upscaler.scaler
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
-
+
return upscalers
-
+
def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
@@ -251,11 +282,11 @@ class Api:
def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
-
+
def get_promp_styles(self):
styleList = []
for k in shared.prompt_styles.styles:
- style = shared.prompt_styles.styles[k]
+ style = shared.prompt_styles.styles[k]
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
return styleList