diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-13 12:21:39 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 12:21:39 +0000 |
commit | b7c5b30f14aadffadd2d35cb3ecb3e91af00581d (patch) | |
tree | 0e51f517bb6ac010c0e3dc5937d112656ec9ee9a /modules | |
parent | 14501f56aaf3c97fb2c38633350dc747b9651f43 (diff) | |
parent | 262ec8ecdaf10d8fe49d0227e24bd3a1459e87b5 (diff) | |
download | stable-diffusion-webui-gfx803-b7c5b30f14aadffadd2d35cb3ecb3e91af00581d.tar.gz stable-diffusion-webui-gfx803-b7c5b30f14aadffadd2d35cb3ecb3e91af00581d.tar.bz2 stable-diffusion-webui-gfx803-b7c5b30f14aadffadd2d35cb3ecb3e91af00581d.zip |
Merge branch 'dev' into master
Diffstat (limited to 'modules')
33 files changed, 488 insertions, 293 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 4ea5d825..11045292 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases from modules.sd_vae import vae_dict from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models @@ -30,13 +30,7 @@ from modules import devices from typing import Dict, List, Any import piexif import piexif.helper - - -def upscaler_to_index(name: str): - try: - return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e +from contextlib import closing def script_name_to_index(name, scripts): @@ -84,6 +78,8 @@ def encode_pil_to_base64(image): image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + if image.mode == "RGBA": + image = image.convert("RGB") parameters = image.info.get('parameters', None) exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } @@ -209,6 +205,11 @@ class Api: self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + if shared.cmd_opts.api_server_stop: + self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) + self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -324,19 +325,19 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - p.scripts = script_runner - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples - - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() + with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: + p.scripts = script_runner + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + + shared.state.begin(job="scripts_txt2img") + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -380,20 +381,20 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - p.init_images = [decode_base64_to_image(x) for x in init_images] - p.scripts = script_runner - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples - - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() + with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: + p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + + shared.state.begin(job="scripts_img2img") + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -517,6 +518,10 @@ class Api: return options def set_config(self, req: Dict[str, Any]): + checkpoint_name = req.get("sd_model_checkpoint", None) + if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: + raise RuntimeError(f"model {checkpoint_name!r} not found") + for k, v in req.items(): shared.opts.set(k, v) @@ -598,44 +603,42 @@ class Api: def create_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used - shared.state.end() return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create embedding error: {e}") + finally: + shared.state.end() + def create_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding - shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create hypernetwork error: {e}") + finally: + shared.state.end() def preprocess(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() - return models.PreprocessResponse(info = 'preprocess complete') + return models.PreprocessResponse(info='preprocess complete') except KeyError as e: - shared.state.end() return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except AssertionError as e: - shared.state.end() + except Exception as e: return models.PreprocessResponse(info=f"preprocess error: {e}") - except FileNotFoundError as e: + finally: shared.state.end() - return models.PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_embedding") apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -648,15 +651,15 @@ class Api: finally: if not apply_optimizations: sd_hijack.apply_optimizations() - shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError as msg: - shared.state.end() + except Exception as msg: return models.TrainResponse(info=f"train embedding error: {msg}") + finally: + shared.state.end() def train_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_hypernetwork") shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None @@ -674,9 +677,10 @@ class Api: sd_hijack.apply_optimizations() shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError: + except Exception as exc: + return models.TrainResponse(info=f"train embedding error: {exc}") + finally: shared.state.end() - return models.TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: @@ -716,3 +720,16 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive) + + def kill_webui(self): + restart.stop_program() + + def restart_webui(self): + if restart.is_restartable(): + restart.restart_program() + return Response(status_code=501) + + def stop_webui(request): + shared.state.server_command = "stop" + return Response("Stopping.") + diff --git a/modules/api/models.py b/modules/api/models.py index b3a745f0..b5683071 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel): prompt: Optional[str] = Field(title="Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt") -class ArtistItem(BaseModel): - name: str = Field(title="Name") - score: float = Field(title="Score") - category: str = Field(title="Category") class EmbeddingItem(BaseModel): step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") diff --git a/modules/call_queue.py b/modules/call_queue.py index 1b5e5273..3b94f8a4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,3 +1,4 @@ +from functools import wraps
import html
import threading
import time
@@ -18,6 +19,7 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None):
+ @wraps(func)
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
@@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): id_task = None
with queue_lock:
- shared.state.begin()
+ shared.state.begin(job=id_task)
progress.start_task(id_task)
try:
@@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+ @wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 982d9055..ae78f469 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -107,4 +107,5 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index d974e4b8..da42b5e9 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -15,7 +15,6 @@ model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-have_codeformer = False
codeformer = None
@@ -100,7 +99,7 @@ def setup_model(dirname): output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
- torch.cuda.empty_cache()
+ devices.torch_gc()
except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -123,9 +122,6 @@ def setup_model(dirname): return restored_img
- global have_codeformer
- have_codeformer = True
-
global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
diff --git a/modules/devices.py b/modules/devices.py index 1ed6ffdc..57e51da3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -15,13 +15,6 @@ def has_mps() -> bool: else: return mac_specific.has_mps -def extract_device_id(args, name): - for x in range(len(args)): - if name in args[x]: - return args[x + 1] - - return None - def get_cuda_device_string(): from modules import shared @@ -56,11 +49,15 @@ def get_device_for(task): def torch_gc(): + if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() + if has_mps(): + mac_specific.torch_mps_gc() + def enable_tf32(): if torch.cuda.is_available(): diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2fced999..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,15 +1,13 @@ -import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 1f093df2..41799b0a 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -103,6 +103,9 @@ def activate(p, extra_network_data): except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")
+ if p.scripts is not None:
+ p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
+
def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
diff --git a/modules/extras.py b/modules/extras.py index 830b53aa..e9c0263e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -73,8 +73,7 @@ def to_half(tensor, enable): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
- shared.state.begin()
- shared.state.job = 'model-merge'
+ shared.state.begin(job="model-merge")
def fail(message):
shared.state.textinfo = message
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index dd30a1b5..a3448be9 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -174,31 +174,6 @@ def send_image_and_dimensions(x): return img, w, h
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
@@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res
-settings_map = {}
-
-
-
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 6ecd295c..8e0f13bd 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -25,7 +25,7 @@ def gfpgann(): return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
+ if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5d12b449..79670b87 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -3,6 +3,7 @@ import glob import html
import os
import inspect
+from contextlib import closing
import modules.textual_inversion.dataset
import torch
@@ -353,17 +354,6 @@ def load_hypernetworks(names, multipliers=None): shared.loaded_hypernetworks.append(hypernetwork)
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@@ -446,18 +436,6 @@ def statistics(data): return total_information, recent_information
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -734,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
+ with closing(p):
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -770,7 +749,6 @@ Last saved image: {html.escape(last_saved_image)}<br/> pbar.leave = False
pbar.close()
hypernetwork.eval()
- #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
diff --git a/modules/images.py b/modules/images.py index 7bbfc3e0..4bdedb7f 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,3 +1,5 @@ +from __future__ import annotations
+
import datetime
import pytz
@@ -10,7 +12,7 @@ import re import numpy as np
import piexif
import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
import string
import json
import hashlib
@@ -139,6 +141,11 @@ class GridAnnotation: def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+
+ color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
+ color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
+ color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
+
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
@@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): fnt = get_font(fontsize)
- color_active = (0, 0, 0)
- color_inactive = (153, 153, 153)
-
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
cols = im.width // width
@@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
- calc_img = Image.new("RGB", (1, 1), "white")
+ calc_img = Image.new("RGB", (1, 1), color_background)
calc_d = ImageDraw.Draw(calc_img)
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
for row in range(rows):
for col in range(cols):
@@ -302,12 +306,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ if fill_height > 0:
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+ if fill_width > 0:
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return res
@@ -372,8 +378,8 @@ class FilenameGenerator: 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
-
}
default_time_format = '%Y%m%d%H%M%S'
@@ -497,13 +503,23 @@ def get_next_sequence_number(path, basename): return result + 1
-def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+ """
+ Saves image to filename, including geninfo as text information for generation info.
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+ """
+
if extension is None:
extension = os.path.splitext(filename)[1]
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png':
+ existing_pnginfo = existing_pnginfo or {}
+ if opts.enable_pnginfo:
+ existing_pnginfo[pnginfo_section_name] = geninfo
+
if opts.enable_pnginfo:
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in (existing_pnginfo or {}).items():
@@ -622,7 +638,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """
temp_file_path = f"{filename_without_extension}.tmp"
- save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
os.replace(temp_file_path, filename_without_extension + extension)
@@ -639,12 +655,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
-
+ resize_to = None
if oversize and ratio > 1:
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
elif oversize:
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
+ if resize_to is not None:
+ try:
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+ image = image.resize(resize_to, LANCZOS)
+ except Exception:
+ image = image.resize(resize_to)
try:
_atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e:
@@ -662,8 +684,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i return fullfn, txt_fullfn
-def read_info_from_image(image):
- items = image.info or {}
+IGNORED_INFO_KEYS = {
+ 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity', 'photoshop',
+}
+
+
+def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
+ items = (image.info or {}).copy()
geninfo = items.pop('parameters', None)
@@ -679,9 +708,7 @@ def read_info_from_image(image): items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
- 'icc_profile', 'chromaticity']:
+ for field in IGNORED_INFO_KEYS:
items.pop(field, None)
if items.get("Software", None) == "NovelAI":
diff --git a/modules/img2img.py b/modules/img2img.py index 2c497020..664e2688 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -1,23 +1,26 @@ import os
+from contextlib import closing
from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import gradio as gr
-from modules import sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules import sd_samplers, images as imgutil
+from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
+from modules.images import save_image
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.scripts
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
processing.fix_seed(p)
- images = shared.listfiles(input_dir)
+ images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
is_inpaint_batch = False
if inpaint_mask_dir:
@@ -36,6 +39,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal state.job_count = len(images) * p.n_iter
+ # extract "default" params to use in case getting png info fails
+ prompt = p.prompt
+ negative_prompt = p.negative_prompt
+ seed = p.seed
+ cfg_scale = p.cfg_scale
+ sampler_name = p.sampler_name
+ steps = p.steps
+
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
@@ -79,25 +90,45 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
+ if use_png_info:
+ try:
+ info_img = img
+ if png_info_dir:
+ info_img_path = os.path.join(png_info_dir, os.path.basename(image))
+ info_img = Image.open(info_img_path)
+ geninfo, _ = imgutil.read_info_from_image(info_img)
+ parsed_parameters = parse_generation_parameters(geninfo)
+ parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
+ except Exception:
+ parsed_parameters = {}
+
+ p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
+ p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
+ p.seed = int(parsed_parameters.get("Seed", seed))
+ p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
+ p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
+ p.steps = int(parsed_parameters.get("Steps", steps))
+
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
for n, processed_image in enumerate(proc.images):
- filename = image_path.name
+ filename = image_path.stem
+ infotext = proc.infotext(p, n)
+ relpath = os.path.dirname(os.path.relpath(image, input_dir))
if n > 0:
- left, right = os.path.splitext(filename)
- filename = f"{left}-{n}{right}"
+ filename += f"-{n}"
if not save_normally:
- os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
if processed_image.mode == 'RGBA':
processed_image = processed_image.convert("RGB")
- processed_image.save(os.path.join(output_dir, filename))
+ save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -180,24 +211,25 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s p.scripts = modules.scripts.scripts_img2img
p.script_args = args
+ p.user = request.username
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask:
p.extra_generation_params["Mask blur"] = mask_blur
- if is_batch:
- assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ with closing(p):
+ if is_batch:
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
-
- processed = Processed(p, [], p.seed, "")
- else:
- processed = modules.scripts.scripts_img2img.run(p, *args)
- if processed is None:
- processed = process_images(p)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
- p.close()
+ processed = Processed(p, [], p.seed, "")
+ else:
+ processed = modules.scripts.scripts_img2img.run(p, *args)
+ if processed is None:
+ processed = process_images(p)
shared.total_tqdm.clear()
diff --git a/modules/interrogate.py b/modules/interrogate.py index 9b2c5b60..a3ae1dd5 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -184,8 +184,7 @@ class InterrogateModels: def interrogate(self, pil_image):
res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
+ shared.state.begin(job="interrogate")
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 609a181e..0e0dbca4 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -142,15 +142,15 @@ def git_clone(url, dir, name, commithash=None): if commithash is None:
return
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash:
return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
diff --git a/modules/mac_specific.py b/modules/mac_specific.py index d74c6b95..9ceb43ba 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,20 +1,43 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger(__name__) + -# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. -# check `getattr` and try it for compatibility +# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# use check `getattr` and try it for compatibility. +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 def check_for_mps() -> bool: - if not getattr(torch, 'has_mps', False): - return False + if version.parse(torch.__version__) <= version.parse("2.0.1"): + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + +has_mps = check_for_mps() + + +def torch_mps_gc() -> None: try: - torch.zeros(1).to(torch.device("mps")) - return True + from modules.shared import state + if state.current_latent is not None: + log.debug("`current_latent` is set, skipping MPS garbage collection") + return + from torch.mps import empty_cache + empty_cache() except Exception: - return False -has_mps = check_for_mps() + log.warning("MPS garbage collection failed", exc_info=True) # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 diff --git a/modules/modelloader.py b/modules/modelloader.py index 75f01247..098bcb79 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import importlib @@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - from basicsr.utils.download_util import load_file_from_url - dl = load_file_from_url(model_url, places[0], True, download_name) - output.append(dl) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) else: output.append(model_url) @@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None def friendly_name(file: str): - if "http" in file: + if file.startswith("http"): file = urlparse(file).path file = os.path.basename(file) diff --git a/modules/paths.py b/modules/paths.py index 5171df4f..bada804e 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs: else:
sys.path.append(d)
paths[what] = d
-
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 736315e2..136e9c88 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -9,8 +9,7 @@ from modules.shared import opts def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
- shared.state.begin()
- shared.state.job = 'extras'
+ shared.state.begin(job="extras")
image_data = []
image_names = []
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, for image, name in zip(image_data, image_names):
shared.state.textinfo = name
- existing_pnginfo = image.info or {}
+ parameters, existing_pnginfo = images.read_info_from_image(image)
+ if parameters:
+ existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
diff --git a/modules/processing.py b/modules/processing.py index 8da73884..cd568a20 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -184,6 +184,8 @@ class StableDiffusionProcessing: self.uc = None
self.c = None
+ self.user = None
+
@property
def sd_model(self):
return shared.sd_model
@@ -549,7 +551,7 @@ def program_version(): return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -573,7 +575,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
@@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
+ "User": p.user if opts.add_user_name_to_info else None,
}
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -602,7 +606,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
- if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+ if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
@@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- text = infotext()
+ text = infotext(use_main_prompt=True)
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
@@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data)
@@ -1074,6 +1078,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+ if self.scripts is not None:
+ self.scripts.before_hr(self)
+
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2d27b321..0700b853 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -2,7 +2,6 @@ import os import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): if not self.enable:
return img
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print(f"Unable to load RealESRGAN model: {info.name}")
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler): return image
def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
- return info
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
- return None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
diff --git a/modules/scripts.py b/modules/scripts.py index 99bf836a..7d9dd59f 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,6 +1,7 @@ import os
import re
import sys
+import inspect
from collections import namedtuple
import gradio as gr
@@ -116,6 +117,21 @@ class Script: pass
+ def after_extra_networks_activate(self, p, *args, **kwargs):
+ """
+ Calledafter extra networks activation, before conds calculation
+ allow modification of the network after extra networks activation been applied
+ won't be call if p.disable_extra_networks
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ - extra_network_data - list of ExtraNetworkParams for current stage
+ """
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -186,6 +202,11 @@ class Script: return f'script_{tabname}{title}_{item_id}'
+ def before_hr(self, p, *args):
+ """
+ This function is called before hires fix start.
+ """
+ pass
current_basedir = paths.script_path
@@ -249,7 +270,7 @@ def load_scripts(): def register_scripts_from_module(module):
for script_class in module.__dict__.values():
- if type(script_class) != type:
+ if not inspect.isclass(script_class):
continue
if issubclass(script_class, Script):
@@ -483,6 +504,14 @@ class ScriptRunner: except Exception:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
+ def after_extra_networks_activate(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.after_extra_networks_activate(p, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
+
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
@@ -548,6 +577,15 @@ class ScriptRunner: self.scripts[si].args_to = args_to
+ def before_hr(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_hr(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+
+
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ff5d17d..060e0007 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -23,7 +23,8 @@ model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
checkpoints_list = {}
-checkpoint_alisases = {}
+checkpoint_aliases = {}
+checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
@@ -66,7 +67,7 @@ class CheckpointInfo: def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
- checkpoint_alisases[id] = self
+ checkpoint_aliases[id] = self
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
@@ -112,7 +113,7 @@ def checkpoint_tiles(): def list_models():
checkpoints_list.clear()
- checkpoint_alisases.clear()
+ checkpoint_aliases.clear()
cmd_ckpt = shared.cmd_opts.ckpt
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
@@ -136,7 +137,7 @@ def list_models(): def get_closet_checkpoint_match(search_string):
- checkpoint_info = checkpoint_alisases.get(search_string, None)
+ checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -166,7 +167,7 @@ def select_checkpoint(): """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+ if not shared.opts.disable_mmap_load_safetensors:
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -585,7 +591,6 @@ def unload_model_weights(sd_model=None, info=None): sd_model = None
gc.collect()
devices.torch_gc()
- torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.")
diff --git a/modules/shared.py b/modules/shared.py index a0862055..48478a68 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,9 +1,11 @@ import datetime
import json
import os
+import re
import sys
import threading
import time
+import logging
import gradio as gr
import torch
@@ -18,6 +20,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import Optional
+log = logging.getLogger(__name__)
+
demo = None
parser = cmd_args.parser
@@ -144,12 +148,15 @@ class State: def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
+ log.info("Received restart request")
def skip(self):
self.skipped = True
+ log.info("Received skip request")
def interrupt(self):
self.interrupted = True
+ log.info("Received interrupt request")
def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@@ -173,7 +180,7 @@ class State: return obj
- def begin(self):
+ def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
@@ -187,10 +194,13 @@ class State: self.interrupted = False
self.textinfo = None
self.time_start = time.time()
-
+ self.job = job
devices.torch_gc()
+ log.info("Starting job %s", job)
def end(self):
+ duration = time.time() - self.time_start
+ log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
@@ -311,6 +321,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+ "font": OptionInfo("", "Font for image grids that have text"),
+ "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
@@ -376,6 +390,7 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -470,7 +485,6 @@ options_templates.update(options_section(('ui', "User interface"), { "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
@@ -481,6 +495,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
@@ -493,6 +508,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+ "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
@@ -817,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) mem_mon.start()
+def natural_sort_key(s, regex=re.compile('([0-9]+)')):
+ return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
+
+
def listfiles(dirname):
- filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
@@ -843,8 +863,11 @@ def walk_files(path, allowed_extensions=None): if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
- for root, _, files in os.walk(path, followlinks=True):
- for filename in files:
+ items = list(os.walk(path, followlinks=True))
+ items = sorted(items, key=lambda x: natural_sort_key(x[0]))
+
+ for root, _, files in items:
+ for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py index 734a4b6f..45823eb1 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/logging.py @@ -2,11 +2,51 @@ import datetime import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_shared = {
+ "batch_size",
+ "clip_grad_mode",
+ "clip_grad_value",
+ "create_image_every",
+ "data_root",
+ "gradient_step",
+ "initial_step",
+ "latent_sampling_method",
+ "learn_rate",
+ "log_directory",
+ "model_hash",
+ "model_name",
+ "num_of_dataset_images",
+ "steps",
+ "template_file",
+ "training_height",
+ "training_width",
+}
+saved_params_ti = {
+ "embedding_name",
+ "num_vectors_per_token",
+ "save_embedding_every",
+ "save_image_with_stored_embedding",
+}
+saved_params_hypernet = {
+ "activation_func",
+ "add_layer_norm",
+ "hypernetwork_name",
+ "layer_structure",
+ "save_hypernetwork_every",
+ "use_dropout",
+ "weight_init",
+}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+saved_params_previews = {
+ "preview_cfg_scale",
+ "preview_height",
+ "preview_negative_prompt",
+ "preview_prompt",
+ "preview_sampler_index",
+ "preview_seed",
+ "preview_steps",
+ "preview_width",
+}
def save_settings_to_file(log_directory, all_params):
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 0d4c3f84..dbd856bd 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index bb6f211c..cbe975b7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,5 +1,6 @@ import os
from collections import namedtuple
+from contextlib import closing
import torch
import tqdm
@@ -584,8 +585,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
+ with closing(p):
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
diff --git a/modules/txt2img.py b/modules/txt2img.py index 2e7d202d..d0be2e73 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,13 +1,15 @@ +from contextlib import closing
+
import modules.scripts
from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.ui import plaintext_to_html
+import gradio as gr
-
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -48,15 +50,16 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
+ p.user = request.username
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
- processed = modules.scripts.scripts_txt2img.run(p, *args)
-
- if processed is None:
- processed = processing.process_images(p)
+ with closing(p):
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
- p.close()
+ if processed is None:
+ processed = processing.process_images(p)
shared.total_tqdm.clear()
diff --git a/modules/ui.py b/modules/ui.py index e2e3b6da..39d226ad 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di img = Image.open(image)
filename = os.path.basename(image)
left, _ = os.path.splitext(filename)
- print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
return [gr.update(), None]
@@ -733,6 +733,10 @@ def create_ui(): img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ with gr.Accordion("PNG info", open=False):
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
@@ -773,7 +777,7 @@ def create_ui(): selected_scale_tab = gr.State(value=0)
with gr.Tabs():
- with gr.Tab(label="Resize to") as tab_scale_to:
+ with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@@ -782,7 +786,7 @@ def create_ui(): res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
- with gr.Tab(label="Resize by") as tab_scale_by:
+ with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow():
@@ -934,6 +938,9 @@ def create_ui(): img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir,
override_settings,
+ img2img_batch_use_png_info,
+ img2img_batch_png_info_props,
+ img2img_batch_png_info_dir,
] + custom_inputs,
outputs=[
img2img_gallery,
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index c7e0a866..dff522ef 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -138,7 +138,10 @@ def extension_table(): <table id="extensions">
<thead>
<tr>
- <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+ <th>
+ <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
+ <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
+ </th>
<th>URL</th>
<th>Branch</th>
<th>Version</th>
@@ -170,7 +173,7 @@ def extension_table(): code += f"""
<tr>
- <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+ <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
<td>{remote}</td>
<td>{ext.branch}</td>
<td>{version_link}</td>
@@ -421,9 +424,19 @@ sort_ordering = [ (False, lambda x: x.get('name', 'z')),
(True, lambda x: x.get('name', 'z')),
(False, lambda x: 'z'),
+ (True, lambda x: x.get('commit_time', '')),
+ (True, lambda x: x.get('created_at', '')),
+ (True, lambda x: x.get('stars', 0)),
]
+def get_date(info: dict, key):
+ try:
+ return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
+ except (ValueError, TypeError):
+ return ''
+
+
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -448,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname")
+ stars = int(ext.get("stars", 0))
added = ext.get('added', 'unknown')
+ update_time = get_date(ext, 'commit_time')
+ create_time = get_date(ext, 'created_at')
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
@@ -475,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
- <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
+ <td>{html.escape(description)}<p class="info">
+ <span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
<td>{install_code}</td>
</tr>
@@ -559,7 +576,7 @@ def create_ui(): with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
- sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False)
@@ -568,9 +585,9 @@ def create_ui(): available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column],
- outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
)
install_extension_button.click(
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a7d3bc79..693cafb6 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -30,8 +30,8 @@ def fetch_file(filename: str = ""): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".jpeg", ".webp"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+ if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -90,8 +90,8 @@ class ExtraNetworksPage: subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for root, dirs, _ in os.walk(parentdir, followlinks=True):
- for dirname in dirs:
+ for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
+ for dirname in sorted(dirs, key=shared.natural_sort_key):
x = os.path.join(root, dirname)
if not os.path.isdir(x):
diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 0c560b30..a6076bf3 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -260,13 +260,20 @@ class UiSettings: component = self.component_dict[k]
info = opts.data_labels[k]
- change_handler = component.release if hasattr(component, 'release') else component.change
- change_handler(
- fn=lambda value, k=k: self.run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, self.text_settings],
- show_progress=info.refresh is not None,
- )
+ if isinstance(component, gr.Textbox):
+ methods = [component.submit, component.blur]
+ elif hasattr(component, 'release'):
+ methods = [component.release]
+ else:
+ methods = [component.change]
+
+ for method in methods:
+ method(
+ fn=lambda value, k=k: self.run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, self.text_settings],
+ show_progress=info.refresh is not None,
+ )
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
|