diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-11-06 08:27:54 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-06 08:27:54 +0000 |
commit | 07d1bd426722b4c53b38ff682c5aab53177d8530 (patch) | |
tree | 4fdad803a4536cec2bd3e622c5f4cfb980f04550 /modules/api/models.py | |
parent | 3f3d14afd5abd07d3843370dc1c28be299dbdbab (diff) | |
parent | 6e4de5b4422dfc0d45063b2c8c78b19f00321615 (diff) | |
download | stable-diffusion-webui-gfx803-07d1bd426722b4c53b38ff682c5aab53177d8530.tar.gz stable-diffusion-webui-gfx803-07d1bd426722b4c53b38ff682c5aab53177d8530.tar.bz2 stable-diffusion-webui-gfx803-07d1bd426722b4c53b38ff682c5aab53177d8530.zip |
Merge branch 'master' into roy.add_simple_interrogate_api
Diffstat (limited to 'modules/api/models.py')
-rw-r--r-- | modules/api/models.py | 79 |
1 files changed, 73 insertions, 6 deletions
diff --git a/modules/api/models.py b/modules/api/models.py index 82ab29b8..34dbfa16 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,11 +1,11 @@ import inspect -from click import prompt from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img -from modules.shared import sd_upscalers +from modules.shared import sd_upscalers, opts, parser +from typing import Dict, List API_NOT_ALLOWED = [ "self", @@ -110,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( ).generate_model() class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str @@ -132,6 +132,7 @@ class ExtrasBaseRequest(BaseModel): upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") + upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?") class ExtraBaseResponse(BaseModel): html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") @@ -147,10 +148,10 @@ class FileData(BaseModel): name: str = Field(title="File name") class ExtrasBatchImagesRequest(ExtrasBaseRequest): - imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") + images: List[str] = Field(title="Images", description="The generated images in base64 format.") class PNGInfoRequest(BaseModel): image: str = Field(title="Image", description="The base64 encoded PNG image") @@ -172,3 +173,69 @@ class InterrogateRequest(BaseModel): class InterrogateResponse(BaseModel): caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + +fields = {} +for key, value in opts.data.items(): + metadata = opts.data_labels.get(key) + optType = opts.typemap.get(type(value), type(value)) + + if (metadata is not None): + fields.update({key: (Optional[optType], Field( + default=metadata.default ,description=metadata.label))}) + else: + fields.update({key: (Optional[optType], Field())}) + +OptionsModel = create_model("Options", **fields) + +flags = {} +_options = vars(parser)['_option_string_actions'] +for key in _options: + if(_options[key].dest != 'help'): + flag = _options[key] + _type = str + if _options[key].default is not None: _type = type(_options[key].default) + flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) + +FlagsModel = create_model("Flags", **flags) + +class SamplerItem(BaseModel): + name: str = Field(title="Name") + aliases: List[str] = Field(title="Aliases") + options: Dict[str, str] = Field(title="Options") + +class UpscalerItem(BaseModel): + name: str = Field(title="Name") + model_name: Optional[str] = Field(title="Model Name") + model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") + +class SDModelItem(BaseModel): + title: str = Field(title="Title") + model_name: str = Field(title="Model Name") + hash: str = Field(title="Hash") + filename: str = Field(title="Filename") + config: str = Field(title="Config file") + +class HypernetworkItem(BaseModel): + name: str = Field(title="Name") + path: Optional[str] = Field(title="Path") + +class FaceRestorerItem(BaseModel): + name: str = Field(title="Name") + cmd_dir: Optional[str] = Field(title="Path") + +class RealesrganItem(BaseModel): + name: str = Field(title="Name") + path: Optional[str] = Field(title="Path") + scale: Optional[int] = Field(title="Scale") + +class PromptStyleItem(BaseModel): + name: str = Field(title="Name") + prompt: Optional[str] = Field(title="Prompt") + negative_prompt: Optional[str] = Field(title="Negative Prompt") + +class ArtistItem(BaseModel): + name: str = Field(title="Name") + score: float = Field(title="Score") + category: str = Field(title="Category") + |