aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--README.md7
-rw-r--r--javascript/edit-attention.js3
-rw-r--r--launch.py11
-rw-r--r--localizations/ko_KR.json9
-rw-r--r--localizations/zh_CN.json2
-rw-r--r--localizations/zh_TW.json54
-rw-r--r--modules/api/api.py89
-rw-r--r--modules/api/models.py77
-rw-r--r--modules/extensions.py7
-rw-r--r--modules/extras.py5
-rw-r--r--modules/processing.py11
-rw-r--r--modules/script_callbacks.py69
-rw-r--r--modules/scripts.py22
-rw-r--r--modules/sd_models.py21
-rw-r--r--modules/shared.py34
-rw-r--r--modules/ui.py11
-rw-r--r--modules/ui_extensions.py2
-rw-r--r--modules/upscaler.py12
-rw-r--r--test/utils_test.py63
-rw-r--r--webui.py26
20 files changed, 421 insertions, 114 deletions
diff --git a/README.md b/README.md
index 55c050d5..33508f31 100644
--- a/README.md
+++ b/README.md
@@ -155,14 +155,15 @@ The documentation was moved from this README over to the project's [wiki](https:
- Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
-- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
-- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
-- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
+- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
+- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
+- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
+- Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index c0d29a74..b947cbec 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -1,7 +1,6 @@
addEventListener('keydown', (event) => {
let target = event.originalTarget || event.composedPath()[0];
- if (!target.hasAttribute("placeholder")) return;
- if (!target.placeholder.toLowerCase().includes("prompt")) return;
+ if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
if (! (event.metaKey || event.ctrlKey)) return;
diff --git a/launch.py b/launch.py
index ff2f74ba..2a51f20e 100644
--- a/launch.py
+++ b/launch.py
@@ -238,12 +238,15 @@ def tests(argv):
proc.kill()
-def start_webui():
- print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
+def start():
+ print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
- webui.webui()
+ if '--nowebui' in sys.argv:
+ webui.api_only()
+ else:
+ webui.webui()
if __name__ == "__main__":
prepare_enviroment()
- start_webui()
+ start()
diff --git a/localizations/ko_KR.json b/localizations/ko_KR.json
index 874771f9..29e10075 100644
--- a/localizations/ko_KR.json
+++ b/localizations/ko_KR.json
@@ -24,6 +24,7 @@
"Add model hash to generation information": "생성 정보에 모델 해시 추가",
"Add model name to generation information": "생성 정보에 모델 이름 추가",
"Add number to filename when saving": "이미지를 저장할 때 파일명에 숫자 추가하기",
+ "Adds a tab to the webui that allows the user to automatically extract keyframes from video, and manually extract 512x512 crops of those frames for use in model training.": "WebUI에 비디오로부터 자동으로 키프레임을 추출하고, 그 키프레임으로부터 모델 훈련에 사용될 512x512 이미지를 잘라낼 수 있는 탭을 추가합니다.",
"Aesthetic Gradients": "스타일 그라디언트",
"Aesthetic Image Scorer": "스타일 이미지 스코어러",
"Aesthetic imgs embedding": "스타일 이미지 임베딩",
@@ -260,6 +261,7 @@
"Keep -1 for seeds": "시드값 -1로 유지",
"keep whatever was there originally": "이미지 원본 유지",
"keyword": "프롬프트",
+ "Krita Plugin.": "Kirta 플러그인입니다.",
"Label": "라벨",
"Lanczos": "Lanczos",
"Last prompt:": "마지막 프롬프트 : ",
@@ -441,8 +443,8 @@
"See": "자세한 설명은",
"Seed": "시드",
"Seed of a different picture to be mixed into the generation.": "결과물에 섞일 다른 그림의 시드",
- "Select activation function of hypernetwork": "하이퍼네트워크 활성화 함수 선택",
- "Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended": "레이어 가중치 초기화 방식 선택 - relu류 : Kaiming 추천, sigmoid류 : Xavier 추천",
+ "Select activation function of hypernetwork. Recommended : Swish / Linear(none)": "하이퍼네트워크 활성화 함수 선택 - 추천 : Swish / Linear(None)",
+ "Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise": "레이어 가중치 초기화 방식 선택 - relu류 : Kaiming 추천, sigmoid류 : Xavier 추천, 그 외 : Normal",
"Select which Real-ESRGAN models to show in the web UI. (Requires restart)": "WebUI에 표시할 Real-ESRGAN 모델을 선택하십시오. (재시작 필요)",
"Send seed when sending prompt or image to other interface": "다른 화면으로 프롬프트나 이미지를 보낼 때 시드도 함께 보내기",
"Send to extras": "부가기능으로 전송",
@@ -458,7 +460,7 @@
"should be 2 or lower.": "이 2 이하여야 합니다.",
"Show generation progress in window title.": "창 타이틀에 생성 진행도 보여주기",
"Show grid in results for web": "웹에서 결과창에 그리드 보여주기",
- "Show image creation progress every N sampling steps. Set 0 to disable.": "N번째 샘플링 스텝마다 이미지 생성 과정 보이기 - 비활성화하려면 0으로 설정",
+ "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.": "N번째 샘플링 스텝마다 이미지 생성 과정 보이기 - 비활성화하려면 0으로 설정, 배치 생성 완료시 보이려면 -1로 설정",
"Show images zoomed in by default in full page image viewer": "전체 페이지 이미지 뷰어에서 기본값으로 이미지 확대해서 보여주기",
"Show previews of all images generated in a batch as a grid": "배치에서 생성된 모든 이미지의 미리보기를 그리드 형식으로 보여주기",
"Show progressbar": "프로그레스 바 보이기",
@@ -520,6 +522,7 @@
"Train Embedding": "임베딩 훈련",
"Train Hypernetwork": "하이퍼네트워크 훈련",
"Training": "훈련",
+ "training-picker": "훈련용 선택기",
"txt2img": "텍스트→이미지",
"txt2img history": "텍스트→이미지 기록",
"uniform": "uniform",
diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json
index 56c8980e..ff785fc0 100644
--- a/localizations/zh_CN.json
+++ b/localizations/zh_CN.json
@@ -109,7 +109,7 @@
"Sigma noise": "Sigma noise",
"Eta": "Eta",
"Clip skip": "Clip 跳过",
- "Denoising": "去噪",
+ "Denoising": "重绘幅度",
"Cond. Image Mask Weight": "图像调节屏蔽度",
"X values": "X轴数值",
"Y type": "Y轴类型",
diff --git a/localizations/zh_TW.json b/localizations/zh_TW.json
index 4e6dac44..04bde864 100644
--- a/localizations/zh_TW.json
+++ b/localizations/zh_TW.json
@@ -7,7 +7,7 @@
"Loading...": "載入中…",
"view": "檢視",
"api": "api",
- "•": "•",
+ "•": " • ",
"built with gradio": "基於 Gradio 構建",
"Stable Diffusion checkpoint": "Stable Diffusion 模型權重存檔點",
"txt2img": "文生圖",
@@ -70,12 +70,12 @@
"Variation strength": "差異強度",
"Resize seed from width": "自寬度縮放隨機種子",
"Resize seed from height": "自高度縮放隨機種子",
- "Open for Clip Aesthetic!": "打開美術風格 Clip!",
+ "Open for Clip Aesthetic!": "打開以調整 Clip 的美術風格!",
"▼": "▼",
"Aesthetic weight": "美術風格權重",
"Aesthetic steps": "美術風格疊代步數",
"Aesthetic learning rate": "美術風格學習率",
- "Slerp interpolation": "Slerp 插值",
+ "Slerp interpolation": "球面線性插值角度",
"Aesthetic imgs embedding": "美術風格圖集 embedding",
"None": "無",
"Aesthetic text for imgs": "該圖集的美術風格描述",
@@ -105,15 +105,15 @@
"Prompt order": "提示詞順序",
"Sampler": "採樣器",
"Checkpoint name": "模型權重存檔點的名稱",
- "Hypernetwork": "超網路",
- "Hypernet str.": "超網路強度",
+ "Hypernetwork": "超網路(Hypernetwork)",
+ "Hypernet str.": "超網路(Hypernetwork)強度",
"Sigma Churn": "Sigma Churn",
"Sigma min": "最小 Sigma",
"Sigma max": "最大 Sigma",
"Sigma noise": "Sigma noise",
"Eta": "Eta",
"Clip skip": "Clip 跳過",
- "Denoising": "去噪",
+ "Denoising": "重繪幅度",
"Cond. Image Mask Weight": "圖像調節屏蔽度",
"X values": "X軸數值",
"Y type": "Y軸類型",
@@ -189,6 +189,7 @@
"Tile overlap": "圖塊重疊的畫素",
"Upscaler": "放大演算法",
"Lanczos": "Lanczos",
+ "Nearest": "最鄰近(整數縮放)",
"LDSR": "LDSR",
"BSRGAN 4x": "BSRGAN 4x",
"ESRGAN_4x": "ESRGAN_4x",
@@ -230,15 +231,15 @@
"for detailed explanation.": "以了解詳細說明",
"Create embedding": "生成 embedding",
"Create aesthetic images embedding": "生成美術風格圖集 embedding",
- "Create hypernetwork": "生成 hypernetwork",
+ "Create hypernetwork": "生成超網路(Hypernetwork)",
"Preprocess images": "圖像預處理",
"Name": "名稱",
"Initialization text": "初始化文字",
"Number of vectors per token": "每個 token 的向量數",
"Overwrite Old Embedding": "覆寫舊的 Embedding",
"Modules": "模組",
- "Enter hypernetwork layer structure": "輸入 hypernetwork 層結構",
- "Select activation function of hypernetwork": "選擇 hypernetwork 的激活函數",
+ "Enter hypernetwork layer structure": "輸入超網路(Hypernetwork)層結構",
+ "Select activation function of hypernetwork": "選擇超網路(Hypernetwork)的激活函數",
"linear": "linear",
"relu": "relu",
"leakyrelu": "leakyrelu",
@@ -276,7 +277,7 @@
"XavierNormal": "Xavier 正態",
"Add layer normalization": "加入層標準化",
"Use dropout": "採用 dropout 防止過擬合",
- "Overwrite Old Hypernetwork": "覆寫舊的 Hypernetwork",
+ "Overwrite Old Hypernetwork": "覆寫舊的超網路(Hypernetwork)",
"Source directory": "來源目錄",
"Destination directory": "目標目錄",
"Existing Caption txt Action": "對已有的TXT說明文字的行為",
@@ -298,11 +299,11 @@
"Create debug image": "生成除錯圖片",
"Preprocess": "預處理",
"Train an embedding; must specify a directory with a set of 1:1 ratio images": "訓練 embedding; 必須指定一組具有 1:1 比例圖像的目錄",
- "Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images": "訓練 embedding 或者 hypernetwork; 必須指定一組具有 1:1 比例圖像的目錄",
- "[wiki]": "[wiki]",
+ "Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images": "訓練 embedding 或者超網路(Hypernetwork); 必須指定一組具有 1:1 比例圖像的目錄",
+ "[wiki]": "[wiki文件]",
"Embedding": "Embedding",
"Embedding Learning rate": "Embedding 學習率",
- "Hypernetwork Learning rate": "Hypernetwork 學習率",
+ "Hypernetwork Learning rate": "超網路(Hypernetwork)學習率",
"Learning rate": "學習率",
"Dataset directory": "資料集目錄",
"Log directory": "日誌目錄",
@@ -312,7 +313,7 @@
"Save a copy of embedding to log directory every N steps, 0 to disable": "每 N 步將 embedding 的副本儲存到日誌目錄,0 表示禁用",
"Save images with embedding in PNG chunks": "儲存圖像,並在 PNG 圖片檔案中嵌入 embedding 檔案",
"Read parameters (prompt, etc...) from txt2img tab when making previews": "進行預覽時,從文生圖頁籤中讀取參數(提示詞等)",
- "Train Hypernetwork": "訓練 Hypernetwork",
+ "Train Hypernetwork": "訓練超網路(Hypernetwork)",
"Train Embedding": "訓練 Embedding",
"Create an aesthetic embedding out of any number of images": "從任意數量的圖像中建立美術風格 embedding",
"Create images embedding": "生成圖集 embedding",
@@ -418,7 +419,7 @@
"Checkpoints to cache in RAM": "快取在內存(RAM)中的模型權重存檔點",
"SD VAE": "模型的VAE",
"auto": "自動",
- "Hypernetwork strength": "Hypernetwork 強度",
+ "Hypernetwork strength": "超網路(Hypernetwork)強度",
"Inpainting conditioning mask strength": "局部重繪時圖像調節的蒙版屏蔽強度",
"Apply color correction to img2img results to match original colors.": "對圖生圖結果套用顏色校正以匹配原始顏色",
"With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising).": "在進行圖生圖的時候,確切地執行滑塊指定的疊代步數(正常情況下更弱的重繪幅度需要更少的疊代步數)",
@@ -488,7 +489,17 @@
"Extension": "擴充",
"URL": "網址",
"Update": "更新",
+ "a1111-sd-webui-tagcomplete": "標記自動補全",
"unknown": "未知",
+ "deforum-for-automatic1111-webui": "Deforum",
+ "sd-dynamic-prompting": "動態提示詞",
+ "stable-diffusion-webui-aesthetic-gradients": "美術風格梯度",
+ "stable-diffusion-webui-aesthetic-image-scorer": "美術風格評等",
+ "stable-diffusion-webui-artists-to-study": "藝術家圖庫",
+ "stable-diffusion-webui-dataset-tag-editor": "資料集標記編輯器",
+ "stable-diffusion-webui-images-browser": "圖庫瀏覽器",
+ "stable-diffusion-webui-inspiration": "靈感",
+ "stable-diffusion-webui-wildcards": "萬用字元",
"Load from:": "載入自",
"Extension index URL": "擴充清單連結",
"URL for extension's git repository": "擴充的 git 倉庫連結",
@@ -527,8 +538,8 @@
"What to put inside the masked area before processing it with Stable Diffusion.": "在使用 Stable Diffusion 處理蒙版區域之前要在蒙版區域內放置什麼",
"fill it with colors of the image": "用圖像的顏色(高強度模糊)填充它",
"keep whatever was there originally": "保留原來的圖像,不進行預處理",
- "fill it with latent space noise": "用潛空間的噪聲填充它",
- "fill it with latent space zeroes": "用潛空間的零填充它",
+ "fill it with latent space noise": "於潛空間填充噪聲",
+ "fill it with latent space zeroes": "於潛空間填零",
"Upscale masked region to target resolution, do inpainting, downscale back and paste into original image": "將蒙版區域(包括預留畫素長度的緩衝區域)放大到目標解析度,進行局部重繪。\n然後縮小並粘貼回原始圖像中",
"Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.": "將圖像大小調整為目標解析度。除非高度和寬度匹配,否則你將獲得不正確的縱橫比",
"Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.": "調整圖像大小,使整個目標解析度都被圖像填充。裁剪多出來的部分",
@@ -560,6 +571,8 @@
"Select which Real-ESRGAN models to show in the web UI. (Requires restart)": "選擇哪些 Real-ESRGAN 模型顯示在網頁使用者介面。(需要重新啟動)",
"Allowed categories for random artists selection when using the Roll button": "使用抽選藝術家按鈕時將會隨機的藝術家類別",
"Append commas": "附加逗號",
+ "latest": "最新",
+ "behind": "落後",
"Roll three": "抽三位出來",
"Generate forever": "無限生成",
"Cancel generate forever": "停止無限生成",
@@ -581,10 +594,9 @@
"Start drawing": "開始繪製",
"Description": "描述",
"Action": "行動",
- "Aesthetic Gradients": "美術風格",
- "aesthetic-gradients": "美術風格",
- "stable-diffusion-webui-wildcards": "萬用字元",
- "Dynamic Prompts": "動態提示",
+ "Aesthetic Gradients": "美術風格梯度",
+ "aesthetic-gradients": "美術風格梯度",
+ "Dynamic Prompts": "動態提示詞",
"images-browser": "圖庫瀏覽器",
"Inspiration": "靈感",
"Deforum": "Deforum",
diff --git a/modules/api/api.py b/modules/api/api.py
index 71c9c160..8a7ab2f5 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -2,14 +2,17 @@ import base64
import io
import time
import uvicorn
-from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
-from fastapi import APIRouter, Depends, HTTPException
+from threading import Lock
+from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
+from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
-
+from modules.sd_models import checkpoints_list
+from modules.realesrgan_model import get_realesrgan_models
+from typing import List
def upscaler_to_index(name: str):
try:
@@ -37,7 +40,7 @@ def encode_pil_to_base64(image):
class Api:
- def __init__(self, app, queue_lock):
+ def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
@@ -48,6 +51,18 @@ class Api:
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+ self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
+ self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
+ self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
+ self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
+ self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
+ self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
+ self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
+ self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
+ self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
+ self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -190,6 +205,70 @@ class Api:
shared.state.interrupt()
return {}
+
+ def get_config(self):
+ options = {}
+ for key in shared.opts.data.keys():
+ metadata = shared.opts.data_labels.get(key)
+ if(metadata is not None):
+ options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
+ else:
+ options.update({key: shared.opts.data.get(key, None)})
+
+ return options
+
+ def set_config(self, req: OptionsModel):
+ # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
+ # overwrite all options with default values.
+ raise RuntimeError('Setting options via API is not supported')
+
+ reqDict = vars(req)
+ for o in reqDict:
+ setattr(shared.opts, o, reqDict[o])
+
+ shared.opts.save(shared.config_filename)
+ return
+
+ def get_cmd_flags(self):
+ return vars(shared.cmd_opts)
+
+ def get_samplers(self):
+ return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
+
+ def get_upscalers(self):
+ upscalers = []
+
+ for upscaler in shared.sd_upscalers:
+ u = upscaler.scaler
+ upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
+
+ return upscalers
+
+ def get_sd_models(self):
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
+
+ def get_hypernetworks(self):
+ return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
+
+ def get_face_restorers(self):
+ return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
+
+ def get_realesrgan_models(self):
+ return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
+
+ def get_promp_styles(self):
+ styleList = []
+ for k in shared.prompt_styles.styles:
+ style = shared.prompt_styles.styles[k]
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
+
+ return styleList
+
+ def get_artists_categories(self):
+ return shared.artist_db.cats
+
+ def get_artists(self):
+ return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 68fb45c6..a44c5ddd 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,11 +1,11 @@
import inspect
-from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-from modules.shared import sd_upscalers
+from modules.shared import sd_upscalers, opts, parser
+from typing import List
API_NOT_ALLOWED = [
"self",
@@ -109,12 +109,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@@ -147,10 +147,10 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+ images: List[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
@@ -166,3 +166,68 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+
+fields = {}
+for key, value in opts.data.items():
+ metadata = opts.data_labels.get(key)
+ optType = opts.typemap.get(type(value), type(value))
+
+ if (metadata is not None):
+ fields.update({key: (Optional[optType], Field(
+ default=metadata.default ,description=metadata.label))})
+ else:
+ fields.update({key: (Optional[optType], Field())})
+
+OptionsModel = create_model("Options", **fields)
+
+flags = {}
+_options = vars(parser)['_option_string_actions']
+for key in _options:
+ if(_options[key].dest != 'help'):
+ flag = _options[key]
+ _type = str
+ if _options[key].default is not None: _type = type(_options[key].default)
+ flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+
+FlagsModel = create_model("Flags", **flags)
+
+class SamplerItem(BaseModel):
+ name: str = Field(title="Name")
+ aliases: list[str] = Field(title="Aliases")
+ options: dict[str, str] = Field(title="Options")
+
+class UpscalerItem(BaseModel):
+ name: str = Field(title="Name")
+ model_name: Optional[str] = Field(title="Model Name")
+ model_path: Optional[str] = Field(title="Path")
+ model_url: Optional[str] = Field(title="URL")
+
+class SDModelItem(BaseModel):
+ title: str = Field(title="Title")
+ model_name: str = Field(title="Model Name")
+ hash: str = Field(title="Hash")
+ filename: str = Field(title="Filename")
+ config: str = Field(title="Config file")
+
+class HypernetworkItem(BaseModel):
+ name: str = Field(title="Name")
+ path: Optional[str] = Field(title="Path")
+
+class FaceRestorerItem(BaseModel):
+ name: str = Field(title="Name")
+ cmd_dir: Optional[str] = Field(title="Path")
+
+class RealesrganItem(BaseModel):
+ name: str = Field(title="Name")
+ path: Optional[str] = Field(title="Path")
+ scale: Optional[int] = Field(title="Scale")
+
+class PromptStyleItem(BaseModel):
+ name: str = Field(title="Name")
+ prompt: Optional[str] = Field(title="Prompt")
+ negative_prompt: Optional[str] = Field(title="Negative Prompt")
+
+class ArtistItem(BaseModel):
+ name: str = Field(title="Name")
+ score: float = Field(title="Score")
+ category: str = Field(title="Category") \ No newline at end of file
diff --git a/modules/extensions.py b/modules/extensions.py
index 897af96e..8e0977fd 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -34,8 +34,11 @@ class Extension:
if repo is None or repo.bare:
self.remote = None
else:
- self.remote = next(repo.remote().urls, None)
- self.status = 'unknown'
+ try:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+ except Exception:
+ self.remote = None
def list_files(self, subdir, extension):
from modules import scripts
diff --git a/modules/extras.py b/modules/extras.py
index 8e2ab35c..71b93a06 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -136,12 +136,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
+ image_hash: str = hash(np.array(image.getdata()).tobytes())
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
+ cache_key = LruCache.Key(image_hash=image_hash,
info_hash=hash(info),
- args_hash=hash((upscale_args, upscale_first)))
+ args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
diff --git a/modules/processing.py b/modules/processing.py
index a46e592d..03c9143d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -501,6 +501,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
+ if p.scripts is not None:
+ p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -665,17 +668,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
+
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
-
+
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
else:
image_conditioning = self.txt2img_image_conditioning(samples)
-
- for i in range(samples.shape[0]):
- save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index c28e220e..74dfb880 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -46,25 +46,23 @@ class CFGDenoiserParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
-callbacks_app_started = []
-callbacks_model_loaded = []
-callbacks_ui_tabs = []
-callbacks_ui_settings = []
-callbacks_before_image_saved = []
-callbacks_image_saved = []
-callbacks_cfg_denoiser = []
+callback_map = dict(
+ callbacks_app_started=[],
+ callbacks_model_loaded=[],
+ callbacks_ui_tabs=[],
+ callbacks_ui_settings=[],
+ callbacks_before_image_saved=[],
+ callbacks_image_saved=[],
+ callbacks_cfg_denoiser=[]
+)
def clear_callbacks():
- callbacks_model_loaded.clear()
- callbacks_ui_tabs.clear()
- callbacks_ui_settings.clear()
- callbacks_before_image_saved.clear()
- callbacks_image_saved.clear()
- callbacks_cfg_denoiser.clear()
+ for callback_list in callback_map.values():
+ callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in callbacks_app_started:
+ for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
@@ -72,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def model_loaded_callback(sd_model):
- for c in callbacks_model_loaded:
+ for c in callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
@@ -82,7 +80,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
- for c in callbacks_ui_tabs:
+ for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
@@ -92,7 +90,7 @@ def ui_tabs_callback():
def ui_settings_callback():
- for c in callbacks_ui_settings:
+ for c in callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
@@ -100,7 +98,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in callbacks_before_image_saved:
+ for c in callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
@@ -108,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
- for c in callbacks_image_saved:
+ for c in callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
@@ -116,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in callbacks_cfg_denoiser:
+ for c in callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
@@ -129,17 +127,33 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
+
+def remove_current_script_callbacks():
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ if filename == 'unknown file':
+ return
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
+ callback_list.remove(callback_to_remove)
+
+
+def remove_callbacks_for_function(callback_func):
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
+ callback_list.remove(callback_to_remove)
+
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
- add_callback(callbacks_app_started, callback)
+ add_callback(callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- add_callback(callbacks_model_loaded, callback)
+ add_callback(callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
@@ -152,13 +166,13 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- add_callback(callbacks_ui_tabs, callback)
+ add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(callbacks_ui_settings, callback)
+ add_callback(callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
@@ -166,7 +180,7 @@ def on_before_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
- add_callback(callbacks_before_image_saved, callback)
+ add_callback(callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
@@ -174,7 +188,7 @@ def on_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
- add_callback(callbacks_image_saved, callback)
+ add_callback(callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
@@ -182,5 +196,4 @@ def on_cfg_denoiser(callback):
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
- add_callback(callbacks_cfg_denoiser, callback)
-
+ add_callback(callback_map['callbacks_cfg_denoiser'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index 28ce07f4..366c90d7 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -73,6 +73,19 @@ class Script:
pass
+ def process_batch(self, p, *args, **kwargs):
+ """
+ Same as process(), but called for every batch.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -296,6 +309,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def process_batch(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process_batch(p, *script_args, **kwargs)
+ except Exception:
+ print(f"Error running process_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
try:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ae427a5c..34c57bfa 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -163,11 +163,13 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+ if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
+ sd_vae.restore_base_vae(model)
+ checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
- checkpoint_key = checkpoint_info
+ vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
- if checkpoint_key not in checkpoints_loaded:
+ if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -197,18 +199,15 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.first_stage_model.to(devices.dtype_vae)
- if shared.opts.sd_checkpoint_cache > 0:
- # if PR #4035 were to get merged, restore base VAE first before caching
- checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
- checkpoints_loaded.popitem(last=False) # LRU
-
else:
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
- checkpoints_loaded.move_to_end(checkpoint_key)
- model.load_state_dict(checkpoints_loaded[checkpoint_key])
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
+
+ if shared.opts.sd_checkpoint_cache > 0:
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
diff --git a/modules/shared.py b/modules/shared.py
index 6e7a02e0..71587557 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -44,6 +44,7 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
+parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@@ -85,6 +86,9 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
+parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
cmd_opts = parser.parse_args()
restricted_opts = {
@@ -99,7 +103,7 @@ restricted_opts = {
"outdir_save",
}
-cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
+cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
@@ -146,6 +150,9 @@ class State:
self.interrupted = True
def nextjob(self):
+ if opts.show_progress_every_n_steps == -1:
+ self.do_set_current_image()
+
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
@@ -186,17 +193,21 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ self.do_set_current_image()
+
+ def do_set_current_image(self):
if not parallel_processing_allowed:
return
+ if self.current_latent is None:
+ return
+
+ if opts.show_progress_grid:
+ self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
+ else:
+ self.current_image = sd_samplers.sample_to_image(self.current_latent)
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
- if opts.show_progress_grid:
- self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
- else:
- self.current_image = sd_samplers.sample_to_image(self.current_latent)
-
- self.current_image_sampling_step = self.sampling_step
-
+ self.current_image_sampling_step = self.sampling_step
state = State()
@@ -352,7 +363,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
@@ -399,7 +410,8 @@ class Options:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
- comp_args = opts.data_labels[key].component_args
+ info = opts.data_labels.get(key, None)
+ comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
diff --git a/modules/ui.py b/modules/ui.py
index 633b56ef..4c2829af 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -276,7 +276,7 @@ def check_progress_call(id_part):
image = gr_show(False)
preview_visibility = gr_show(False)
- if opts.show_progress_every_n_steps > 0:
+ if opts.show_progress_every_n_steps != 0:
shared.state.set_current_image()
image = shared.state.current_image
@@ -1439,8 +1439,7 @@ def create_ui(wrap_gradio_gpu_call):
changed = 0
for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
- return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
+ assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
@@ -1458,7 +1457,7 @@ def create_ui(wrap_gradio_gpu_call):
opts.save(shared.config_filename)
- return f'{changed} settings changed.', opts.dumpjson()
+ return opts.dumpjson(), f'{changed} settings changed.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
@@ -1622,9 +1621,9 @@ def create_ui(wrap_gradio_gpu_call):
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
- fn=run_settings,
+ fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
inputs=components,
- outputs=[result, text_settings],
+ outputs=[text_settings, result],
)
for i, k, item in quicksettings_list:
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index ab807722..a81de9a7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -86,7 +86,7 @@ def extension_table():
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
- <td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
+ <td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 83fde7ca..c4e6e6bd 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -57,10 +57,18 @@ class Upscaler:
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
+
for i in range(3):
- if img.width > dest_w and img.height > dest_h:
- break
+ shape = (img.width, img.height)
+
img = self.do_upscale(img, selected_model)
+
+ if shape == (img.width, img.height):
+ break
+
+ if img.width >= dest_w and img.height >= dest_h:
+ break
+
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
diff --git a/test/utils_test.py b/test/utils_test.py
new file mode 100644
index 00000000..65d3d177
--- /dev/null
+++ b/test/utils_test.py
@@ -0,0 +1,63 @@
+import unittest
+import requests
+
+class UtilsTests(unittest.TestCase):
+ def setUp(self):
+ self.url_options = "http://localhost:7860/sdapi/v1/options"
+ self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
+ self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
+ self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
+ self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
+ self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
+ self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
+ self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
+ self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
+ self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
+ self.url_artists = "http://localhost:7860/sdapi/v1/artists"
+
+ def test_options_get(self):
+ self.assertEqual(requests.get(self.url_options).status_code, 200)
+
+ def test_options_write(self):
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
+
+ pre_value = response.json()["send_seed"]
+
+ self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
+
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.json()["send_seed"], not pre_value)
+
+ requests.post(self.url_options, json={"send_seed": pre_value})
+
+ def test_cmd_flags(self):
+ self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
+
+ def test_samplers(self):
+ self.assertEqual(requests.get(self.url_samplers).status_code, 200)
+
+ def test_upscalers(self):
+ self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
+
+ def test_sd_models(self):
+ self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
+
+ def test_hypernetworks(self):
+ self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
+
+ def test_face_restorers(self):
+ self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
+
+ def test_realesrgan_models(self):
+ self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
+
+ def test_prompt_styles(self):
+ self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
+
+ def test_artist_categories(self):
+ self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
+
+ def test_artists(self):
+ self.assertEqual(requests.get(self.url_artists).status_code, 200) \ No newline at end of file
diff --git a/webui.py b/webui.py
index 3b21c071..222dbeee 100644
--- a/webui.py
+++ b/webui.py
@@ -34,7 +34,7 @@ from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock()
-
+server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name
def wrap_queued_call(func):
def f(*args, **kwargs):
@@ -85,6 +85,20 @@ def initialize():
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
+ if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
+
+ try:
+ if not os.path.exists(cmd_opts.tls_keyfile):
+ print("Invalid path to TLS keyfile given")
+ if not os.path.exists(cmd_opts.tls_certfile):
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+ except TypeError:
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+ print("TLS setup invalid, running webui without TLS")
+ else:
+ print("Running with TLS")
+
+
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
@@ -131,8 +145,10 @@ def webui():
app, local_url, share_url = demo.launch(
share=cmd_opts.share,
- server_name="0.0.0.0" if cmd_opts.listen else None,
+ server_name=server_name,
server_port=cmd_opts.port,
+ ssl_keyfile=cmd_opts.tls_keyfile,
+ ssl_certfile=cmd_opts.tls_certfile,
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
inbrowser=cmd_opts.autolaunch,
@@ -141,6 +157,12 @@ def webui():
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
+ # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
+ # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
+ # running web ui and do whatever the attcker wants, including installing an extension and
+ # runnnig its code. We disable this here. Suggested by RyotaK.
+ app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
if launch_api: