diff options
author | Dynamic <bradje@naver.com> | 2022-10-23 13:36:56 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-23 13:36:56 +0000 |
commit | 660ae690bd7107b78aac6413e1370f8cd72676bc (patch) | |
tree | b666cfd0872687ccd293a41d9d0a90fcdfe1ea0a /modules/api | |
parent | 21364c5c39b269497944b56dd6664792d779333b (diff) | |
parent | 6bd6154a92eb05c80d66df661a38f8b70cc13729 (diff) | |
download | stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.tar.gz stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.tar.bz2 stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.zip |
Merge branch 'AUTOMATIC1111:master' into kr-localization
Diffstat (limited to 'modules/api')
-rw-r--r-- | modules/api/api.py | 124 | ||||
-rw-r--r-- | modules/api/processing.py | 106 |
2 files changed, 230 insertions, 0 deletions
diff --git a/modules/api/api.py b/modules/api/api.py new file mode 100644 index 00000000..3caa83a4 --- /dev/null +++ b/modules/api/api.py @@ -0,0 +1,124 @@ +from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.sd_samplers import all_samplers +from modules.extras import run_pnginfo +import modules.shared as shared +import uvicorn +from fastapi import Body, APIRouter, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field, Json +import json +import io +import base64 +from PIL import Image + +sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + +class ImageToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + +class Api: + def __init__(self, app, queue_lock): + self.router = APIRouter() + self.app = app + self.queue_lock = queue_lock + self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + + def __base64_to_image(self, base64_string): + # if has a comma, deal with prefix + if "," in base64_string: + base64_string = base64_string.split(",")[1] + imgdata = base64.b64decode(base64_string) + # convert base64 to PIL image + return Image.open(io.BytesIO(imgdata)) + + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + sampler_index = sampler_to_index(txt2imgreq.sampler_index) + + if sampler_index is None: + raise HTTPException(status_code=404, detail="Sampler not found") + + populate = txt2imgreq.copy(update={ # Override __init__ params + "sd_model": shared.sd_model, + "sampler_index": sampler_index[0], + "do_not_save_samples": True, + "do_not_save_grid": True + } + ) + p = StableDiffusionProcessingTxt2Img(**vars(populate)) + # Override object param + with self.queue_lock: + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) + + + + def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): + sampler_index = sampler_to_index(img2imgreq.sampler_index) + + if sampler_index is None: + raise HTTPException(status_code=404, detail="Sampler not found") + + + init_images = img2imgreq.init_images + if init_images is None: + raise HTTPException(status_code=404, detail="Init image not found") + + mask = img2imgreq.mask + if mask: + mask = self.__base64_to_image(mask) + + + populate = img2imgreq.copy(update={ # Override __init__ params + "sd_model": shared.sd_model, + "sampler_index": sampler_index[0], + "do_not_save_samples": True, + "do_not_save_grid": True, + "mask": mask + } + ) + p = StableDiffusionProcessingImg2Img(**vars(populate)) + + imgs = [] + for img in init_images: + img = self.__base64_to_image(img) + imgs = [img] * p.batch_size + + p.init_images = imgs + # Override object param + with self.queue_lock: + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) + + def extrasapi(self): + raise NotImplementedError + + def pnginfoapi(self): + raise NotImplementedError + + def launch(self, server_name, port): + self.app.include_router(self.router) + uvicorn.run(self.app, host=server_name, port=port) diff --git a/modules/api/processing.py b/modules/api/processing.py new file mode 100644 index 00000000..f551fa35 --- /dev/null +++ b/modules/api/processing.py @@ -0,0 +1,106 @@ +from array import array +from inflection import underscore +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field, create_model +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img +import inspect + + +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + +class ModelDef(BaseModel): + """Assistance Class for Pydantic Dynamic Model Generation""" + + field: str + field_alias: str + field_type: Any + field_value: Any + + +class PydanticModelGenerator: + """ + Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: + source_data is a snapshot of the default values produced by the class + params are the names of the actual keys required by __init__ + """ + + def __init__( + self, + model_name: str = None, + class_instance = None, + additional_fields = None, + ): + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation + + return Optional[field_type] + + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + + self._model_name = model_name + self._class_data = merge_class_params(class_instance) + self._model_def = [ + ModelDef( + field=underscore(k), + field_alias=k, + field_type=field_type_generator(k, v), + field_value=v.default + ) + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED + ] + + for fields in additional_fields: + self._model_def.append(ModelDef( + field=underscore(fields["key"]), + field_alias=fields["key"], + field_type=fields["type"], + field_value=fields["default"])) + + def generate_model(self): + """ + Creates a pydantic BaseModel + from the json and overrides provided at initialization + """ + fields = { + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + } + DynamicModel = create_model(self._model_name, **fields) + DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True + return DynamicModel + +StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingTxt2Img", + StableDiffusionProcessingTxt2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() + +StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingImg2Img", + StableDiffusionProcessingImg2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] +).generate_model()
\ No newline at end of file |