diff options
-rw-r--r-- | modules/api/api.py | 115 | ||||
-rw-r--r-- | modules/api/models.py | 27 | ||||
-rw-r--r-- | modules/shared.py | 13 |
3 files changed, 129 insertions, 26 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 49c213ea..9d68ac23 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,12 +1,70 @@ +# import time + +# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI +# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +# from modules.sd_samplers import all_samplers +# from modules.extras import run_pnginfo +# import modules.shared as shared +# from modules import devices +# import uvicorn +# from fastapi import Body, APIRouter, HTTPException +# from fastapi.responses import JSONResponse +# from pydantic import BaseModel, Field, Json +# from typing import List +# import json +# import io +# import base64 +# from PIL import Image + +# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) + +# class TextToImageResponse(BaseModel): +# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") +# parameters: Json +# info: Json + +# class ImageToImageResponse(BaseModel): +# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") +# parameters: Json +# info: Json + +import time import uvicorn from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, HTTPException import modules.shared as shared +from modules import devices from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_extras +# copy from wrap_gradio_gpu_call of webui.py +# because queue lock will be acquired in api handlers +# and time start needs to be set +# the function has been modified into two parts + +def before_gpu_call(): + devices.torch_gc() + + shared.state.sampling_step = 0 + shared.state.job_count = -1 + shared.state.job_no = 0 + shared.state.job_timestamp = shared.state.get_job_timestamp() + shared.state.current_latent = None + shared.state.current_image = None + shared.state.current_image_sampling_step = 0 + shared.state.skipped = False + shared.state.interrupted = False + shared.state.textinfo = None + shared.state.time_start = time.time() + +def after_gpu_call(): + shared.state.job = "" + shared.state.job_count = 0 + + devices.torch_gc() + def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) @@ -32,15 +90,16 @@ class Api: self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) - + if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - + raise HTTPException(status_code=404, detail="Sampler not found") + populate = txt2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, + "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, "do_not_save_grid": True @@ -48,34 +107,36 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param + before_gpu_call() with self.queue_lock: processed = process_images(p) - + after_gpu_call() + b64images = list(map(encode_pil_to_base64, processed.images)) - + return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) - + if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") + raise HTTPException(status_code=404, detail="Sampler not found") init_images = img2imgreq.init_images if init_images is None: - raise HTTPException(status_code=404, detail="Init image not found") + raise HTTPException(status_code=404, detail="Init image not found") mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) - + populate = img2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, + "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, - "do_not_save_grid": True, + "do_not_save_grid": True, "mask": mask } ) @@ -88,15 +149,17 @@ class Api: p.init_images = imgs # Override object param + before_gpu_call() with self.queue_lock: processed = process_images(p) - + after_gpu_call() + b64images = list(map(encode_pil_to_base64, processed.images)) if (not img2imgreq.include_init_images): img2imgreq.init_images = None img2imgreq.mask = None - + return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) def extras_single_image_api(self, req: ExtrasSingleImageRequest): @@ -124,7 +187,29 @@ class Api: result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - + + def progressapi(self): + # copy from check_progress_call of ui.py + + if shared.state.job_count == 0: + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js()) + + # avoid dividing zero + progress = 0.01 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + + progress = min(progress, 1) + + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js()) + def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index dd122321..c374a627 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -51,17 +51,17 @@ class PydanticModelGenerator: # field_type = str if not overrides.get(k) else overrides[k]["type"] # print(k, v.annotation, v.default) field_type = v.annotation - + return Optional[field_type] - + def merge_class_params(class_): all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) parameters = {} for classes in all_classes: parameters = {**parameters, **inspect.signature(classes.__init__).parameters} return parameters - - + + self._model_name = model_name self._class_data = merge_class_params(class_instance) self._model_def = [ @@ -73,11 +73,11 @@ class PydanticModelGenerator: ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] - + for fields in additional_fields: self._model_def.append(ModelDef( - field=underscore(fields["key"]), - field_alias=fields["key"], + field=underscore(fields["key"]), + field_alias=fields["key"], field_type=fields["type"], field_value=fields["default"], field_exclude=fields["exclude"] if "exclude" in fields else False)) @@ -94,15 +94,15 @@ class PydanticModelGenerator: DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_mutation = True return DynamicModel - + StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingTxt2Img", + "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, [{"key": "sampler_index", "type": str, "default": "Euler"}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingImg2Img", + "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] ).generate_model() @@ -148,4 +148,9 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest): imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file + images: list[str] = Field(title="Images", description="The generated images in base64 format.") + +class ProgressResponse(BaseModel): + progress: float + eta_relative: float + state: dict diff --git a/modules/shared.py b/modules/shared.py index fb84afd8..0f4c035d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -147,6 +147,19 @@ class State: def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
+ def js(self):
+ obj = {
+ "skipped": self.skipped,
+ "interrupted": self.skipped,
+ "job": self.job,
+ "job_count": self.job_count,
+ "job_no": self.job_no,
+ "sampling_step": self.sampling_step,
+ "sampling_steps": self.sampling_steps,
+ }
+
+ return json.dumps(obj)
+
state = State()
|