diff options
Diffstat (limited to 'modules/api/models.py')
-rw-r--r-- | modules/api/models.py | 86 |
1 files changed, 80 insertions, 6 deletions
diff --git a/modules/api/models.py b/modules/api/models.py index 68fb45c6..f77951fc 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,11 +1,11 @@ import inspect -from click import prompt from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img -from modules.shared import sd_upscalers +from modules.shared import sd_upscalers, opts, parser +from typing import Dict, List API_NOT_ALLOWED = [ "self", @@ -65,6 +65,7 @@ class PydanticModelGenerator: self._model_name = model_name self._class_data = merge_class_params(class_instance) + self._model_def = [ ModelDef( field=underscore(k), @@ -109,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( ).generate_model() class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str @@ -147,10 +148,10 @@ class FileData(BaseModel): name: str = Field(title="File name") class ExtrasBatchImagesRequest(ExtrasBaseRequest): - imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") + images: List[str] = Field(title="Images", description="The generated images in base64 format.") class PNGInfoRequest(BaseModel): image: str = Field(title="Image", description="The base64 encoded PNG image") @@ -166,3 +167,76 @@ class ProgressResponse(BaseModel): eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + +class InterrogateRequest(BaseModel): + image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + model: str = Field(default="clip", title="Model", description="The interrogate model used.") + +class InterrogateResponse(BaseModel): + caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + +fields = {} +for key, metadata in opts.data_labels.items(): + value = opts.data.get(key) + optType = opts.typemap.get(type(metadata.default), type(value)) + + if (metadata is not None): + fields.update({key: (Optional[optType], Field( + default=metadata.default ,description=metadata.label))}) + else: + fields.update({key: (Optional[optType], Field())}) + +OptionsModel = create_model("Options", **fields) + +flags = {} +_options = vars(parser)['_option_string_actions'] +for key in _options: + if(_options[key].dest != 'help'): + flag = _options[key] + _type = str + if _options[key].default is not None: _type = type(_options[key].default) + flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) + +FlagsModel = create_model("Flags", **flags) + +class SamplerItem(BaseModel): + name: str = Field(title="Name") + aliases: List[str] = Field(title="Aliases") + options: Dict[str, str] = Field(title="Options") + +class UpscalerItem(BaseModel): + name: str = Field(title="Name") + model_name: Optional[str] = Field(title="Model Name") + model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") + +class SDModelItem(BaseModel): + title: str = Field(title="Title") + model_name: str = Field(title="Model Name") + hash: str = Field(title="Hash") + filename: str = Field(title="Filename") + config: str = Field(title="Config file") + +class HypernetworkItem(BaseModel): + name: str = Field(title="Name") + path: Optional[str] = Field(title="Path") + +class FaceRestorerItem(BaseModel): + name: str = Field(title="Name") + cmd_dir: Optional[str] = Field(title="Path") + +class RealesrganItem(BaseModel): + name: str = Field(title="Name") + path: Optional[str] = Field(title="Path") + scale: Optional[int] = Field(title="Scale") + +class PromptStyleItem(BaseModel): + name: str = Field(title="Name") + prompt: Optional[str] = Field(title="Prompt") + negative_prompt: Optional[str] = Field(title="Negative Prompt") + +class ArtistItem(BaseModel): + name: str = Field(title="Name") + score: float = Field(title="Score") + category: str = Field(title="Category") + |