aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/warns_merge_master.yml19
-rw-r--r--extensions-builtin/Lora/lora.py11
-rw-r--r--extensions-builtin/Lora/ui_edit_user_metadata.py200
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py50
-rw-r--r--html/extra-networks-card.html10
-rw-r--r--html/image-update.svg7
-rw-r--r--javascript/extraNetworks.js64
-rw-r--r--javascript/hints.js3
-rw-r--r--modules/api/api.py33
-rw-r--r--modules/cache.py97
-rw-r--r--modules/call_queue.py18
-rw-r--r--modules/extensions.py26
-rw-r--r--modules/hashes.py38
-rw-r--r--modules/images.py5
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/launch_utils.py10
-rw-r--r--modules/processing.py7
-rw-r--r--modules/sd_hijack.py5
-rw-r--r--modules/sd_hijack_clip.py15
-rw-r--r--modules/shared.py7
-rw-r--r--modules/textual_inversion/textual_inversion.py11
-rw-r--r--modules/txt2img.py2
-rw-r--r--modules/ui.py3
-rw-r--r--modules/ui_common.py9
-rw-r--r--modules/ui_extensions.py11
-rw-r--r--modules/ui_extra_networks.py109
-rw-r--r--modules/ui_extra_networks_checkpoints.py32
-rw-r--r--modules/ui_extra_networks_hypernets.py33
-rw-r--r--modules/ui_extra_networks_textual_inversion.py32
-rw-r--r--modules/ui_extra_networks_user_metadata.py190
-rw-r--r--style.css175
31 files changed, 946 insertions, 288 deletions
diff --git a/.github/workflows/warns_merge_master.yml b/.github/workflows/warns_merge_master.yml
new file mode 100644
index 00000000..ae2aab6b
--- /dev/null
+++ b/.github/workflows/warns_merge_master.yml
@@ -0,0 +1,19 @@
+name: Pull requests can't target master branch
+
+"on":
+ pull_request:
+ types:
+ - opened
+ - synchronize
+ - reopened
+ branches:
+ - master
+
+jobs:
+ check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Warning marge into master
+ run: |
+ echo -e "::warning::This pull request directly merge into \"master\" branch, normally development happens on \"dev\" branch."
+ exit 1
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 302490fb..9cdff6ed 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -3,7 +3,7 @@ import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -86,9 +86,15 @@ class LoraOnDisk:
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+ def read_metadata():
+ metadata = sd_models.read_metadata_from_safetensors(filename)
+ metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
+
+ return metadata
+
if self.is_safetensors:
try:
- self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
@@ -99,7 +105,6 @@ class LoraOnDisk:
self.metadata = m
- self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
new file mode 100644
index 00000000..354a1d68
--- /dev/null
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -0,0 +1,200 @@
+import html
+import random
+
+import gradio as gr
+import re
+
+from modules import ui_extra_networks_user_metadata
+
+
+def is_non_comma_tagset(tags):
+ average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)
+
+ return average_tag_length >= 16
+
+
+re_word = re.compile(r"[-_\w']+")
+re_comma = re.compile(r" *, *")
+
+
+def build_tags(metadata):
+ tags = {}
+
+ for _, tags_dict in metadata.get("ss_tag_frequency", {}).items():
+ for tag, tag_count in tags_dict.items():
+ tag = tag.strip()
+ tags[tag] = tags.get(tag, 0) + int(tag_count)
+
+ if tags and is_non_comma_tagset(tags):
+ new_tags = {}
+
+ for text, text_count in tags.items():
+ for word in re.findall(re_word, text):
+ if len(word) < 3:
+ continue
+
+ new_tags[word] = new_tags.get(word, 0) + text_count
+
+ tags = new_tags
+
+ ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)
+
+ return [(tag, tags[tag]) for tag in ordered_tags]
+
+
+class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
+ def __init__(self, ui, tabname, page):
+ super().__init__(ui, tabname, page)
+
+ self.taginfo = None
+ self.edit_activation_text = None
+ self.slider_preferred_weight = None
+ self.edit_notes = None
+
+ def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["activation text"] = activation_text
+ user_metadata["preferred weight"] = preferred_weight
+ user_metadata["notes"] = notes
+
+ self.write_user_metadata(name, user_metadata)
+
+ def get_metadata_table(self, name):
+ table = super().get_metadata_table(name)
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+
+ keys = {
+ 'ss_sd_model_name': "Model:",
+ 'ss_clip_skip': "Clip skip:",
+ }
+
+ for key, label in keys.items():
+ value = metadata.get(key, None)
+ if value is not None and str(value) != "None":
+ table.append((label, html.escape(value)))
+
+ ss_bucket_info = metadata.get("ss_bucket_info")
+ if ss_bucket_info and "buckets" in ss_bucket_info:
+ resolutions = {}
+ for _, bucket in ss_bucket_info["buckets"].items():
+ resolution = bucket["resolution"]
+ resolution = f'{resolution[1]}x{resolution[0]}'
+
+ resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"])
+
+ resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True)
+ resolutions_text = html.escape(", ".join(resolutions_list[0:4]))
+ if len(resolutions) > 4:
+ resolutions_text += ", ..."
+ resolutions_text = f"<span title='{html.escape(', '.join(resolutions_list))}'>{resolutions_text}</span>"
+
+ table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text))
+
+ image_count = 0
+ for _, params in metadata.get("ss_dataset_dirs", {}).items():
+ image_count += int(params.get("img_count", 0))
+
+ if image_count:
+ table.append(("Dataset size:", image_count))
+
+ return table
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+ values = super().put_values_into_components(name)
+
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+
+ tags = build_tags(metadata)
+ gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
+
+ return [
+ *values[0:4],
+ gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
+ user_metadata.get('activation text', ''),
+ float(user_metadata.get('preferred weight', 0.0)),
+ user_metadata.get('notes', ''),
+ gr.update(visible=True if tags else False),
+ gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
+ ]
+
+ def generate_random_prompt(self, name):
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+ tags = build_tags(metadata)
+
+ return self.generate_random_prompt_from_tags(tags)
+
+ def generate_random_prompt_from_tags(self, tags):
+ max_count = None
+ res = []
+ for tag, count in tags:
+ if not max_count:
+ max_count = count
+
+ v = random.random() * max_count
+ if count > v:
+ res.append(tag)
+
+ return ", ".join(sorted(res))
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ self.taginfo = gr.HighlightedText(label="Tags")
+ self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
+ self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
+
+ with gr.Row() as row_random_prompt:
+ with gr.Column(scale=8):
+ random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
+
+ with gr.Column(scale=1, min_width=120):
+ generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)
+
+ def select_tag(activation_text, evt: gr.SelectData):
+ tag = evt.value[0]
+
+ words = re.split(re_comma, activation_text)
+ if tag in words:
+ words = [x for x in words if x != tag and x.strip()]
+ return ", ".join(words)
+
+ return activation_text + ", " + tag if activation_text else tag
+
+ self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)
+
+ self.create_default_buttons()
+
+ viewed_components = [
+ self.edit_name,
+ self.edit_description,
+ self.html_filedata,
+ self.html_preview,
+ self.taginfo,
+ self.edit_activation_text,
+ self.slider_preferred_weight,
+ self.edit_notes,
+ row_random_prompt,
+ random_prompt,
+ ]
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ edited_components = [
+ self.edit_description,
+ self.edit_activation_text,
+ self.slider_preferred_weight,
+ self.edit_notes,
+ ]
+
+ self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index da49790b..b2bc1810 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -1,8 +1,9 @@
-import json
import os
import lora
from modules import shared, ui_extra_networks
+from modules.ui_extra_networks import quote_js
+from ui_edit_user_metadata import LoraUserMetadataEditor
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
@@ -12,25 +13,42 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
lora.list_available_loras()
- def list_items(self):
- for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()):
- path, ext = os.path.splitext(lora_on_disk.filename)
+ def create_item(self, name, index=None):
+ lora_on_disk = lora.available_loras.get(name)
+
+ path, ext = os.path.splitext(lora_on_disk.filename)
+
+ alias = lora_on_disk.get_alias()
- alias = lora_on_disk.get_alias()
+ # in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string
+ item = {
+ "name": name,
+ "filename": lora_on_disk.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(lora_on_disk.filename),
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": lora_on_disk.metadata,
+ "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
+ }
- yield {
- "name": name,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(lora_on_disk.filename),
- "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
- "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
+ self.read_user_metadata(item)
+ activation_text = item["user_metadata"].get("activation text")
+ preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
+ item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
- }
+ if activation_text:
+ item["prompt"] += " + " + quote_js(" " + activation_text)
+
+ return item
+
+ def list_items(self):
+ for index, name in enumerate(lora.available_loras):
+ item = self.create_item(name, index)
+ yield item
def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir]
+ def create_user_metadata_editor(self, ui, tabname):
+ return LoraUserMetadataEditor(ui, tabname, self)
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index 68a84c3a..eb8b1a67 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -1,11 +1,11 @@
-<div class='card' style={style} onclick={card_clicked} {sort_keys}>
+<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
{background_image}
- {metadata_button}
+ <div class="button-row">
+ {edit_button}
+ {metadata_button}
+ </div>
<div class='actions'>
<div class='additional'>
- <ul>
- <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
- </ul>
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div>
<span class='name'>{name}</span>
diff --git a/html/image-update.svg b/html/image-update.svg
deleted file mode 100644
index 3abf12df..00000000
--- a/html/image-update.svg
+++ /dev/null
@@ -1,7 +0,0 @@
-<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
- <filter id='shadow' color-interpolation-filters="sRGB">
- <feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
- <feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
- </filter>
- <path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" />
-</svg>
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index b87bca3e..e453094a 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -113,7 +113,7 @@ function setupExtraNetworks() {
onUiLoaded(setupExtraNetworks);
-var re_extranet = /<([^:]+:[^:]+):[\d.]+>/;
+var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/;
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
@@ -121,15 +121,22 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var replaced = false;
var newTextareaText;
if (m) {
+ var extraTextAfterNet = m[2];
var partToSearch = m[1];
- newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) {
+ var foundAtPosition = -1;
+ newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) {
m = found.match(re_extranet);
if (m[1] == partToSearch) {
replaced = true;
+ foundAtPosition = pos;
return "";
}
return found;
});
+
+ if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
+ newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
+ }
} else {
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
if (found == text) {
@@ -182,19 +189,20 @@ function extraNetworksSearchButton(tabs_id, event) {
var globalPopup = null;
var globalPopupInner = null;
+function closePopup() {
+ if (!globalPopup) return;
+
+ globalPopup.style.display = "none";
+}
function popup(contents) {
if (!globalPopup) {
globalPopup = document.createElement('div');
- globalPopup.onclick = function() {
- globalPopup.style.display = "none";
- };
+ globalPopup.onclick = closePopup;
globalPopup.classList.add('global-popup');
var close = document.createElement('div');
close.classList.add('global-popup-close');
- close.onclick = function() {
- globalPopup.style.display = "none";
- };
+ close.onclick = closePopup;
close.title = "Close";
globalPopup.appendChild(close);
@@ -263,3 +271,43 @@ function extraNetworksRequestMetadata(event, extraPage, cardName) {
event.stopPropagation();
}
+
+var extraPageUserMetadataEditors = {};
+
+function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
+ var id = tabname + '_' + extraPage + '_edit_user_metadata';
+
+ var editor = extraPageUserMetadataEditors[id];
+ if (!editor) {
+ editor = {};
+ editor.page = gradioApp().getElementById(id);
+ editor.nameTextarea = gradioApp().querySelector("#" + id + "_name" + ' textarea');
+ editor.button = gradioApp().querySelector("#" + id + "_button");
+ extraPageUserMetadataEditors[id] = editor;
+ }
+
+ editor.nameTextarea.value = cardName;
+ updateInput(editor.nameTextarea);
+
+ editor.button.click();
+
+ popup(editor.page);
+
+ event.stopPropagation();
+}
+
+function extraNetworksRefreshSingleCard(page, tabname, name) {
+ requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
+ if (data && data.html) {
+ var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function
+
+ var newDiv = document.createElement('DIV');
+ newDiv.innerHTML = data.html;
+ var newCard = newDiv.firstElementChild;
+
+ newCard.style = '';
+ card.parentElement.insertBefore(newCard, card);
+ card.parentElement.removeChild(card);
+ }
+ });
+}
diff --git a/javascript/hints.js b/javascript/hints.js
index dc75ce31..4167cb28 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -84,8 +84,6 @@ var titles = {
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
- "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
-
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
@@ -110,7 +108,6 @@ var titles = {
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
- "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
diff --git a/modules/api/api.py b/modules/api/api.py
index 11045292..2a4cd8a2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,5 +1,6 @@
import base64
import io
+import os
import time
import datetime
import uvicorn
@@ -98,14 +99,16 @@ def encode_pil_to_base64(image):
def api_middleware(app: FastAPI):
- rich_available = True
+ rich_available = False
try:
- import anyio # importing just so it can be placed on silent list
- import starlette # importing just so it can be placed on silent list
- from rich.console import Console
- console = Console()
+ if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
+ import anyio # importing just so it can be placed on silent list
+ import starlette # importing just so it can be placed on silent list
+ from rich.console import Console
+ console = Console()
+ rich_available = True
except Exception:
- rich_available = False
+ pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
@@ -116,14 +119,14 @@ def api_middleware(app: FastAPI):
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
- t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
- code = res.status_code,
- ver = req.scope.get('http_version', '0.0'),
- cli = req.scope.get('client', ('0:0.0.0', 0))[0],
- prot = req.scope.get('scheme', 'err'),
- method = req.scope.get('method', 'err'),
- endpoint = endpoint,
- duration = duration,
+ t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code=res.status_code,
+ ver=req.scope.get('http_version', '0.0'),
+ cli=req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot=req.scope.get('scheme', 'err'),
+ method=req.scope.get('method', 'err'),
+ endpoint=endpoint,
+ duration=duration,
))
return res
@@ -134,7 +137,7 @@ def api_middleware(app: FastAPI):
"body": vars(e).get('body', ''),
"errors": str(e),
}
- if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
+ if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
diff --git a/modules/cache.py b/modules/cache.py
new file mode 100644
index 00000000..28d42a8c
--- /dev/null
+++ b/modules/cache.py
@@ -0,0 +1,97 @@
+import json
+import os.path
+import threading
+
+from modules.paths import data_path, script_path
+
+cache_filename = os.path.join(data_path, "cache.json")
+cache_data = None
+cache_lock = threading.Lock()
+
+
+def dump_cache():
+ """
+ Saves all cache data to a file.
+ """
+
+ with cache_lock:
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ """
+ Retrieves or initializes a cache for a specific subsection.
+
+ Parameters:
+ subsection (str): The subsection identifier for the cache.
+
+ Returns:
+ dict: The cache data for the specified subsection.
+ """
+
+ global cache_data
+
+ if cache_data is None:
+ with cache_lock:
+ if cache_data is None:
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ try:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+ except Exception:
+ os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
+ print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
+ cache_data = {}
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def cached_data_for_file(subsection, title, filename, func):
+ """
+ Retrieves or generates data for a specific file, using a caching mechanism.
+
+ Parameters:
+ subsection (str): The subsection of the cache to use.
+ title (str): The title of the data entry in the subsection of the cache.
+ filename (str): The path to the file to be checked for modifications.
+ func (callable): A function that generates the data if it is not available in the cache.
+
+ Returns:
+ dict or None: The cached or generated data, or None if data generation fails.
+
+ The `cached_data_for_file` function implements a caching mechanism for data stored in files.
+ It checks if the data associated with the given `title` is present in the cache and compares the
+ modification time of the file with the cached modification time. If the file has been modified,
+ the cache is considered invalid and the data is regenerated using the provided `func`.
+ Otherwise, the cached data is returned.
+
+ If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
+ or cached data is returned as a dictionary.
+ """
+
+ existing_cache = cache(subsection)
+ ondisk_mtime = os.path.getmtime(filename)
+
+ entry = existing_cache.get(title)
+ if entry:
+ cached_mtime = entry.get("mtime", 0)
+ if ondisk_mtime > cached_mtime:
+ entry = None
+
+ if not entry:
+ value = func()
+ if value is None:
+ return None
+
+ entry = {'mtime': ondisk_mtime, 'value': value}
+ existing_cache[title] = entry
+
+ dump_cache()
+
+ return entry['value']
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 3b94f8a4..61aa240f 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -85,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60