diff options
Diffstat (limited to 'modules/api/api.py')
-rw-r--r-- | modules/api/api.py | 31 |
1 files changed, 19 insertions, 12 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 7a567be3..89935a70 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -3,7 +3,8 @@ import io import time import uvicorn from threading import Lock -from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image +from io import BytesIO +from gradio.processing_utils import decode_base64_to_file from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest @@ -13,7 +14,7 @@ from modules import sd_samplers, deepbooru from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras, run_pnginfo -from PIL import PngImagePlugin +from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List @@ -40,6 +41,10 @@ def setUpscalers(req: dict): reqDict.pop('upscaler_2') return reqDict +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + return Image.open(BytesIO(base64.b64decode(encoding))) def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: @@ -107,11 +112,13 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), + "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } ) + if populate.sampler_name: + populate.sampler_index = None # prevent a warning later on p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param @@ -137,20 +144,20 @@ class Api: populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_name": validate_sampler_name(img2imgreq.sampler_index), + "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask } ) - p = StableDiffusionProcessingImg2Img(**vars(populate)) + if populate.sampler_name: + populate.sampler_index = None # prevent a warning later on - imgs = [] - for img in init_images: - img = decode_base64_to_image(img) - imgs = [img] * p.batch_size + args = vars(populate) + args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. + p = StableDiffusionProcessingImg2Img(**args) - p.init_images = imgs + p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() @@ -161,7 +168,7 @@ class Api: b64images = list(map(encode_pil_to_base64, processed.images)) - if (not img2imgreq.include_init_images): + if not img2imgreq.include_init_images: img2imgreq.init_images = None img2imgreq.mask = None @@ -305,7 +312,7 @@ class Api: styleList = [] for k in shared.prompt_styles.styles: style = shared.prompt_styles.styles[k] - styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]}) + styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]}) return styleList |