diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 12 | ||||
-rw-r--r-- | modules/api/models.py | 9 | ||||
-rw-r--r-- | modules/processing.py | 12 | ||||
-rw-r--r-- | modules/scripts.py | 24 |
4 files changed, 47 insertions, 10 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 49c213ea..d0f488ca 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -5,7 +5,7 @@ import modules.shared as shared from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers -from modules.extras import run_extras +from modules.extras import run_extras, run_pnginfo def upscaler_to_index(name: str): try: @@ -32,6 +32,7 @@ class Api: self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -125,8 +126,13 @@ class Api: return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - def pnginfoapi(self): - raise NotImplementedError + def pnginfoapi(self, req: PNGInfoRequest): + if(not req.image.strip()): + return PNGInfoResponse(info="") + + result = run_pnginfo(decode_base64_to_image(req.image.strip())) + + return PNGInfoResponse(info=result[1]) def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index dd122321..58e8e58b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,4 +1,5 @@ import inspect +from click import prompt from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal @@ -148,4 +149,10 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest): imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file + images: list[str] = Field(title="Images", description="The generated images in base64 format.") + +class PNGInfoRequest(BaseModel): + image: str = Field(title="Image", description="The base64 encoded PNG image") + +class PNGInfoResponse(BaseModel): + info: str = Field(title="Image info", description="A string with all the info the image had")
\ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index 548eec29..50343846 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -478,7 +478,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
- p.scripts.run_alwayson_scripts(p)
+ p.scripts.process(p)
infotexts = []
output_images = []
@@ -501,7 +501,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
with devices.autocast():
@@ -590,7 +590,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
diff --git a/modules/scripts.py b/modules/scripts.py index a7f36012..96e44bfd 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -64,7 +64,16 @@ class Script: def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
- scripts. You can modify the processing object (p) here, inject hooks, etc.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
+ def postprocess(self, p, processed, *args):
+ """
+ This function is called after processing ends for AlwaysVisible scripts.
+ args contains all values returned by components from ui()
"""
pass
@@ -289,13 +298,22 @@ class ScriptRunner: return processed
- def run_alwayson_scripts(self, p):
+ def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(f"Error running process: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def postprocess(self, p, processed):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess(p, processed, *script_args)
+ except Exception:
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
|