diff options
Diffstat (limited to 'modules/api')
-rw-r--r-- | modules/api/api.py | 73 | ||||
-rw-r--r-- | modules/api/models.py | 10 |
2 files changed, 66 insertions, 17 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 1ceba75d..48a70a44 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,11 +1,12 @@ import base64 import io import time +import datetime import uvicorn from threading import Lock from io import BytesIO from gradio.processing_utils import decode_base64_to_file -from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest @@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list +from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -67,6 +68,27 @@ def encode_pil_to_base64(image): bytes_data = output_bytes.getvalue() return base64.b64encode(bytes_data) +def api_middleware(app: FastAPI): + @app.middleware("http") + async def log_and_time(req: Request, call_next): + ts = time.time() + res: Response = await call_next(req) + duration = str(round(time.time() - ts, 4)) + res.headers["X-Process-Time"] = duration + endpoint = req.scope.get('path', 'err') + if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): + print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( + t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code = res.status_code, + ver = req.scope.get('http_version', '0.0'), + cli = req.scope.get('client', ('0:0.0.0', 0))[0], + prot = req.scope.get('scheme', 'err'), + method = req.scope.get('method', 'err'), + endpoint = endpoint, + duration = duration, + )) + return res + class Api: def __init__(self, app: FastAPI, queue_lock: Lock): @@ -79,6 +101,7 @@ class Api: self.router = APIRouter() self.app = app self.queue_lock = queue_lock + api_middleware(self.app) self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) @@ -100,6 +123,7 @@ class Api: self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) @@ -121,7 +145,6 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): populate = txt2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True @@ -129,15 +152,14 @@ class Api: ) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on - p = StableDiffusionProcessingTxt2Img(**vars(populate)) - # Override object param - - shared.state.begin() with self.queue_lock: + p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate)) + + shared.state.begin() processed = process_images(p) + shared.state.end() - shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) @@ -153,7 +175,6 @@ class Api: mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, @@ -165,16 +186,14 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. - p = StableDiffusionProcessingImg2Img(**args) - - p.init_images = [decode_base64_to_image(x) for x in init_images] - - shared.state.begin() with self.queue_lock: - processed = process_images(p) + p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) + p.init_images = [decode_base64_to_image(x) for x in init_images] - shared.state.end() + shared.state.begin() + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) @@ -307,7 +326,7 @@ class Api: return upscalers def get_sd_models(self): - return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] + return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -332,6 +351,26 @@ class Api: def get_artists(self): return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] + def get_embeddings(self): + db = sd_hijack.model_hijack.embedding_db + + def convert_embedding(embedding): + return { + "step": embedding.step, + "sd_checkpoint": embedding.sd_checkpoint, + "sd_checkpoint_name": embedding.sd_checkpoint_name, + "shape": embedding.shape, + "vectors": embedding.vectors, + } + + def convert_embeddings(embeddings): + return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()} + + return { + "loaded": convert_embeddings(db.word_embeddings), + "skipped": convert_embeddings(db.skipped_embeddings), + } + def refresh_checkpoints(self): shared.refresh_checkpoints() diff --git a/modules/api/models.py b/modules/api/models.py index c446ce7a..4a632c68 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -249,3 +249,13 @@ class ArtistItem(BaseModel): score: float = Field(title="Score") category: str = Field(title="Category") +class EmbeddingItem(BaseModel): + step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") + sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available") + sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead") + shape: int = Field(title="Shape", description="The length of each individual vector in the embedding") + vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") + +class EmbeddingsResponse(BaseModel): + loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") + skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file |