diff options
Diffstat (limited to 'modules/api')
-rw-r--r-- | modules/api/api.py | 47 |
1 files changed, 29 insertions, 18 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index f2e9e884..b1dd14cc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -23,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List +import piexif +import piexif.helper def upscaler_to_index(name: str): try: @@ -45,10 +46,8 @@ def validate_sampler_name(name): def setUpscalers(req: dict): reqDict = vars(req) - reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) - reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) + reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def decode_base64_to_image(encoding): @@ -59,18 +58,30 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - # Copy any text-only metadata - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True + if opts.samples_format.lower() == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + parameters = image.info.get('parameters', None) + exif_bytes = piexif.dump({ + "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } + }) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + else: + image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") - image.save( - output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) - ) bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): @@ -244,7 +255,7 @@ class Api: reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) @@ -260,7 +271,7 @@ class Api: reqDict.pop('imageList') with self.queue_lock: - result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) |