aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--launch.py2
-rw-r--r--modules/api/api.py20
-rw-r--r--modules/api/models.py13
3 files changed, 33 insertions, 2 deletions
diff --git a/launch.py b/launch.py
index 958336f2..fe9cef3c 100644
--- a/launch.py
+++ b/launch.py
@@ -220,7 +220,7 @@ def tests(argv):
def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
import webui
- webui.webui()
+ webui.webui_or_api()
if __name__ == "__main__":
diff --git a/modules/api/api.py b/modules/api/api.py
index 6c06d449..c510a833 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -40,6 +40,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -116,6 +117,8 @@ class Api:
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+ def extrasapi(self):
+ raise NotImplementedError
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
@@ -176,6 +179,23 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+ def interrogateapi(self, interrogatereq: InterrogateRequest):
+ image_b64 = interrogatereq.image
+ if image_b64 is None:
+ raise HTTPException(status_code=404, detail="Image not found")
+
+ populate = interrogatereq.copy(update={ # Override __init__ params
+ }
+ )
+
+ img = self.__base64_to_image(image_b64)
+
+ # Override object param
+ with self.queue_lock:
+ processed = shared.interrogator.interrogate(img)
+
+ return InterrogateResponse(caption=processed)
+
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
diff --git a/modules/api/models.py b/modules/api/models.py
index 9ee42a17..035a7179 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -64,7 +64,12 @@ class PydanticModelGenerator:
self._model_name = model_name
- self._class_data = merge_class_params(class_instance)
+
+ if class_instance is not None:
+ self._class_data = merge_class_params(class_instance)
+ else:
+ self._class_data = {}
+
self._model_def = [
ModelDef(
field=underscore(k),
@@ -165,3 +170,9 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+
+class InterrogateRequest(BaseModel):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+
+class InterrogateResponse(BaseModel):
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")