From 0ebf66b575f008a027097946eb2f6845feffd010 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 5 Nov 2022 18:58:19 -0300 Subject: Fix set config endpoint --- modules/api/api.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 112000b8..a924c83a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -230,14 +230,10 @@ class Api: return options - def set_config(self, req: OptionsModel): - # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will - # overwrite all options with default values. - raise RuntimeError('Setting options via API is not supported') - - reqDict = vars(req) - for o in reqDict: - setattr(shared.opts, o, reqDict[o]) + def set_config(self, req: Dict[str, Any]): + + for o in req: + setattr(shared.opts, o, req[o]) shared.opts.save(shared.config_filename) return -- cgit v1.2.3 From 3c72055c22425dcde0739b5246e3501f4a3ec794 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 5 Nov 2022 19:05:15 -0300 Subject: Add skip endpoint --- modules/api/api.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index a924c83a..c7ceb787 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -64,6 +64,7 @@ class Api: self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) @@ -219,6 +220,11 @@ class Api: return {} + def skip(self): + shared.state.skip() + + return + def get_config(self): options = {} for key in shared.opts.data.keys(): -- cgit v1.2.3 From 7f63980e479c7ffaec907fb659b5024e96eb72e7 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 5 Nov 2022 19:09:13 -0300 Subject: Remove unnecesary return --- modules/api/api.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index c7ceb787..33e6c6dc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -223,8 +223,6 @@ class Api: def skip(self): shared.state.skip() - return - def get_config(self): options = {} for key in shared.opts.data.keys(): -- cgit v1.2.3 From 67c8e11be74180be19341aebbd6a246c37a79fbb Mon Sep 17 00:00:00 2001 From: snowmeow2 Date: Mon, 7 Nov 2022 02:32:06 +0800 Subject: Adding DeepDanbooru to the interrogation API --- modules/api/api.py | 16 ++++++++++++++-- modules/api/models.py | 1 + 2 files changed, 15 insertions(+), 2 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 688469ad..596a6616 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -15,6 +15,9 @@ from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List +if shared.cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags + def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) @@ -220,11 +223,20 @@ class Api: if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") - img = self.__base64_to_image(image_b64) + img = decode_base64_to_image(image_b64) + img = img.convert('RGB') # Override object param with self.queue_lock: - processed = shared.interrogator.interrogate(img) + if interrogatereq.model == "clip": + processed = shared.interrogator.interrogate(img) + elif interrogatereq.model == "deepdanbooru": + if shared.cmd_opts.deepdanbooru: + processed = get_deepbooru_tags(img) + else: + raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.") + else: + raise HTTPException(status_code=404, detail="Model not found") return InterrogateResponse(caption=processed) diff --git a/modules/api/models.py b/modules/api/models.py index 34dbfa16..f9cd929e 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -170,6 +170,7 @@ class ProgressResponse(BaseModel): class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + model: str = Field(default="clip", title="Model", description="The interrogate model used.") class InterrogateResponse(BaseModel): caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") -- cgit v1.2.3 From 8f2ff861d31972d12de278075ea9c0c0deef99de Mon Sep 17 00:00:00 2001 From: Maiko Sinkyaet Tan Date: Tue, 15 Nov 2022 16:12:34 +0800 Subject: feat: add http basic authentication for api --- modules/api/api.py | 61 ++++++++++++++++++++++++++++++++++++------------------ modules/shared.py | 1 + 2 files changed, 42 insertions(+), 20 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..6bb01603 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -5,6 +5,9 @@ import uvicorn from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from secrets import compare_digest + import modules.shared as shared from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -57,29 +60,47 @@ def encode_pil_to_base64(image): class Api: def __init__(self, app: FastAPI, queue_lock: Lock): + if shared.cmd_opts.api_auth: + self.credenticals = dict() + for auth in shared.cmd_opts.api_auth.split(","): + user, password = auth.split(":") + self.credenticals[user] = password + self.router = APIRouter() self.app = app self.queue_lock = queue_lock - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) - self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) - self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) - self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) - self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) - self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) - self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) - self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) - self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) - self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) - self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) - self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) + self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) + self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) + self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) + self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) + self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) + self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) + self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) + self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + + def add_api_route(self, path: str, endpoint, **kwargs): + if shared.cmd_opts.api_auth: + return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) + return self.app.add_api_route(path, endpoint, **kwargs) + + def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())): + if credenticals.username in self.credenticals: + if compare_digest(credenticals.password, self.credenticals[credenticals.username]): + return True + + raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..62d526fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,6 +81,7 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") +parser.add_argument("--api-auth", type=str, help='Set authentication for api like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -- cgit v1.2.3 From cdc8020d13c5eef099c609b0a911ccf3568afc0d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 12:01:51 +0300 Subject: change StableDiffusionProcessing to internally use sampler name instead of sampler index --- modules/api/api.py | 26 ++++++++--------------- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/images.py | 2 +- modules/img2img.py | 4 ++-- modules/processing.py | 29 +++++++++++--------------- modules/sd_samplers.py | 13 +++++++++--- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/txt2img.py | 3 ++- modules/ui.py | 2 +- scripts/img2imgalt.py | 4 ++-- scripts/xy_grid.py | 12 +++++------ 11 files changed, 49 insertions(+), 54 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..0eccccbb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,9 +6,9 @@ from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException import modules.shared as shared +from modules import sd_samplers from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo from PIL import PngImagePlugin from modules.sd_models import checkpoints_list @@ -25,8 +25,12 @@ def upscaler_to_index(name: str): raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") -sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def validate_sampler_name(name): + config = sd_samplers.all_samplers_map.get(name, None) + if config is None: + raise HTTPException(status_code=404, detail="Sampler not found") + return name def setUpscalers(req: dict): reqDict = vars(req) @@ -82,14 +86,9 @@ class Api: self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - sampler_index = sampler_to_index(txt2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } @@ -109,12 +108,6 @@ class Api: return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): - sampler_index = sampler_to_index(img2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - - init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -123,10 +116,9 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask @@ -272,7 +264,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7f182712..fbb87dd1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared +from modules import devices, processing, sd_models, shared, sd_samplers from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_index = preview_sampler_index + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width diff --git a/modules/images.py b/modules/images.py index ae705cbd..26d5b7a9 100644 --- a/modules/images.py +++ b/modules/images.py @@ -303,7 +303,7 @@ class FilenameGenerator: 'width': lambda self: self.image.width, 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), - 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), + 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime