aboutsummaryrefslogtreecommitdiffstats
path: root/modules/api/api.py
diff options
context:
space:
mode:
authorevshiron <evshiron@gmail.com>2022-10-29 19:45:29 +0000
committerevshiron <evshiron@gmail.com>2022-10-29 19:45:29 +0000
commit6b719c49b193a7dfeb64aacfdc8437037e07a2d1 (patch)
tree2ec883e44f3f959688ba65a4c5f194470bc7d467 /modules/api/api.py
parentfddb4883f4a408b3464076465e1b0949ebe0fc30 (diff)
parent35c45df28b303a05d56a13cb56d4046f08cf8c25 (diff)
downloadstable-diffusion-webui-gfx803-6b719c49b193a7dfeb64aacfdc8437037e07a2d1.tar.gz
stable-diffusion-webui-gfx803-6b719c49b193a7dfeb64aacfdc8437037e07a2d1.tar.bz2
stable-diffusion-webui-gfx803-6b719c49b193a7dfeb64aacfdc8437037e07a2d1.zip
Merge branch 'master' into feat/progress-api
Diffstat (limited to 'modules/api/api.py')
-rw-r--r--modules/api/api.py150
1 files changed, 89 insertions, 61 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index c038f674..9d68ac23 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,37 +1,43 @@
-import time
+# import time
+
+# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+# from modules.sd_samplers import all_samplers
+# from modules.extras import run_pnginfo
+# import modules.shared as shared
+# from modules import devices
+# import uvicorn
+# from fastapi import Body, APIRouter, HTTPException
+# from fastapi.responses import JSONResponse
+# from pydantic import BaseModel, Field, Json
+# from typing import List
+# import json
+# import io
+# import base64
+# from PIL import Image
+
+# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
+# class TextToImageResponse(BaseModel):
+# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+# parameters: Json
+# info: Json
+
+# class ImageToImageResponse(BaseModel):
+# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+# parameters: Json
+# info: Json
-from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers
-from modules.extras import run_pnginfo
+import time
+import uvicorn
+from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from fastapi import APIRouter, HTTPException
import modules.shared as shared
from modules import devices
-import uvicorn
-from fastapi import Body, APIRouter, HTTPException
-from fastapi.responses import JSONResponse
-from pydantic import BaseModel, Field, Json
-from typing import List
-import json
-import io
-import base64
-from PIL import Image
-
-sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-
-class TextToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
-
-class ImageToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
-
-class ProgressResponse(BaseModel):
- progress: float
- eta_relative: float
- state: Json
+from modules.api.models import *
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.sd_samplers import all_samplers
+from modules.extras import run_extras
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
@@ -53,30 +59,39 @@ def before_gpu_call():
shared.state.textinfo = None
shared.state.time_start = time.time()
-
def after_gpu_call():
shared.state.job = ""
shared.state.job_count = 0
devices.torch_gc()
+def upscaler_to_index(name: str):
+ try:
+ return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
+ except:
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+
+sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
+def setUpscalers(req: dict):
+ reqDict = vars(req)
+ reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
+ reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
+ reqDict.pop('upscaler_1')
+ reqDict.pop('upscaler_2')
+ return reqDict
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
- self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
- self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
+ self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
+ self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
+ self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
- def __base64_to_image(self, base64_string):
- # if has a comma, deal with prefix
- if "," in base64_string:
- base64_string = base64_string.split(",")[1]
- imgdata = base64.b64decode(base64_string)
- # convert base64 to PIL image
- return Image.open(io.BytesIO(imgdata))
-
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -97,15 +112,9 @@ class Api:
processed = process_images(p)
after_gpu_call()
- b64images = []
- for i in processed.images:
- buffer = io.BytesIO()
- i.save(buffer, format="png")
- b64images.append(base64.b64encode(buffer.getvalue()))
-
- return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
-
+ b64images = list(map(encode_pil_to_base64, processed.images))
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
@@ -120,7 +129,7 @@ class Api:
mask = img2imgreq.mask
if mask:
- mask = self.__base64_to_image(mask)
+ mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
@@ -135,7 +144,7 @@ class Api:
imgs = []
for img in init_images:
- img = self.__base64_to_image(img)
+ img = decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
@@ -145,17 +154,39 @@ class Api:
processed = process_images(p)
after_gpu_call()
- b64images = []
- for i in processed.images:
- buffer = io.BytesIO()
- i.save(buffer, format="png")
- b64images.append(base64.b64encode(buffer.getvalue()))
+ b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
- return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
+ return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+
+ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+ reqDict = setUpscalers(req)
+
+ reqDict['image'] = decode_base64_to_image(reqDict['image'])
+
+ with self.queue_lock:
+ result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
+
+ return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+
+ def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+ reqDict = setUpscalers(req)
+
+ def prepareFiles(file):
+ file = decode_base64_to_file(file.data, file_path=file.name)
+ file.orig_name = file.name
+ return file
+
+ reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
+ reqDict.pop('imageList')
+
+ with self.queue_lock:
+ result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
+
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def progressapi(self):
# copy from check_progress_call of ui.py
@@ -179,9 +210,6 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
- def extrasapi(self):
- raise NotImplementedError
-
def pnginfoapi(self):
raise NotImplementedError