diff options
136 files changed, 3007 insertions, 1256 deletions
diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index a168be5b..d42965b1 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -18,22 +18,29 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: - python-version: 3.10.6 - cache: pip - cache-dependency-path: | - **/requirements*txt - - name: Install PyLint - run: | - python -m pip install --upgrade pip - pip install pylint - # This lets PyLint check to see if it can resolve imports - - name: Install dependencies - run: | - export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" - python launch.py - - name: Analysing the code with pylint - run: | - pylint $(git ls-files '*.py') + python-version: 3.11 + # NB: there's no cache: pip here since we're not installing anything + # from the requirements.txt file(s) in the repository; it's faster + # not to have GHA download an (at the time of writing) 4 GB cache + # of PyTorch and other dependencies. + - name: Install Ruff + run: pip install ruff==0.0.265 + - name: Run Ruff + run: ruff . + +# The rest are currently disabled pending fixing of e.g. installing the torch dependency. + +# - name: Install PyLint +# run: | +# python -m pip install --upgrade pip +# pip install pylint +# # This lets PyLint check to see if it can resolve imports +# - name: Install dependencies +# run: | +# export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" +# python launch.py +# - name: Analysing the code with pylint +# run: | +# pylint $(git ls-files '*.py') @@ -32,4 +32,5 @@ notification.mp3 /extensions /test/stdout.txt /test/stderr.txt -/cache.json +/cache.json* +/config_states/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..cf3fef3d --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,106 @@ +## Upcoming 1.2.0
+
+### Features:
+ * do not load wait for stable diffusion model to load at startup
+ * add filename patterns: [denoising]
+ * directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
+ * Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
+ * Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
+ * Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
+ * Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
+ * add version to infotext, footer and console output when starting
+ * add links to wiki for filename pattern settings
+ * add extended info for quicksettings setting and use multiselect input instead of a text field
+
+### Minor:
+ * gradio bumped to 3.29.0
+ * torch bumped to 2.0.1
+ * --subpath option for gradio for use with reverse proxy
+ * linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
+ * possible frontend optimization: do not apply localizations if there are none
+ * Add extra `None` option for VAE in XYZ plot
+ * print error to console when batch processing in img2img fails
+ * create HTML for extra network pages only on demand
+ * allow directories starting with . to still list their models for lora, checkpoints, etc
+ * put infotext options into their own category in settings tab
+ * do not show licenses page when user selects Show all pages in settings
+
+### Extensions:
+ * Tooltip localization support
+ * Add api method to get LoRA models with prompt
+
+### Bug Fixes:
+ * re-add /docs endpoint
+ * fix gamepad navigation
+ * make the lightbox fullscreen image function properly
+ * fix squished thumbnails in extras tab
+ * keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
+ * fix webui showing the same image if you configure the generation to always save results into same file
+ * fix bug with upscalers not working properly
+ * Fix MPS on PyTorch 2.0.1, Intel Macs
+ * make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
+ * prevent Reload UI button/link from reloading the page when it's not yet ready
+
+
+## 1.1.1
+### Bug Fixes:
+ * fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
+
+## 1.1.0
+### Features:
+ * switch to torch 2.0.0 (except for AMD GPUs)
+ * visual improvements to custom code scripts
+ * add filename patterns: [clip_skip], [hasprompt<>], [batch_number], [generation_number]
+ * add support for saving init images in img2img, and record their hashes in infotext for reproducability
+ * automatically select current word when adjusting weight with ctrl+up/down
+ * add dropdowns for X/Y/Z plot
+ * setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
+ * support Gradio's theme API
+ * use TCMalloc on Linux by default; possible fix for memory leaks
+ * (optimization) option to remove negative conditioning at low sigma values #9177
+ * embed model merge metadata in .safetensors file
+ * extension settings backup/restore feature #9169
+ * add "resize by" and "resize to" tabs to img2img
+ * add option "keep original size" to textual inversion images preprocess
+ * image viewer scrolling via analog stick
+ * button to restore the progress from session lost / tab reload
+
+### Minor:
+ * gradio bumped to 3.28.1
+ * in extra tab, change extras "scale to" to sliders
+ * add labels to tool buttons to make it possible to hide them
+ * add tiled inference support for ScuNET
+ * add branch support for extension installation
+ * change linux installation script to insall into current directory rather than /home/username
+ * sort textual inversion embeddings by name (case insensitive)
+ * allow styles.csv to be symlinked or mounted in docker
+ * remove the "do not add watermark to images" option
+ * make selected tab configurable with UI config
+ * extra networks UI in now fixed height and scrollable
+ * add disable_tls_verify arg for use with self-signed certs
+
+### Extensions:
+ * Add reload callback
+ * add is_hr_pass field for processing
+
+### Bug Fixes:
+ * fix broken batch image processing on 'Extras/Batch Process' tab
+ * add "None" option to extra networks dropdowns
+ * fix FileExistsError for CLIP Interrogator
+ * fix /sdapi/v1/txt2img endpoint not working on Linux #9319
+ * fix disappearing live previews and progressbar during slow tasks
+ * fix fullscreen image view not working properly in some cases
+ * prevent alwayson_scripts args param resizing script_arg list when they are inserted in it
+ * fix prompt schedule for second order samplers
+ * fix image mask/composite for weird resolutions #9628
+ * use correct images for previews when using AND (see #9491)
+ * one broken image in img2img batch won't stop all processing
+ * fix image orientation bug in train/preprocess
+ * fix Ngrok recreating tunnels every reload
+ * fix --realesrgan-models-path and --ldsr-models-path not working
+ * fix --skip-install not working
+ * outpainting Mk2 & Poorman should use the SAMPLE file format to save images, not GRID file format
+ * do not fail all Loras if some have failed to load when making a picture
+
+## 1.0.0
+ * everything
@@ -100,7 +100,7 @@ Alternatively, use online services (like Google Colab): - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Automatic Installation on Windows
-1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
+1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
@@ -115,11 +115,12 @@ sudo dnf install wget git python3 # Arch-based:
sudo pacman -S wget git python3
```
-2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
+2. Navigate to the directory you would like the webui to be installed and execute the following command:
```bash
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
```
3. Run `webui.sh`.
+4. Check `webui-user.sh` for options.
### Installation on Apple Silicon
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
@@ -158,4 +159,4 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
-- (You)
\ No newline at end of file +- (You)
diff --git a/environment-wsl2.yaml b/environment-wsl2.yaml index f8872750..0c4ae680 100644 --- a/environment-wsl2.yaml +++ b/environment-wsl2.yaml @@ -4,8 +4,8 @@ channels: - defaults dependencies: - python=3.10 - - pip=22.2.2 - - cudatoolkit=11.3 - - pytorch=1.12.1 - - torchvision=0.13.1 - - numpy=1.23.1
\ No newline at end of file + - pip=23.0 + - cudatoolkit=11.8 + - pytorch=2.0 + - torchvision=0.15 + - numpy=1.23 diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index bc11cc6e..2173de79 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -88,7 +88,7 @@ class LDSR: x_t = None logs = None - for n in range(n_runs): + for _ in range(n_runs): if custom_shape is not None: x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) @@ -110,7 +110,6 @@ class LDSR: diffusion_steps = int(steps) eta = 1.0 - down_sample_method = 'Lanczos' gc.collect() if torch.cuda.is_available: @@ -158,7 +157,7 @@ class LDSR: def get_cond(selected_path): - example = dict() + example = {} up_f = 4 c = selected_path.convert('RGB') c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) @@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s @torch.no_grad() def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): - log = dict() + log = {} z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, return_first_stage_outputs=True, @@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) log["sample_noquant"] = x_sample_noquant log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except: + except Exception: pass log["sample"] = x_sample diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index b8cff29b..fbbe9005 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks -import sd_hijack_autoencoder, sd_hijack_ddpm_v1 +import sd_hijack_autoencoder # noqa: F401 +import sd_hijack_ddpm_v1 # noqa: F401 class UpscalerLDSR(Upscaler): @@ -25,22 +26,28 @@ class UpscalerLDSR(Upscaler): yaml_path = os.path.join(self.model_path, "project.yaml") old_model_path = os.path.join(self.model_path, "model.pth") new_model_path = os.path.join(self.model_path, "model.ckpt") - safetensors_model_path = os.path.join(self.model_path, "model.safetensors") + + local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"]) + local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None) + local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None) + local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None) + if os.path.exists(yaml_path): statinfo = os.stat(yaml_path) if statinfo.st_size >= 10485760: print("Removing invalid LDSR YAML file.") os.remove(yaml_path) + if os.path.exists(old_model_path): print("Renaming model from model.pth to model.ckpt") os.rename(old_model_path, new_model_path) - if os.path.exists(safetensors_model_path): - model = safetensors_model_path + + if local_safetensors_path is not None and os.path.exists(local_safetensors_path): + model = local_safetensors_path else: - model = load_file_from_url(url=self.model_url, model_dir=self.model_path, - file_name="model.ckpt", progress=True) - yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, - file_name="project.yaml", progress=True) + model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True) + + yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True) try: return LDSR(model, yaml) diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 8e03c7f8..81c5101b 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -1,16 +1,21 @@ # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder - +import numpy as np import torch import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager + +from torch.optim.lr_scheduler import LambdaLR + +from ldm.modules.ema import LitEma from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config import ldm.models.autoencoder +from packaging import version class VQModel(pl.LightningModule): def __init__(self, @@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): n_embed, embed_dim, ckpt_path=None, - ignore_keys=[], + ignore_keys=None, image_key="image", colorize_nlabels=None, monitor=None, @@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or []) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor @@ -76,11 +81,11 @@ class VQModel(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -165,7 +170,7 @@ class VQModel(pl.LightningModule): def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + self._validation_step(batch, batch_idx, suffix="_ema") return log_dict def _validation_step(self, batch, batch_idx, suffix=""): @@ -232,7 +237,7 @@ class VQModel(pl.LightningModule): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: @@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log @@ -264,7 +270,7 @@ class VQModel(pl.LightningModule): class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) + super().__init__(*args, embed_dim=embed_dim, **kwargs) self.embed_dim = embed_dim def encode(self, x): @@ -282,5 +288,5 @@ class VQModelInterface(VQModel): dec = self.decoder(quant) return dec -setattr(ldm.models.autoencoder, "VQModel", VQModel) -setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) +ldm.models.autoencoder.VQModel = VQModel +ldm.models.autoencoder.VQModelInterface = VQModelInterface diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 5c0488e5..57c02d12 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): if monitor is not None: self.monitor = monitor if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) @@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] @@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, @@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): use_ddim = ddim_steps is not None - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: @@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. @@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] @@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): logs['bbox_image'] = cond_img return logs -setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) -setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) -setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) -setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) +ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1 +ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1 +ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1 +ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1 diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 6be6ef73..ccb249ac 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,6 +1,7 @@ from modules import extra_networks, shared
import lora
+
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
@@ -8,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list):
additional = shared.opts.sd_lora
- if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index d3eb0d3b..7b56136f 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -1,10 +1,9 @@ -import glob
import os
import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors
+from modules import shared, devices, sd_models, errors, scripts
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -93,6 +92,7 @@ class LoraOnDisk: self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
+ self.alias = self.metadata.get('ss_output_name', self.name)
class LoraModule:
@@ -165,12 +165,14 @@ def load_lora(name, filename): module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(sd_module) == torch.nn.Conv2d:
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -182,7 +184,7 @@ def load_lora(name, filename): elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -199,11 +201,11 @@ def load_loras(names, multipliers=None): loaded_loras.clear()
- loras_on_disk = [available_loras.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
+ if any(x is None for x in loras_on_disk):
list_available_loras()
- loras_on_disk = [available_loras.get(name, None) for name in names]
+ loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
@@ -211,7 +213,11 @@ def load_loras(names, multipliers=None): lora_on_disk = loras_on_disk[i]
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
- lora = load_lora(name, lora_on_disk.filename)
+ try:
+ lora = load_lora(name, lora_on_disk.filename)
+ except Exception as e:
+ errors.display(e, f"loading Lora {lora_on_disk.filename}")
+ continue
if lora is None:
print(f"Couldn't find Lora with name {name}")
@@ -228,6 +234,8 @@ def lora_calc_updown(lora, module, target): if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
@@ -236,6 +244,19 @@ def lora_calc_updown(lora, module, target): return updown
+def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ weights_backup = getattr(self, "lora_weights_backup", None)
+
+ if weights_backup is None:
+ return
+
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
@@ -260,12 +281,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu self.lora_weights_backup = weights_backup
if current_names != wanted_names:
- if weights_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
+ lora_restore_weights_from_backup(self)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
@@ -293,15 +309,48 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
+
+
+def lora_forward(module, input, original_forward):
+ """
+ Old way of applying Lora by executing operations during layer's forward.
+ Stacking many loras this way results in big performance degradation.
+ """
+
+ if len(loaded_loras) == 0:
+ return original_forward(module, input)
+
+ input = devices.cond_cast_unet(input)
+
+ lora_restore_weights_from_backup(module)
+ lora_reset_cached_weight(module)
+
+ res = original_forward(module, input)
+
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is None:
+ continue
+
+ module.up.to(device=devices.device)
+ module.down.to(device=devices.device)
+
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ return res
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
@@ -314,6 +363,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs): def lora_Conv2d_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
@@ -339,24 +391,59 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): def list_available_loras():
available_loras.clear()
+ available_lora_aliases.clear()
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
- candidates = \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
-
+ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
+ entry = LoraOnDisk(name, filename)
+
+ available_loras[name] = entry
+
+ available_lora_aliases[name] = entry
+ available_lora_aliases[entry.alias] = entry
+
+
+re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+
+
+def infotext_pasted(infotext, params):
+ if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+ return # if the other extension is active, it will handle those fields, no need to do anything
+
+ added = []
+
+ for k in params:
+ if not k.startswith("AddNet Model "):
+ continue
+
+ num = k[13:]
+
+ if params.get("AddNet Module " + num) != "LoRA":
+ continue
+
+ name = params.get("AddNet Model " + num)
+ if name is None:
+ continue
+
+ m = re_lora_name.match(name)
+ if m:
+ name = m.group(1)
+
+ multiplier = params.get("AddNet Weight A " + num, "1.0")
- available_loras[name] = LoraOnDisk(name, filename)
+ added.append(f"<lora:{name}:{multiplier}>")
+ if added:
+ params["Prompt"] += "\n" + "".join(added)
available_loras = {}
+available_lora_aliases = {}
loaded_loras = []
list_available_loras()
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 0adab225..13d297d7 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,12 +1,12 @@ import torch
import gradio as gr
+from fastapi import FastAPI
import lora
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
-
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
@@ -49,8 +49,33 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
+script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
+}))
+
+
+shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
+ "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
+
+
+def create_lora_json(obj: lora.LoraOnDisk):
+ return {
+ "name": obj.name,
+ "alias": obj.alias,
+ "path": obj.filename,
+ "metadata": obj.metadata,
+ }
+
+
+def api_loras(_: gr.Blocks, app: FastAPI):
+ @app.get("/sdapi/v1/loras")
+ async def get_loras():
+ return [create_lora_json(obj) for obj in lora.available_loras.values()]
+
+
+script_callbacks.on_app_started(api_loras)
+
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 68b11332..a0edbc1e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -21,7 +21,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename),
- "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
+ "prompt": json.dumps(f"<lora:{lora_on_disk.alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
}
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index e0fbf3a3..1f5ea0d3 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -5,11 +5,14 @@ import traceback import PIL.Image import numpy as np import torch +from tqdm import tqdm + from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader from scunet_model_arch import SCUNet as net +from modules.shared import opts class UpscalerScuNET(modules.upscaler.Upscaler): @@ -42,28 +45,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - def do_upscale(self, img: PIL.Image, selected_file): + @staticmethod + @torch.no_grad() + def tiled_inference(img, model): + # test the image tile by tile + h, w = img.shape[2:] + tile = opts.SCUNET_tile + tile_overlap = opts.SCUNET_tile_overlap + if tile == 0: + return model(img) + + device = devices.get_device_for('scunet') + assert tile % 8 == 0, "tile size should be a multiple of window_size" + sf = 1 + + stride = tile - tile_overlap + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) + W = torch.zeros_like(E, dtype=devices.dtype, device=device) + + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: + for h_idx in h_idx_list: + + for w_idx in w_idx_list: + + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) + output = E.div_(W) + + return output + + def do_upscale(self, img: PIL.Image.Image, selected_file): + torch.cuda.empty_cache() model = self.load_model(selected_file) if model is None: + print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) return img device = devices.get_device_for('scunet') - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) - - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] + tile = opts.SCUNET_tile + h, w = img.height, img.width + np_img = np.array(img) + np_img = np_img[:, :, ::-1] # RGB to BGR + np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW + torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore + + if tile > h or tile > w: + _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) + _img[:, :, :h, :w] = torch_img # pad image + torch_img = _img + + torch_output = self.tiled_inference(torch_img, model).squeeze(0) + torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any + np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() + del torch_img, torch_output torch.cuda.empty_cache() - return PIL.Image.fromarray(output, 'RGB') + + output = np_output.transpose((1, 2, 0)) # CHW to HWC + output = output[:, :, ::-1] # BGR to RGB + return PIL.Image.fromarray((output * 255).astype(np.uint8)) def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -79,9 +132,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler): model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model.load_state_dict(torch.load(filename), strict=True) model.eval() - for k, v in model.named_parameters(): + for _, v in model.named_parameters(): v.requires_grad = False model = model.to(device) return model - diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py index 43ca8d36..8028918a 100644 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -61,7 +61,9 @@ class WMSA(nn.Module): Returns: output: tensor shape [b h w c] """ - if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + if self.type != 'W': + x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) h_windows = x.size(1) w_windows = x.size(2) @@ -85,8 +87,9 @@ class WMSA(nn.Module): output = self.linear(output) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), - dims=(1, 2)) + if self.type != 'W': + output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) + return output def relative_embedding(self): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index e8783bca..55dd94ab 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,4 +1,3 @@ -import contextlib import os import numpy as np @@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts, state +from modules.shared import opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler): img = upscale(img, model) try: torch.cuda.empty_cache() - except: + except Exception: pass return img diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 863f42db..73e37cfa 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -644,7 +644,7 @@ class SwinIR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, @@ -844,7 +844,7 @@ class SwinIR(nn.Module): H, W = self.patches_resolution flops += H * W * 3 * self.embed_dim * 9 flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): + for layer in self.layers: flops += layer.flops() flops += H * W * 3 * self.embed_dim * self.embed_dim flops += self.upsample.flops() diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 0e28ae6e..3ca9be78 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -74,7 +74,7 @@ class WindowAttention(nn.Module): """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
+ pretrained_window_size=(0, 0)):
super().__init__()
self.dim = dim
@@ -698,7 +698,7 @@ class Swin2SR(nn.Module): """
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
@@ -994,7 +994,7 @@ class Swin2SR(nn.Module): H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index f0918e26..5c7a836a 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -1,103 +1,42 @@ // Stable Diffusion WebUI - Bracket checker -// Version 1.0 -// By Hingashi no Florin/Bwin4L +// By Hingashi no Florin/Bwin4L & @akx // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. -function checkBrackets(evt, textArea, counterElt) { - errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; - errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; - errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; - - openBracketRegExp = /\(/g; - closeBracketRegExp = /\)/g; - - openSquareBracketRegExp = /\[/g; - closeSquareBracketRegExp = /\]/g; - - openCurlyBracketRegExp = /\{/g; - closeCurlyBracketRegExp = /\}/g; - - totalOpenBracketMatches = 0; - totalCloseBracketMatches = 0; - totalOpenSquareBracketMatches = 0; - totalCloseSquareBracketMatches = 0; - totalOpenCurlyBracketMatches = 0; - totalCloseCurlyBracketMatches = 0; - - openBracketMatches = textArea.value.match(openBracketRegExp); - if(openBracketMatches) { - totalOpenBracketMatches = openBracketMatches.length; - } - - closeBracketMatches = textArea.value.match(closeBracketRegExp); - if(closeBracketMatches) { - totalCloseBracketMatches = closeBracketMatches.length; - } - - openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp); - if(openSquareBracketMatches) { - totalOpenSquareBracketMatches = openSquareBracketMatches.length; - } - - closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp); - if(closeSquareBracketMatches) { - totalCloseSquareBracketMatches = closeSquareBracketMatches.length; - } - - openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp); - if(openCurlyBracketMatches) { - totalOpenCurlyBracketMatches = openCurlyBracketMatches.length; - } - - closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp); - if(closeCurlyBracketMatches) { - totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length; - } - - if(totalOpenBracketMatches != totalCloseBracketMatches) { - if(!counterElt.title.includes(errorStringParen)) { - counterElt.title += errorStringParen; - } - } else { - counterElt.title = counterElt.title.replace(errorStringParen, ''); - } - - if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) { - if(!counterElt.title.includes(errorStringSquare)) { - counterElt.title += errorStringSquare; - } - } else { - counterElt.title = counterElt.title.replace(errorStringSquare, ''); - } - - if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) { - if(!counterElt.title.includes(errorStringCurly)) { - counterElt.title += errorStringCurly; +function checkBrackets(textArea, counterElt) { + var counts = {}; + (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => { + counts[bracket] = (counts[bracket] || 0) + 1; + }); + var errors = []; + + function checkPair(open, close, kind) { + if (counts[open] !== counts[close]) { + errors.push( + `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.` + ); } - } else { - counterElt.title = counterElt.title.replace(errorStringCurly, ''); } - if(counterElt.title != '') { - counterElt.classList.add('error'); - } else { - counterElt.classList.remove('error'); - } + checkPair('(', ')', 'round brackets'); + checkPair('[', ']', 'square brackets'); + checkPair('{', '}', 'curly brackets'); + counterElt.title = errors.join('\n'); + counterElt.classList.toggle('error', errors.length !== 0); } -function setupBracketChecking(id_prompt, id_counter){ - var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); - var counter = gradioApp().getElementById(id_counter) +function setupBracketChecking(id_prompt, id_counter) { + var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); + var counter = gradioApp().getElementById(id_counter) - textarea.addEventListener("input", function(evt){ - checkBrackets(evt, textarea, counter) - }); + if (textarea && counter) { + textarea.addEventListener("input", () => checkBrackets(textarea, counter)); + } } -onUiLoaded(function(){ - setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') - setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') - setupBracketChecking('img2img_prompt', 'img2img_token_counter') - setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') -})
\ No newline at end of file +onUiLoaded(function () { + setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); + setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); + setupBracketChecking('img2img_prompt', 'img2img_token_counter'); + setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); +}); diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index ef4b613a..1d546217 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -6,7 +6,7 @@ <ul> <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> </ul> - <span style="display:none" class='search_term'>{search_term}</span> + <span style="display:none" class='search_term{serach_only}'>{search_term}</span> </div> <span class='name'>{name}</span> <span class='description'>{description}</span> diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index a8278cca..5160081d 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){ var viewportOffset = targetElement.getBoundingClientRect();
- viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
+ var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
- scaledx = targetElement.naturalWidth*viewportscale
- scaledy = targetElement.naturalHeight*viewportscale
+ var scaledx = targetElement.naturalWidth*viewportscale
+ var scaledy = targetElement.naturalHeight*viewportscale
- cleintRectTop = (viewportOffset.top+window.scrollY)
- cleintRectLeft = (viewportOffset.left+window.scrollX)
- cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
- cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
+ var cleintRectTop = (viewportOffset.top+window.scrollY)
+ var cleintRectLeft = (viewportOffset.left+window.scrollX)
+ var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
+ var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
- viewRectTop = cleintRectCentreY-(scaledy/2)
- viewRectLeft = cleintRectCentreX-(scaledx/2)
- arRectWidth = scaledx
- arRectHeight = scaledy
+ var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
+ var arscaledx = currentWidth*arscale
+ var arscaledy = currentHeight*arscale
- arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
- arscaledx = currentWidth*arscale
- arscaledy = currentHeight*arscale
-
- arRectTop = cleintRectCentreY-(arscaledy/2)
- arRectLeft = cleintRectCentreX-(arscaledx/2)
- arRectWidth = arscaledx
- arRectHeight = arscaledy
+ var arRectTop = cleintRectCentreY-(arscaledy/2)
+ var arRectLeft = cleintRectCentreX-(arscaledx/2)
+ var arRectWidth = arscaledx
+ var arRectHeight = arscaledy
arPreviewRect.style.top = arRectTop+'px';
arPreviewRect.style.left = arRectLeft+'px';
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 06f505b0..b2bdf053 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -4,7 +4,7 @@ contextMenuInit = function(){ let menuSpecs = new Map();
const uid = function(){
- return Date.now().toString(36) + Math.random().toString(36).substr(2);
+ return Date.now().toString(36) + Math.random().toString(36).substring(2);
}
function showContextMenu(event,element,menuEntries){
@@ -16,8 +16,7 @@ contextMenuInit = function(){ oldMenu.remove()
}
- let tabButton = uiCurrentTab
- let baseStyle = window.getComputedStyle(tabButton)
+ let baseStyle = window.getComputedStyle(uiCurrentTab)
const contextMenu = document.createElement('nav')
contextMenu.id = "context-menu"
@@ -36,7 +35,7 @@ contextMenuInit = function(){ menuEntries.forEach(function(entry){
let contextMenuEntry = document.createElement('a')
contextMenuEntry.innerHTML = entry['name']
- contextMenuEntry.addEventListener("click", function(e) {
+ contextMenuEntry.addEventListener("click", function() {
entry['func']();
})
contextMenuList.append(contextMenuEntry);
@@ -63,7 +62,7 @@ contextMenuInit = function(){ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
- currentItems = menuSpecs.get(targetElementSelector)
+ var currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){
currentItems = []
@@ -79,7 +78,7 @@ contextMenuInit = function(){ }
function removeContextMenuOption(uid){
- menuSpecs.forEach(function(v,k) {
+ menuSpecs.forEach(function(v) {
let index = -1
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
if(index>=0){
@@ -93,8 +92,7 @@ contextMenuInit = function(){ return;
}
gradioApp().addEventListener("click", function(e) {
- let source = e.composedPath()[0]
- if(source.id && source.id.indexOf('check_progress')>-1){
+ if(! e.isTrusted){
return
}
@@ -112,7 +110,6 @@ contextMenuInit = function(){ if(e.composedPath()[0].matches(k)){
showContextMenu(e,e.composedPath()[0],v)
e.preventDefault()
- return
}
})
});
@@ -161,14 +158,6 @@ addContextMenuEventListener = initResponse[2]; appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#roll','Roll three',
- function(){
- let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
- setTimeout(function(){rollbutton.click()},100)
- setTimeout(function(){rollbutton.click()},200)
- setTimeout(function(){rollbutton.click()},300)
- }
- )
})();
//End example Context Menu Items
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 20a5aadf..d2c2f190 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -17,7 +17,7 @@ function keyupEditAttention(event){ // Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
- if (beforeParen == -1) return false;
+ if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
@@ -27,7 +27,7 @@ function keyupEditAttention(event){ // Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
- if (afterParen == -1) return false;
+ if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1);
@@ -43,16 +43,34 @@ function keyupEditAttention(event){ target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
+
+ function selectCurrentWord(){
+ if (selectionStart !== selectionEnd) return false;
+ const delimiters = opts.keyedit_delimiters + " \r\n\t";
+
+ // seek backward until to find beggining
+ while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
+ selectionStart--;
+ }
+
+ // seek forward to find end
+ while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
+ selectionEnd++;
+ }
- // If the user hasn't selected anything, let's select their current parenthesis block
- if(! selectCurrentParenthesisBlock('<', '>')){
- selectCurrentParenthesisBlock('(', ')')
+ target.setSelectionRange(selectionStart, selectionEnd);
+ return true;
+ }
+
+ // If the user hasn't selected anything, let's select their current parenthesis block or word
+ if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
+ selectCurrentWord();
}
event.preventDefault();
- closeCharacter = ')'
- delta = opts.keyedit_precision_attention
+ var closeCharacter = ')'
+ var delta = opts.keyedit_precision_attention
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>'
@@ -73,15 +91,21 @@ function keyupEditAttention(event){ selectionEnd += 1;
}
- end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+ var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+ var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0"
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+ if (closeCharacter == ')' && weight == 1) {
+ text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
+ selectionStart--;
+ selectionEnd--;
+ } else {
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+ }
target.focus();
target.value = text;
@@ -93,4 +117,4 @@ function keyupEditAttention(event){ addEventListener('keydown', (event) => {
keyupEditAttention(event);
-});
\ No newline at end of file +});
diff --git a/javascript/extensions.js b/javascript/extensions.js index 72924a28..2a2d2f8e 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -1,14 +1,14 @@ -function extensions_apply(_, _, disable_all){
+function extensions_apply(_disabled_list, _update_list, disable_all){
var disable = []
var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substr(7))
+ disable.push(x.name.substring(7))
if(x.name.startsWith("update_") && x.checked)
- update.push(x.name.substr(7))
+ update.push(x.name.substring(7))
})
restart_reload()
@@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){ return [JSON.stringify(disable), JSON.stringify(update), disable_all]
}
-function extensions_check(_, _){
+function extensions_check(){
var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substr(7))
+ disable.push(x.name.substring(7))
})
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
@@ -41,9 +41,31 @@ function install_extension_from_index(button, url){ button.disabled = "disabled"
button.value = "Installing..."
- textarea = gradioApp().querySelector('#extension_to_install textarea')
+ var textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
updateInput(textarea)
gradioApp().querySelector('#install_extension_button').click()
}
+
+function config_state_confirm_restore(_, config_state_name, config_restore_type) {
+ if (config_state_name == "Current") {
+ return [false, config_state_name, config_restore_type];
+ }
+ let restored = "";
+ if (config_restore_type == "extensions") {
+ restored = "all saved extension versions";
+ } else if (config_restore_type == "webui") {
+ restored = "the webui version";
+ } else {
+ restored = "the webui version and all saved extension versions";
+ }
+ let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
+ if (confirmed) {
+ restart_reload();
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
+ x.innerHTML = "Loading..."
+ })
+ }
+ return [confirmed, config_state_name, config_restore_type];
+}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 25322138..c85bc79a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,4 +1,3 @@ -
function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
@@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){ tabs.appendChild(search)
tabs.appendChild(refresh)
- search.addEventListener("input", function(evt){
- searchTerm = search.value.toLowerCase()
+ var applyFilter = function(){
+ var searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
- text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
- elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
+ var searchOnly = elem.querySelector('.search_only')
+ var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
+
+ var visible = text.indexOf(searchTerm) != -1
+
+ if(searchOnly && searchTerm.length < 4){
+ visible = false
+ }
+
+ elem.style.display = visible ? "" : "none"
})
- });
+ }
+
+ search.addEventListener("input", applyFilter);
+ applyFilter();
+
+ extraNetworksApplyFilter[tabname] = applyFilter;
+}
+
+function applyExtraNetworkFilter(tabname){
+ setTimeout(extraNetworksApplyFilter[tabname], 1);
}
+var extraNetworksApplyFilter = {}
var activePromptTextarea = {};
function setupExtraNetworks(){
@@ -55,7 +72,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text){ var partToSearch = m[1]
var replaced = false
- var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
+ var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
m = found.match(re_extranet);
if(m[1] == partToSearch){
replaced = true;
@@ -96,9 +113,9 @@ function saveCardPreview(event, tabname, filename){ }
function extraNetworksSearchButton(tabs_id, event){
- searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
- button = event.target
- text = button.classList.contains("search-all") ? "" : button.textContent.trim()
+ var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
+ var button = event.target
+ var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
searchTextarea.value = text
updateInput(searchTextarea)
@@ -133,7 +150,7 @@ function popup(contents){ }
function extraNetworksShowMetadata(text){
- elem = document.createElement('pre')
+ var elem = document.createElement('pre')
elem.classList.add('popup-metadata');
elem.textContent = text;
@@ -165,7 +182,7 @@ function requestGet(url, data, handler, errorHandler){ }
function extraNetworksRequestMetadata(event, extraPage, cardName){
- showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
+ var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){
diff --git a/javascript/generationParams.js b/javascript/generationParams.js index 95f05093..ef64ee2e 100644 --- a/javascript/generationParams.js +++ b/javascript/generationParams.js @@ -16,14 +16,14 @@ onUiUpdate(function(){ let modalObserver = new MutationObserver(function(mutations) { mutations.forEach(function(mutationRecord) { - let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText - if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') - gradioApp().getElementById(selectedTab+"_generation_info_button").click() + let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText + if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) + gradioApp().getElementById(selectedTab+"_generation_info_button")?.click() }); }); function attachGalleryListeners(tab_name) { - gallery = gradioApp().querySelector('#'+tab_name+'_gallery') + var gallery = gradioApp().querySelector('#'+tab_name+'_gallery') gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); gallery?.addEventListener('keydown', (e) => { if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow diff --git a/javascript/hints.js b/javascript/hints.js index f48a0eb6..3746df99 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -22,6 +22,7 @@ titles = { "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", "\u{1f3b4}": "Show/hide extra networks", + "\u{1f300}": "Restore progress", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", @@ -65,8 +66,8 @@ titles = { "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", - "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", - "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", + "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", + "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.", @@ -85,7 +86,6 @@ titles = { "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", - "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.", @@ -111,22 +111,25 @@ titles = { "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", - "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited." + "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.", + "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." } onUiUpdate(function(){ gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){ - tooltip = titles[span.textContent]; + if (span.title) return; // already has a title - if(!tooltip){ - tooltip = titles[span.value]; + let tooltip = localization[titles[span.textContent]] || titles[span.textContent]; + + if(!tooltip){ + tooltip = localization[titles[span.value]] || titles[span.value]; } if(!tooltip){ for (const c of span.classList) { if (c in titles) { - tooltip = titles[c]; + tooltip = localization[titles[c]] || titles[c]; break; } } @@ -141,7 +144,7 @@ onUiUpdate(function(){ if (select.onchange != null) return; select.onchange = function(){ - select.title = titles[select.value] || ""; + select.title = localization[titles[select.value]] || titles[select.value] || ""; } }) }) diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js index 0629475f..48196be4 100644 --- a/javascript/hires_fix.js +++ b/javascript/hires_fix.js @@ -1,16 +1,12 @@ -function setInactive(elem, inactive){
- if(inactive){
- elem.classList.add('inactive')
- } else{
- elem.classList.remove('inactive')
+function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
+ function setInactive(elem, inactive){
+ elem.classList.toggle('inactive', !!inactive)
}
-}
-function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
- hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
- hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
- hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
+ var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
+ var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
+ var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js index 9fe7a603..a612705d 100644 --- a/javascript/imageMaskFix.js +++ b/javascript/imageMaskFix.js @@ -2,11 +2,10 @@ * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 * @see https://github.com/gradio-app/gradio/issues/1721 */ -window.addEventListener( 'resize', () => imageMaskResize()); function imageMaskResize() { const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); if ( ! canvases.length ) { - canvases_fixed = false; + canvases_fixed = false; // TODO: this is unused..? window.removeEventListener( 'resize', imageMaskResize ); return; } @@ -15,7 +14,7 @@ function imageMaskResize() { const previewImage = wrapper.previousElementSibling; if ( ! previewImage.complete ) { - previewImage.addEventListener( 'load', () => imageMaskResize()); + previewImage.addEventListener( 'load', imageMaskResize); return; } @@ -24,7 +23,6 @@ function imageMaskResize() { const nw = previewImage.naturalWidth; const nh = previewImage.naturalHeight; const portrait = nh > nw; - const factor = portrait; const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); @@ -40,6 +38,7 @@ function imageMaskResize() { c.style.maxHeight = '100%'; c.style.objectFit = 'contain'; }); - } - - onUiUpdate(() => imageMaskResize()); +} + +onUiUpdate(imageMaskResize); +window.addEventListener( 'resize', imageMaskResize); diff --git a/javascript/imageParams.js b/javascript/imageParams.js index 67404a89..64aee93b 100644 --- a/javascript/imageParams.js +++ b/javascript/imageParams.js @@ -1,7 +1,6 @@ window.onload = (function(){ window.addEventListener('drop', e => { const target = e.composedPath()[0]; - const idx = selected_gallery_index(); if (target.placeholder.indexOf("Prompt") == -1) return; let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index d6483562..32066ab8 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -57,7 +57,7 @@ function modalImageSwitch(offset) { }) if (result != -1) { - nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)] + var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)] nextButton.click() const modalImage = gradioApp().getElementById("modalImage"); const modal = gradioApp().getElementById("lightboxModal"); @@ -144,15 +144,11 @@ function setupImageForLightbox(e) { } function modalZoomSet(modalImage, enable) { - if (enable) { - modalImage.classList.add('modalImageFullscreen'); - } else { - modalImage.classList.remove('modalImageFullscreen'); - } + if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable); } function modalZoomToggle(event) { - modalImage = gradioApp().getElementById("modalImage"); + var modalImage = gradioApp().getElementById("modalImage"); modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen')) event.stopPropagation() } @@ -179,7 +175,7 @@ function galleryImageHandler(e) { } onUiUpdate(function() { - fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img') + var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img') if (fullImg_preview != null) { fullImg_preview.forEach(setupImageForLightbox); } @@ -251,8 +247,11 @@ document.addEventListener("DOMContentLoaded", function() { modal.appendChild(modalNext) - gradioApp().appendChild(modal) - + try { + gradioApp().appendChild(modal); + } catch (e) { + gradioApp().body.appendChild(modal); + } document.body.appendChild(modal); diff --git a/javascript/imageviewerGamepad.js b/javascript/imageviewerGamepad.js new file mode 100644 index 00000000..6297a12b --- /dev/null +++ b/javascript/imageviewerGamepad.js @@ -0,0 +1,57 @@ +window.addEventListener('gamepadconnected', (e) => { + const index = e.gamepad.index; + let isWaiting = false; + setInterval(async () => { + if (!opts.js_modal_lightbox_gamepad || isWaiting) return; + const gamepad = navigator.getGamepads()[index]; + const xValue = gamepad.axes[0]; + if (xValue <= -0.3) { + modalPrevImage(e); + isWaiting = true; + } else if (xValue >= 0.3) { + modalNextImage(e); + isWaiting = true; + } + if (isWaiting) { + await sleepUntil(() => { + const xValue = navigator.getGamepads()[index].axes[0] + if (xValue < 0.3 && xValue > -0.3) { + return true; + } + }, opts.js_modal_lightbox_gamepad_repeat); + isWaiting = false; + } + }, 10); +}); + +/* +Primarily for vr controller type pointer devices. +I use the wheel event because there's currently no way to do it properly with web xr. + */ +let isScrolling = false; +window.addEventListener('wheel', (e) => { + if (!opts.js_modal_lightbox_gamepad || isScrolling) return; + isScrolling = true; + + if (e.deltaX <= -0.6) { + modalPrevImage(e); + } else if (e.deltaX >= 0.6) { + modalNextImage(e); + } + + setTimeout(() => { + isScrolling = false; + }, opts.js_modal_lightbox_gamepad_repeat); +}); + +function sleepUntil(f, timeout) { + return new Promise((resolve) => { + const timeStart = new Date(); + const wait = setInterval(function() { + if (f() || new Date() - timeStart > timeout) { + clearInterval(wait); + resolve(); + } + }, 20); + }); +} diff --git a/javascript/localization.js b/javascript/localization.js index 1a5a1dbb..0123b877 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u original_lines = {}
translated_lines = {}
+function hasLocalization() {
+ return window.localization && Object.keys(window.localization).length > 0;
+}
+
function textNodesUnder(el){
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
while(n=walk.nextNode()) a.push(n);
@@ -35,11 +39,11 @@ function canBeTranslated(node, text){ if(! text) return false;
if(! node.parentElement) return false;
- parentType = node.parentElement.nodeName
+ var parentType = node.parentElement.nodeName
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){
- pnode = node
+ var pnode = node
for(var level=0; level<4; level++){
pnode = pnode.parentElement
if(! pnode) break;
@@ -69,7 +73,7 @@ function getTranslation(text){ }
function processTextNode(node){
- text = node.textContent.trim()
+ var text = node.textContent.trim()
if(! canBeTranslated(node, text)) return
@@ -105,7 +109,7 @@ function processNode(node){ }
function dumpTranslations(){
- dumped = {}
+ var dumped = {}
if (localization.rtl) {
dumped.rtl = true
}
@@ -119,39 +123,8 @@ function dumpTranslations(){ return dumped
}
-onUiUpdate(function(m){
- m.forEach(function(mutation){
- mutation.addedNodes.forEach(function(node){
- processNode(node)
- })
- });
-})
-
-
-document.addEventListener("DOMContentLoaded", function() {
- processNode(gradioApp())
-
- if (localization.rtl) { // if the language is from right to left,
- (new MutationObserver((mutations, observer) => { // wait for the style to load
- mutations.forEach(mutation => {
- mutation.addedNodes.forEach(node => {
- if (node.tagName === 'STYLE') {
- observer.disconnect();
-
- for (const x of node.sheet.rules) { // find all rtl media rules
- if (Array.from(x.media || []).includes('rtl')) {
- x.media.appendMedium('all'); // enable them
- }
- }
- }
- })
- });
- })).observe(gradioApp(), { childList: true });
- }
-})
-
function download_localization() {
- text = JSON.stringify(dumpTranslations(), null, 4)
+ var text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
@@ -163,3 +136,36 @@ function download_localization() { document.body.removeChild(element);
}
+
+if(hasLocalization()) {
+ onUiUpdate(function (m) {
+ m.forEach(function (mutation) {
+ mutation.addedNodes.forEach(function (node) {
+ processNode(node)
+ })
+ });
+ })
+
+
+ document.addEventListener("DOMContentLoaded", function () {
+ processNode(gradioApp())
+
+ if (localization.rtl) { // if the language is from right to left,
+ (new MutationObserver((mutations, observer) => { // wait for the style to load
+ mutations.forEach(mutation => {
+ mutation.addedNodes.forEach(node => {
+ if (node.tagName === 'STYLE') {
+ observer.disconnect();
+
+ for (const x of node.sheet.rules) { // find all rtl media rules
+ if (Array.from(x.media || []).includes('rtl')) {
+ x.media.appendMedium('all'); // enable them
+ }
+ }
+ }
+ })
+ });
+ })).observe(gradioApp(), { childList: true });
+ }
+ })
+}
diff --git a/javascript/notification.js b/javascript/notification.js index 8ddd4c5d..83fce1f8 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -2,15 +2,15 @@ let lastHeadImg = null; -notificationButton = null +let notificationButton = null; onUiUpdate(function(){ if(notificationButton == null){ notificationButton = gradioApp().getElementById('request_notifications') if(notificationButton != null){ - notificationButton.addEventListener('click', function (evt) { - Notification.requestPermission(); + notificationButton.addEventListener('click', () => { + void Notification.requestPermission(); },true); } } diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 4ac9b8db..8d2c3492 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,16 +1,15 @@ // code related to showing and updating progressbar shown as the image is being made -function rememberGallerySelection(id_gallery){ +function rememberGallerySelection(){ } -function getGallerySelectedIndex(id_gallery){ +function getGallerySelectedIndex(){ } function request(url, data, handler, errorHandler){ var xhr = new XMLHttpRequest(); - var url = url; xhr.open("POST", url, true); xhr.setRequestHeader("Content-Type", "application/json"); xhr.onreadystatechange = function () { @@ -66,7 +65,7 @@ function randomId(){ // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and // preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. // calls onProgress every time there is a progress update -function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){ +function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){ var dateStart = new Date() var wasEverActive = false var parentProgressbar = progressbarContainer.parentNode @@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre divProgress.style.width = rect.width + "px"; } - progressText = "" + let progressText = "" divInner.style.width = ((res.progress || 0) * 100.0) + '%' divInner.style.background = res.progress ? "" : "transparent" @@ -138,7 +137,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre return } - if(elapsedFromStart > 5 && !res.queued && !res.active){ + if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){ removeProgressBar() return } diff --git a/javascript/ui.js b/javascript/ui.js index 4a440193..ed9673d6 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -1,7 +1,7 @@ // various functions for interaction with ui.py not large enough to warrant putting them in separate files function set_theme(theme){ - gradioURL = window.location.href + var gradioURL = window.location.href if (!gradioURL.includes('?__theme=')) { window.location.replace(gradioURL + '?__theme=' + theme); } @@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){ return [gallery[0]]; } - index = selected_gallery_index() + var index = selected_gallery_index() if (index < 0 || index >= gallery.length){ // Use the first image in the gallery as the default @@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){ } function args_to_array(args){ - res = [] + var res = [] for(var i=0;i<args.length;i++){ res.push(args[i]) } @@ -138,7 +138,7 @@ function get_img2img_tab_index() { } function create_submit_args(args){ - res = [] + var res = [] for(var i=0;i<args.length;i++){ res.push(args[i]) } @@ -159,14 +159,24 @@ function showSubmitButtons(tabname, show){ gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block" } +function showRestoreProgressButton(tabname, show){ + var button = gradioApp().getElementById(tabname + "_restore_progress") + if(! button) return + + button.style.display = show ? "flex" : "none" +} + function submit(){ rememberGallerySelection('txt2img_gallery') showSubmitButtons('txt2img', false) var id = randomId() + localStorage.setItem("txt2img_task_id", id); + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ showSubmitButtons('txt2img', true) - + localStorage.removeItem("txt2img_task_id") + showRestoreProgressButton('txt2img', false) }) var res = create_submit_args(arguments) @@ -181,8 +191,12 @@ function submit_img2img(){ showSubmitButtons('img2img', false) var id = randomId() + localStorage.setItem("img2img_task_id", id); + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ showSubmitButtons('img2img', true) + localStorage.removeItem("img2img_task_id") + showRestoreProgressButton('img2img', false) }) var res = create_submit_args(arguments) @@ -193,6 +207,42 @@ function submit_img2img(){ return res } +function restoreProgressTxt2img(){ + showRestoreProgressButton("txt2img", false) + var id = localStorage.getItem("txt2img_task_id") + + id = localStorage.getItem("txt2img_task_id") + + if(id) { + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ + showSubmitButtons('txt2img', true) + }, null, 0) + } + + return id +} + +function restoreProgressImg2img(){ + showRestoreProgressButton("img2img", false) + + var id = localStorage.getItem("img2img_task_id") + + if(id) { + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ + showSubmitButtons('img2img', true) + }, null, 0) + } + + return id +} + + +onUiLoaded(function () { + showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id")) + showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id")) +}); + + function modelmerger(){ var id = randomId() requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) @@ -204,7 +254,7 @@ function modelmerger(){ function ask_for_style_name(_, prompt_text, negative_prompt_text) { - name_ = prompt('Style name:') + var name_ = prompt('Style name:') return [name_, prompt_text, negative_prompt_text] } @@ -239,11 +289,11 @@ function recalculate_prompts_img2img(){ } -opts = {} +var opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; - json_elem = gradioApp().getElementById('settings_json') + var json_elem = gradioApp().getElementById('settings_json') if(json_elem == null) return; var textarea = json_elem.querySelector('textarea') @@ -292,12 +342,15 @@ onUiUpdate(function(){ registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button') registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button') - show_all_pages = gradioApp().getElementById('settings_show_all_pages') - settings_tabs = gradioApp().querySelector('#settings div') + var show_all_pages = gradioApp().getElementById('settings_show_all_pages') + var settings_tabs = gradioApp().querySelector('#settings div') if(show_all_pages && settings_tabs){ settings_tabs.appendChild(show_all_pages) show_all_pages.onclick = function(){ gradioApp().querySelectorAll('#settings > div').forEach(function(elem){ + if(elem.id == "settings_tab_licenses") + return; + elem.style.display = "block"; }) } @@ -305,9 +358,9 @@ onUiUpdate(function(){ }) onOptionsChanged(function(){ - elem = gradioApp().getElementById('sd_checkpoint_hash') - sd_checkpoint_hash = opts.sd_checkpoint_hash || "" - shorthash = sd_checkpoint_hash.substr(0,10) + var elem = gradioApp().getElementById('sd_checkpoint_hash') + var sd_checkpoint_hash = opts.sd_checkpoint_hash || "" + var shorthash = sd_checkpoint_hash.substring(0,10) if(elem && elem.textContent != shorthash){ elem.textContent = shorthash @@ -342,7 +395,16 @@ function update_token_counter(button_id) { function restart_reload(){ document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; - setTimeout(function(){location.reload()},2000) + + var requestPing = function(){ + requestGet("./internal/ping", {}, function(data){ + location.reload(); + }, function(){ + setTimeout(requestPing, 500); + }) + } + + setTimeout(requestPing, 2000); return [] } @@ -361,3 +423,19 @@ function selectCheckpoint(name){ desiredCheckpointName = name; gradioApp().getElementById('change_checkpoint').click() } + +function currentImg2imgSourceResolution(_, _, scaleBy){ + var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img') + return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy] +} + +function updateImg2imgResizeToTextAfterChangingImage(){ + // At the time this is called from gradio, the image has no yet been replaced. + // There may be a better solution, but this is simple and straightforward so I'm going with it. + + setTimeout(function() { + gradioApp().getElementById('img2img_update_resize_to').click() + }, 500); + + return [] +} diff --git a/javascript/ui_settings_hints.js b/javascript/ui_settings_hints.js new file mode 100644 index 00000000..87a289d3 --- /dev/null +++ b/javascript/ui_settings_hints.js @@ -0,0 +1,41 @@ +// various hints and extra info for the settings tab
+
+onUiLoaded(function(){
+ createLink = function(elem_id, text, href){
+ var a = document.createElement('A')
+ a.textContent = text
+ a.target = '_blank';
+
+ elem = gradioApp().querySelector('#'+elem_id)
+ elem.insertBefore(a, elem.querySelector('label'))
+
+ return a
+ }
+
+ createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
+ createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
+
+ createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){
+ requestGet("./internal/quicksettings-hint", {}, function(data){
+ var table = document.createElement('table')
+ table.className = 'settings-value-table'
+
+ data.forEach(function(obj){
+ var tr = document.createElement('tr')
+ var td = document.createElement('td')
+ td.textContent = obj.name
+ tr.appendChild(td)
+
+ var td = document.createElement('td')
+ td.textContent = obj.label
+ tr.appendChild(td)
+
+ table.appendChild(tr)
+ })
+
+ popup(table);
+ })
+ });
+})
+
+
@@ -19,7 +19,7 @@ python = sys.executable git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
-skip_install = False
+stored_git_tag = None
dir_repos = "repositories"
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
@@ -49,7 +49,7 @@ or any other error regarding unsuccessful package (library) installation, please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
-You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
@@ -71,6 +71,20 @@ def commit_hash(): return stored_commit_hash
+def git_tag():
+ global stored_git_tag
+
+ if stored_git_tag is not None:
+ return stored_git_tag
+
+ try:
+ stored_git_tag = run(f"{git} describe --tags").strip()
+ except Exception:
+ stored_git_tag = "<none>"
+
+ return stored_git_tag
+
+
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
@@ -121,12 +135,12 @@ def run_python(code, desc=None, errdesc=None): return run(f'"{python}" -c "{code}"', desc, errdesc)
-def run_pip(args, desc=None):
- if skip_install:
+def run_pip(command, desc=None, live=False):
+ if args.skip_install:
return
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
- return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
+ return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run_python(code):
@@ -223,12 +237,10 @@ def run_extensions_installers(settings_file): def prepare_environment():
- global skip_install
-
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
@@ -249,8 +261,10 @@ def prepare_environment(): check_python_version()
commit = commit_hash()
+ tag = git_tag()
print(f"Python {sys.version}")
+ print(f"Version: {tag}")
print(f"Commit hash: {commit}")
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
@@ -271,7 +285,7 @@ def prepare_environment(): if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
- run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
@@ -296,7 +310,7 @@ def prepare_environment(): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
- run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file)
diff --git a/modules/api/api.py b/modules/api/api.py index 518b2a61..594fa655 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,7 +6,6 @@ import uvicorn import gradio as gr from threading import Lock from io import BytesIO -from gradio.processing_utils import decode_base64_to_file from fastapi import APIRouter, Depends, FastAPI, Request, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.exceptions import HTTPException @@ -16,7 +15,8 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing -from modules.api.models import * +from modules.api import models +from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess @@ -26,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices -from typing import List +from typing import Dict, List, Any import piexif import piexif.helper + def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}") + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e + def script_name_to_index(name, scripts): try: return [script.title().lower() for script in scripts].index(name.lower()) - except: - raise HTTPException(status_code=422, detail=f"Script '{name}' not found") + except Exception as e: + raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e + def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) @@ -49,20 +52,23 @@ def validate_sampler_name(name): return name + def setUpscalers(req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict + def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] try: image = Image.open(BytesIO(base64.b64decode(encoding))) return image - except Exception as err: - raise HTTPException(status_code=500, detail="Invalid encoded image") + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: @@ -93,6 +99,7 @@ def encode_pil_to_base64(image): return base64.b64encode(bytes_data) + def api_middleware(app: FastAPI): rich_available = True try: @@ -100,7 +107,7 @@ def api_middleware(app: FastAPI): import starlette # importing just so it can be placed on silent list from rich.console import Console console = Console() - except: + except Exception: import traceback rich_available = False @@ -131,8 +138,8 @@ def api_middleware(app: FastAPI): "body": vars(e).get('body', ''), "errors": str(e), } - print(f"API error: {request.method}: {request.url} {err}") if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + print(f"API error: {request.method}: {request.url} {err}") if rich_available: console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: @@ -158,7 +165,7 @@ def api_middleware(app: FastAPI): class Api: def __init__(self, app: FastAPI, queue_lock: Lock): if shared.cmd_opts.api_auth: - self.credentials = dict() + self.credentials = {} for auth in shared.cmd_opts.api_auth.split(","): user, password = auth.split(":") self.credentials[user] = password @@ -167,36 +174,36 @@ class Api: self.app = app self.queue_lock = queue_lock api_middleware(self.app) - self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) - self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) - self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) + self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse) + self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse) + self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) + self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) + self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) + self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) - self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) + self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) - self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) - self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) - self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) - self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) - self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) - self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) - self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) + self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) - self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) - self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) - self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse) - self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) - self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) - self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) + self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) + self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) + self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) + self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) + self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) + self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) - self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) + self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -225,7 +232,7 @@ class Api: t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] - return ScriptsList(txt2img = t2ilist, img2img = i2ilist) + return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist) def get_script(self, script_name, script_runner): if script_name is None or script_name == "": @@ -265,17 +272,19 @@ class Api: if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) - if alwayson_script == None: + if alwayson_script is None: raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") # Selectable script in always on script param check - if alwayson_script.alwayson == False: - raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + if alwayson_script.alwayson is False: + raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params") # always on script with no arg should always run so you don't really need to add them to the requests if "args" in request.alwayson_scripts[alwayson_script_name]: - script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"] + # min between arg length in scriptrunner and arg length in the request + for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))): + script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args - def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -309,7 +318,7 @@ class Api: p.outpath_samples = opts.outdir_txt2img_samples shared.state.begin() - if selectable_scripts != None: + if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: @@ -319,9 +328,9 @@ class Api: b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] - return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) + return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) - def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): + def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -366,7 +375,7 @@ class Api: p.outpath_samples = opts.outdir_img2img_samples shared.state.begin() - if selectable_scripts != None: + if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: @@ -380,9 +389,9 @@ class Api: img2imgreq.init_images = None img2imgreq.mask = None - return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) + return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) - def extras_single_image_api(self, req: ExtrasSingleImageRequest): + def extras_single_image_api(self, req: models.ExtrasSingleImageRequest): reqDict = setUpscalers(req) reqDict['image'] = decode_base64_to_image(reqDict['image']) @@ -390,31 +399,26 @@ class Api: with self.queue_lock: result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) - return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) + return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) - def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): + def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest): reqDict = setUpscalers(req) - def prepareFiles(file): - file = decode_base64_to_file(file.data, file_path=file.name) - file.orig_name = file.name - return file - - reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList'])) - reqDict.pop('imageList') + image_list = reqDict.pop('imageList', []) + image_folder = [decode_base64_to_image(x.data) for x in image_list] with self.queue_lock: - result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) + return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - def pnginfoapi(self, req: PNGInfoRequest): + def pnginfoapi(self, req: models.PNGInfoRequest): if(not req.image.strip()): - return PNGInfoResponse(info="") + return models.PNGInfoResponse(info="") image = decode_base64_to_image(req.image.strip()) if image is None: - return PNGInfoResponse(info="") + return models.PNGInfoResponse(info="") geninfo, items = images.read_info_from_image(image) if geninfo is None: @@ -422,13 +426,13 @@ class Api: items = {**{'parameters': geninfo}, **items} - return PNGInfoResponse(info=geninfo, items=items) + return models.PNGInfoResponse(info=geninfo, items=items) - def progressapi(self, req: ProgressRequest = Depends()): + def progressapi(self, req: models.ProgressRequest = Depends()): # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) # avoid dividing zero progress = 0.01 @@ -450,9 +454,9 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) - def interrogateapi(self, interrogatereq: InterrogateRequest): + def interrogateapi(self, interrogatereq: models.InterrogateRequest): image_b64 = interrogatereq.image if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") @@ -469,7 +473,7 @@ class Api: else: raise HTTPException(status_code=404, detail="Model not found") - return InterrogateResponse(caption=processed) + return models.InterrogateResponse(caption=processed) def interruptapi(self): shared.state.interrupt() @@ -574,36 +578,36 @@ class Api: filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used shared.state.end() - return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename)) + return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: shared.state.end() - return TrainResponse(info = "create embedding error: {error}".format(error = e)) + return models.TrainResponse(info=f"create embedding error: {e}") def create_hypernetwork(self, args: dict): try: shared.state.begin() filename = create_hypernetwork(**args) # create empty embedding shared.state.end() - return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename)) + return models.CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: shared.state.end() - return TrainResponse(info = "create hypernetwork error: {error}".format(error = e)) + return models.TrainResponse(info=f"create hypernetwork error: {e}") def preprocess(self, args: dict): try: shared.state.begin() preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() - return PreprocessResponse(info = 'preprocess complete') + return models.PreprocessResponse(info = 'preprocess complete') except KeyError as e: shared.state.end() - return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e)) + return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") except AssertionError as e: shared.state.end() - return PreprocessResponse(info = "preprocess error: {error}".format(error = e)) + return models.PreprocessResponse(info=f"preprocess error: {e}") except FileNotFoundError as e: shared.state.end() - return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) + return models.PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: @@ -621,10 +625,10 @@ class Api: if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {msg}".format(msg = msg)) + return models.TrainResponse(info=f"train embedding error: {msg}") def train_hypernetwork(self, args: dict): try: @@ -645,14 +649,15 @@ class Api: if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) - except AssertionError as msg: + return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") + except AssertionError: shared.state.end() - return TrainResponse(info="train embedding error: {error}".format(error=error)) + return models.TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: - import os, psutil + import os + import psutil process = psutil.Process(os.getpid()) res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe @@ -679,10 +684,10 @@ class Api: 'events': warnings, } else: - cuda = { 'error': 'unavailable' } + cuda = {'error': 'unavailable'} except Exception as err: - cuda = { 'error': f'{err}' } - return MemoryResponse(ram = ram, cuda = cuda) + cuda = {'error': f'{err}'} + return models.MemoryResponse(ram=ram, cuda=cuda) def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index 4a70f440..4d291076 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -223,8 +223,9 @@ for key in _options: if(_options[key].dest != 'help'): flag = _options[key] _type = str - if _options[key].default is not None: _type = type(_options[key].default) - flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) + if _options[key].default is not None: + _type = type(_options[key].default) + flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))}) FlagsModel = create_model("Flags", **flags) diff --git a/modules/call_queue.py b/modules/call_queue.py index 92097c15..447bb764 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): try:
res = func(*args, **kwargs)
+ progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)
@@ -59,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {str(args)} {str(kwargs)}"
+ argStr = f"Arguments: {args} {kwargs}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
@@ -72,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): if extra_outputs_array is None:
extra_outputs_array = [None, '']
- res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
+ error_message = f'{type(e).__name__}: {e}'
+ res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 81c0b82a..e01ca655 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -1,6 +1,6 @@ import argparse
import os
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser()
@@ -95,9 +95,11 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin( parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
+parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
\ No newline at end of file diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 11dcc3ee..45c70f84 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -1,14 +1,12 @@ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py import math -import numpy as np import torch from torch import nn, Tensor import torch.nn.functional as F -from typing import Optional, List +from typing import Optional -from modules.codeformer.vqgan_arch import * -from basicsr.utils import get_root_logger +from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock from basicsr.utils.registry import ARCH_REGISTRY def calc_mean_std(feat, eps=1e-5): @@ -163,8 +161,8 @@ class Fuse_sft_block(nn.Module): class CodeFormer(VQAutoEncoder): def __init__(self, dim_embd=512, n_head=8, n_layers=9, codebook_size=1024, latent_size=256, - connect_list=['32', '64', '128', '256'], - fix_modules=['quantize','generator']): + connect_list=('32', '64', '128', '256'), + fix_modules=('quantize', 'generator')): super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) if fix_modules is not None: diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index e7293683..b24a0394 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -5,11 +5,9 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py ''' -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import copy from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY @@ -328,7 +326,7 @@ class Generator(nn.Module): @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() @@ -339,7 +337,7 @@ class VQAutoEncoder(nn.Module): self.embed_dim = emb_dim self.ch_mult = ch_mult self.resolution = img_size - self.attn_resolutions = attn_resolutions + self.attn_resolutions = attn_resolutions or [16] self.quantizer_type = quantizer self.encoder = Encoder( self.in_channels, diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 8d84bbc9..ececdbae 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -33,11 +33,9 @@ def setup_model(dirname): try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils.download_util import load_file_from_url
- from basicsr.utils import imwrite, img2tensor, tensor2img
+ from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
- from modules.shared import cmd_opts
net_class = CodeFormer
@@ -96,7 +94,7 @@ def setup_model(dirname): self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
- for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
diff --git a/modules/config_states.py b/modules/config_states.py new file mode 100644 index 00000000..75da862a --- /dev/null +++ b/modules/config_states.py @@ -0,0 +1,200 @@ +""" +Supports saving and restoring webui and extensions from a known working set of commits +""" + +import os +import sys +import traceback +import json +import time +import tqdm + +from datetime import datetime +from collections import OrderedDict +import git + +from modules import shared, extensions +from modules.paths_internal import script_path, config_states_dir + + +all_config_states = OrderedDict() + + +def list_config_states(): + global all_config_states + + all_config_states.clear() + os.makedirs(config_states_dir, exist_ok=True) + + config_states = [] + for filename in os.listdir(config_states_dir): + if filename.endswith(".json"): + path = os.path.join(config_states_dir, filename) + with open(path, "r", encoding="utf-8") as f: + j = json.load(f) + j["filepath"] = path + config_states.append(j) + + config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) + + for cs in config_states: + timestamp = time.asctime(time.gmtime(cs["created_at"])) + name = cs.get("name", "Config") + full_name = f"{name}: {timestamp}" + all_config_states[full_name] = cs + + return all_config_states + + +def get_webui_config(): + webui_repo = None + + try: + if os.path.exists(os.path.join(script_path, ".git")): + webui_repo = git.Repo(script_path) + except Exception: + print(f"Error reading webui git info from {script_path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + webui_remote = None + webui_commit_hash = None + webui_commit_date = None + webui_branch = None + if webui_repo and not webui_repo.bare: + try: + webui_remote = next(webui_repo.remote().urls, None) + head = webui_repo.head.commit + webui_commit_date = webui_repo.head.commit.committed_date + webui_commit_hash = head.hexsha + webui_branch = webui_repo.active_branch.name + + except Exception: + webui_remote = None + + return { + "remote": webui_remote, + "commit_hash": webui_commit_hash, + "commit_date": webui_commit_date, + "branch": webui_branch, + } + + +def get_extension_config(): + ext_config = {} + + for ext in extensions.extensions: + entry = { + "name": ext.name, + "path": ext.path, + "enabled": ext.enabled, + "is_builtin": ext.is_builtin, + "remote": ext.remote, + "commit_hash": ext.commit_hash, + "commit_date": ext.commit_date, + "branch": ext.branch, + "have_info_from_repo": ext.have_info_from_repo + } + + ext_config[ext.name] = entry + + return ext_config + + +def get_config(): + creation_time = datetime.now().timestamp() + webui_config = get_webui_config() + ext_config = get_extension_config() + + return { + "created_at": creation_time, + "webui": webui_config, + "extensions": ext_config + } + + +def restore_webui_config(config): + print("* Restoring webui state...") + + if "webui" not in config: + print("Error: No webui data saved to config") + return + + webui_config = config["webui"] + + if "commit_hash" not in webui_config: + print("Error: No commit saved to webui config") + return + + webui_commit_hash = webui_config.get("commit_hash", None) + webui_repo = None + + try: + if os.path.exists(os.path.join(script_path, ".git")): + webui_repo = git.Repo(script_path) + except Exception: + print(f"Error reading webui git info from {script_path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return + + try: + webui_repo.git.fetch(all=True) + webui_repo.git.reset(webui_commit_hash, hard=True) + print(f"* Restored webui to commit {webui_commit_hash}.") + except Exception: + print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + +def restore_extension_config(config): + print("* Restoring extension state...") + + if "extensions" not in config: + print("Error: No extension data saved to config") + return + + ext_config = config["extensions"] + + results = [] + disabled = [] + + for ext in tqdm.tqdm(extensions.extensions): + if ext.is_builtin: + continue + + ext.read_info_from_repo() + current_commit = ext.commit_hash + + if ext.name not in ext_config: + ext.disabled = True + disabled.append(ext.name) + results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled")) + continue + + entry = ext_config[ext.name] + + if "commit_hash" in entry and entry["commit_hash"]: + try: + ext.fetch_and_reset_hard(entry["commit_hash"]) + ext.read_info_from_repo() + if current_commit != entry["commit_hash"]: + results.append((ext, current_commit[:8], True, entry["commit_hash"][:8])) + except Exception as ex: + results.append((ext, current_commit[:8], False, ex)) + else: + results.append((ext, current_commit[:8], False, "No commit hash found in config")) + + if not entry.get("enabled", False): + ext.disabled = True + disabled.append(ext.name) + else: + ext.disabled = False + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + print("* Finished restoring extensions. Results:") + for ext, prev_commit, success, result in results: + if success: + print(f" + {ext.name}: {prev_commit} -> {result}") + else: + print(f" ! {ext.name}: FAILURE ({result})") diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 122fce7f..547e1b4c 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -2,7 +2,6 @@ import os import re import torch -from PIL import Image import numpy as np from modules import modelloader, paths, deepbooru_model, devices, images, shared @@ -79,7 +78,7 @@ class DeepDanbooru: res = [] - filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")]) + filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")} for tag in [x for x in tags if x not in filtertags]: probability = probability_dict[tag] diff --git a/modules/devices.py b/modules/devices.py index 52c3e7cd..d8a34a0f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -65,7 +65,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): + if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,14 +92,18 @@ def cond_cast_float(input): def randn(seed, shape): + from modules.shared import opts + torch.manual_seed(seed) - if device.type == 'mps': + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def randn_without_seed(shape): - if device.type == 'mps': + from modules.shared import opts + + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 9a9c38f1..a009eb42 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -6,7 +6,7 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
-from modules import shared, modelloader, images, devices
+from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
@@ -16,9 +16,7 @@ def mod2normal(state_dict): # this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23): if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -156,13 +152,16 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str):
if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
- file_name="%s.pth" % self.model_name,
- progress=True)
+ filename = load_file_from_url(
+ url=self.model_url,
+ model_dir=self.model_path,
+ file_name=f"{self.model_name}.pth",
+ progress=True,
+ )
else:
filename = path
if not os.path.exists(filename) or filename is None:
- print("Unable to load %s from %s" % (self.model_path, filename))
+ print(f"Unable to load {self.model_path} from {filename}")
return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py index 1b52b0f5..4de9dd8d 100644 --- a/modules/esrgan_model_arch.py +++ b/modules/esrgan_model_arch.py @@ -2,7 +2,6 @@ from collections import OrderedDict
import math
-import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -38,7 +37,7 @@ class RRDBNet(nn.Module): elif upsample_mode == 'pixelshuffle':
upsample_block = pixelshuffle_block
else:
- raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+ raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
else:
@@ -261,10 +260,10 @@ class Upsample(nn.Module): def extra_repr(self):
if self.scale_factor is not None:
- info = 'scale_factor=' + str(self.scale_factor)
+ info = f'scale_factor={self.scale_factor}'
else:
- info = 'size=' + str(self.size)
- info += ', mode=' + self.mode
+ info = f'size={self.size}'
+ info += f', mode={self.mode}'
return info
@@ -350,7 +349,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): elif act_type == 'sigmoid': # [0, 1] range output
layer = nn.Sigmoid()
else:
- raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+ raise NotImplementedError(f'activation layer [{act_type}] is not found')
return layer
@@ -372,7 +371,7 @@ def norm(norm_type, nc): elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
- raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+ raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
return layer
@@ -388,7 +387,7 @@ def pad(pad_type, padding): elif pad_type == 'zero':
layer = nn.ZeroPad2d(padding)
else:
- raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+ raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
return layer
@@ -432,15 +431,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias= pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
spectral_norm=False):
""" Conv layer with padding, normalization, activation """
- assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+ assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0
if convtype=='PartialConv2D':
+ from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='DeformConv2D':
+ from torchvision.ops import DeformConv2d # not tested
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='Conv3D':
diff --git a/modules/extensions.py b/modules/extensions.py index 3a7a0372..bc2c0450 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -6,7 +6,7 @@ import time import git
from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
@@ -31,12 +31,15 @@ class Extension: self.status = ''
self.can_update = False
self.is_builtin = is_builtin
+ self.commit_hash = ''
+ self.commit_date = None
self.version = ''
+ self.branch = None
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
- if self.have_info_from_repo:
+ if self.is_builtin or self.have_info_from_repo:
return
self.have_info_from_repo = True
@@ -56,10 +59,15 @@ class Extension: self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit
- ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
- self.version = f'{head.hexsha[:8]} ({ts})'
-
- except Exception:
+ self.commit_date = repo.head.commit.committed_date
+ ts = time.asctime(time.gmtime(self.commit_date))
+ if repo.active_branch:
+ self.branch = repo.active_branch.name
+ self.commit_hash = head.hexsha
+ self.version = f'{self.commit_hash[:8]} ({ts})'
+
+ except Exception as ex:
+ print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
def list_files(self, subdir, extension):
@@ -82,18 +90,30 @@ class Extension: for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
- self.status = "behind"
+ self.status = "new commits"
+ return
+
+ try:
+ origin = repo.rev_parse('origin')
+ if repo.head.commit != origin:
+ self.can_update = True
+ self.status = "behind HEAD"
return
+ except Exception:
+ self.can_update = False
+ self.status = "unknown (remote error)"
+ return
self.can_update = False
self.status = "latest"
- def fetch_and_reset_hard(self):
+ def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
- repo.git.reset('origin', hard=True)
+ repo.git.reset(commit, hard=True)
+ self.have_info_from_repo = False
def list_extensions():
diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 1978673d..f9db41bc 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -91,7 +91,7 @@ def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
- for extra_network_name, extra_network_args in extra_network_data.items():
+ for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index d3a4d7ad..aa2a14ef 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -1,4 +1,4 @@ -from modules import extra_networks, shared, extra_networks
+from modules import extra_networks, shared
from modules.hypernetworks import hypernetwork
@@ -9,8 +9,9 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
- if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
- p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
+ if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
+ p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
diff --git a/modules/extras.py b/modules/extras.py index d8ece955..eb4f0b42 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,6 +1,7 @@ import os
import re
import shutil
+import json
import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable): return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -135,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ result_is_instruct_pix2pix_model = False
if theta_func2:
- shared.state.textinfo = f"Loading B"
+ shared.state.textinfo = "Loading B"
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
else:
theta_1 = None
if theta_func1:
- shared.state.textinfo = f"Loading C"
+ shared.state.textinfo = "Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
@@ -241,13 +242,54 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
+ metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+
+ if save_metadata:
+ merge_recipe = {
+ "type": "webui", # indicate this model was merged with webui's built-in merger
+ "primary_model_hash": primary_model_info.sha256,
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+ "interp_method": interp_method,
+ "multiplier": multiplier,
+ "save_as_half": save_as_half,
+ "custom_name": custom_name,
+ "config_source": config_source,
+ "bake_in_vae": bake_in_vae,
+ "discard_weights": discard_weights,
+ "is_inpainting": result_is_inpainting_model,
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
+ }
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+
+ def add_model_metadata(checkpoint_info):
+ checkpoint_info.calculate_shorthash()
+ metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ "name": checkpoint_info.name,
+ "legacy_hash": checkpoint_info.hash,
+ "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
+ }
+
+ metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+
+ add_model_metadata(primary_model_info)
+ if secondary_model_info:
+ add_model_metadata(secondary_model_info)
+ if tertiary_model_info:
+ add_model_metadata(tertiary_model_info)
+
+ metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
else:
torch.save(theta_0, output_modelname)
sd_models.list_models()
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+ if created_model:
+ created_model.calculate_shorthash()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 6df76858..b0e945a1 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,15 +1,11 @@ import base64
-import html
import io
-import math
import os
import re
-from pathlib import Path
import gradio as gr
from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks
-import tempfile
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
@@ -23,14 +19,14 @@ registered_param_bindings = [] class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
- self.paste_field_names = paste_field_names
+ self.paste_field_names = paste_field_names or []
def reset():
@@ -59,6 +55,7 @@ def image_from_url_text(filedata): is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
+ filename = filename.rsplit('?', 1)[0]
return Image.open(filename)
if type(filedata) == list:
@@ -129,6 +126,7 @@ def connect_paste_params_buttons(): _js=jsfunc,
inputs=[binding.source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ show_progress=False,
)
if binding.source_text_component is not None and fields is not None:
@@ -140,6 +138,7 @@ def connect_paste_params_buttons(): fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
outputs=[field for field, name in fields if name in paste_field_names],
+ show_progress=False,
)
binding.paste_button.click(
@@ -147,6 +146,7 @@ def connect_paste_params_buttons(): _js=f"switch_to_{binding.tabname}",
inputs=None,
outputs=None,
+ show_progress=False,
)
@@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model lines.append(lastline)
lastline = ''
- for i, line in enumerate(lines):
+ for line in lines:
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
@@ -265,8 +265,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
- res[k+"-1"] = m.group(1)
- res[k+"-2"] = m.group(2)
+ res[f"{k}-1"] = m.group(1)
+ res[f"{k}-2"] = m.group(2)
else:
res[k] = v
@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model restore_old_hires_fix_params(res)
+ # Missing RNG means the default was set, which is GPU RNG
+ if "RNG" not in res:
+ res["RNG"] = "GPU"
+
return res
@@ -304,6 +308,8 @@ infotext_to_setting_name_mapping = [ ('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ('RNG', 'randn_source'),
+ ('NGMS', 's_min_uncond'),
]
@@ -403,12 +409,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, fn=paste_func,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
+ show_progress=False,
)
button.click(
fn=None,
_js=f"recalculate_prompts_{tabname}",
inputs=[],
outputs=[],
+ show_progress=False,
)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index fbe6215a..0131dea4 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -78,7 +78,7 @@ def setup_model(dirname): try:
from gfpgan import GFPGANer
- from facexlib import detection, parsing
+ from facexlib import detection, parsing # noqa: F401
global user_path
global have_gfpgan
global gfpgan_constructor
diff --git a/modules/hashes.py b/modules/hashes.py index 83272a07..032120f4 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -13,7 +13,7 @@ cache_data = None def dump_cache():
- with filelock.FileLock(cache_filename+".lock"):
+ with filelock.FileLock(f"{cache_filename}.lock"):
with open(cache_filename, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4)
@@ -22,7 +22,7 @@ def cache(subsection): global cache_data
if cache_data is None:
- with filelock.FileLock(cache_filename+".lock"):
+ with filelock.FileLock(f"{cache_filename}.lock"):
if not os.path.isfile(cache_filename):
cache_data = {}
else:
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 1fc49537..38ef074f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,4 +1,3 @@ -import csv
import datetime
import glob
import html
@@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
-from collections import defaultdict, deque
+from collections import deque
from statistics import stdev, mean
@@ -178,34 +177,34 @@ class Hypernetwork: def weights(self):
res = []
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
res += layer.parameters()
return res
def train(self, mode=True):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.train(mode=mode)
for param in layer.parameters():
param.requires_grad = mode
def to(self, device):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.multiplier = multiplier
return self
def eval(self):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.eval()
for param in layer.parameters():
@@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): k = self.to_k(context_k)
v = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 76599f5a..8b6255e2 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -1,19 +1,17 @@ import html
-import os
-import re
import gradio as gr
import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
+ return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
def train_hypernetwork(*args):
diff --git a/modules/images.py b/modules/images.py index b3535070..c4e98c75 100644 --- a/modules/images.py +++ b/modules/images.py @@ -19,7 +19,7 @@ import json import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
-from modules.shared import opts, cmd_opts
+from modules.shared import opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -149,7 +149,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): return ImageFont.truetype(Roboto, fontsize)
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
- for i, line in enumerate(lines):
+ for line in lines:
fnt = initial_fnt
fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
@@ -318,6 +318,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
+NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
def sanitize_filename_part(text, replace_spaces=True):
@@ -352,6 +353,11 @@ class FilenameGenerator: 'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
+ 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
+ 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
}
default_time_format = '%Y%m%d%H%M%S'
@@ -360,6 +366,22 @@ class FilenameGenerator: self.seed = seed
self.prompt = prompt
self.image = image
+
+ def hasprompt(self, *args):
+ lower = self.prompt.lower()
+ if self.p is None or self.prompt is None:
+ return None
+ outres = ""
+ for arg in args:
+ if arg != "":
+ division = arg.split("|")
+ expected = division[0].lower()
+ default = division[1] if len(division) > 1 else ""
+ if lower.find(expected) >= 0:
+ outres = f'{outres}{expected}'
+ else:
+ outres = outres if default == "" else f'{outres}{default}'
+ return sanitize_filename_part(outres)
def prompt_no_style(self):
if self.p is None or self.prompt is None:
@@ -387,13 +409,13 @@ class FilenameGenerator: time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
- except pytz.exceptions.UnknownTimeZoneError as _:
+ except pytz.exceptions.UnknownTimeZoneError:
time_zone = None
time_zone_time = time_datetime.astimezone(time_zone)
try:
formatted_time = time_zone_time.strftime(time_format)
- except (ValueError, TypeError) as _:
+ except (ValueError, TypeError):
formatted_time = time_zone_time.strftime(self.default_time_format)
return sanitize_filename_part(formatted_time, replace_spaces=False)
@@ -403,9 +425,9 @@ class FilenameGenerator: for m in re_pattern.finditer(x):
text, pattern = m.groups()
- res += text
if pattern is None:
+ res += text
continue
pattern_args = []
@@ -426,11 +448,13 @@ class FilenameGenerator: print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- if replacement is not None:
- res += str(replacement)
+ if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
+ continue
+ elif replacement is not None:
+ res += text + str(replacement)
continue
- res += f'[{pattern}]'
+ res += f'{text}[{pattern}]'
return res
@@ -443,14 +467,14 @@ def get_next_sequence_number(path, basename): """
result = -1
if basename != '':
- basename = basename + "-"
+ basename = f"{basename}-"
prefix_length = len(basename)
for p in os.listdir(path):
if p.startswith(basename):
- l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try:
- result = max(int(l[0]), result)
+ result = max(int(parts[0]), result)
except ValueError:
pass
@@ -512,7 +536,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i add_number = opts.save_images_add_number or file_decoration == ''
if file_decoration != "" and add_number:
- file_decoration = "-" + file_decoration
+ file_decoration = f"-{file_decoration}"
file_decoration = namegen.apply(file_decoration) + suffix
@@ -542,7 +566,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i def _atomically_save_image(image_to_save, filename_without_extension, extension):
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
- temp_file_path = filename_without_extension + ".tmp"
+ temp_file_path = f"{filename_without_extension}.tmp"
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png':
@@ -602,7 +626,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt"
with open(txt_fullfn, "w", encoding="utf8") as file:
- file.write(info + "\n")
+ file.write(f"{info}\n")
else:
txt_fullfn = None
diff --git a/modules/img2img.py b/modules/img2img.py index 953ac5d2..d704bf90 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -1,19 +1,15 @@ -import math
import os
-import sys
-import traceback
import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
-from modules import devices, sd_samplers
+from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
-import modules.images as images
import modules.scripts
@@ -46,7 +42,11 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): if state.interrupted:
break
- img = Image.open(image)
+ try:
+ img = Image.open(image)
+ except UnidentifiedImageError as e:
+ print(e)
+ continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
@@ -55,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): # try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# if not found use first one ("same mask for all images" use-case)
- if not mask_image_path in inpaint_masks:
+ if mask_image_path not in inpaint_masks:
mask_image_path = inpaint_masks[0]
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
@@ -78,7 +78,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -114,6 +114,12 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if image is not None:
image = ImageOps.exif_transpose(image)
+ if selected_scale_tab == 1:
+ assert image, "Can't scale by because no image is selected"
+
+ width = int(image.width * scale_by)
+ height = int(image.height * scale_by)
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
@@ -151,7 +157,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s override_settings=override_settings,
)
- p.scripts = modules.scripts.scripts_txt2img
+ p.scripts = modules.scripts.scripts_img2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts:
diff --git a/modules/interrogate.py b/modules/interrogate.py index cbb80683..111b1322 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,7 +11,6 @@ import torch.hub from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-import modules.shared as shared
from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384
@@ -28,11 +27,11 @@ def category_types(): def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
- tmpdir = content_dir + "_tmp"
+ tmpdir = f"{content_dir}_tmp"
category_types = ["artists", "flavors", "mediums", "movements"]
try:
- os.makedirs(tmpdir)
+ os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
@@ -41,7 +40,7 @@ def download_default_clip_interrogate_categories(content_dir): errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
- os.remove(tmpdir)
+ os.removedirs(tmpdir)
class InterrogateModels:
@@ -160,7 +159,7 @@ class InterrogateModels: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
+ text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
@@ -208,13 +207,13 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True)
- for name, topn, items in self.categories():
- matches = self.rank(image_features, items, top_count=topn)
+ for cat in self.categories():
+ matches = self.rank(image_features, cat.items, top_count=cat.topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})"
else:
- res += ", " + match
+ res += f", {match}"
except Exception:
print("Error interrogating", file=sys.stderr)
diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 6fe8dea0..5c2f92a1 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,6 +1,5 @@ import torch import platform -from modules import paths from modules.sd_hijack_utils import CondFunc from packaging import version @@ -54,6 +53,11 @@ if has_mps: CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) - if version.parse(torch.__version__) == version.parse("2.0"): + # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 - CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6) + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') + + # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 + if platform.processor() == 'i386': + for funcName in ['torch.argmax', 'torch.Tensor.argmax']: + CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
\ No newline at end of file diff --git a/modules/modelloader.py b/modules/modelloader.py index 522affc6..2a479bcb 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,4 +1,3 @@ -import glob import os import shutil import importlib @@ -22,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None """ output = [] - if ext_filter is None: - ext_filter = [] - try: places = [] @@ -39,22 +35,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None places.append(model_path) for place in places: - if os.path.exists(place): - for file in glob.iglob(place + '**/**', recursive=True): - full_path = file - if os.path.isdir(full_path): - continue - if os.path.islink(full_path) and not os.path.exists(full_path): - print(f"Skipping broken symlink: {full_path}") - continue - if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): - continue - if len(ext_filter) != 0: - model_name, extension = os.path.splitext(file) - if extension not in ext_filter: - continue - if file not in output: - output.append(full_path) + for full_path in shared.walk_files(place, allowed_extensions=ext_filter): + if os.path.islink(full_path) and not os.path.exists(full_path): + print(f"Skipping broken symlink: {full_path}") + continue + if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist): + continue + if full_path not in output: + output.append(full_path) if model_url is not None and len(output) == 0: if download_name is not None: @@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): print(f"Moving {file} from {src_path} to {dest_path}.") try: shutil.move(fullpath, dest_path) - except: + except Exception: pass if len(os.listdir(src_path)) == 0: print(f"Removing empty folder: {src_path}") shutil.rmtree(src_path, True) - except: + except Exception: pass -builtin_upscaler_classes = [] -forbidden_upscaler_classes = set() - - -def list_builtin_upscalers(): - load_upscalers() - - builtin_upscaler_classes.clear() - builtin_upscaler_classes.extend(Upscaler.__subclasses__()) - - -def forbid_loaded_nonbuiltin_upscalers(): - for cls in Upscaler.__subclasses__(): - if cls not in builtin_upscaler_classes: - forbidden_upscaler_classes.add(cls) - - def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ @@ -155,15 +126,22 @@ def load_upscalers(): full_model = f"modules.{model_name}_model" try: importlib.import_module(full_model) - except: + except Exception: pass datas = [] commandline_options = vars(shared.cmd_opts) - for cls in Upscaler.__subclasses__(): - if cls in forbidden_upscaler_classes: - continue + # some of upscaler classes will not go away after reloading their modules, and we'll end + # up with two copies of those classes. The newest copy will always be the last in the list, + # so we go from end to beginning and ignore duplicates + used_classes = {} + for cls in reversed(Upscaler.__subclasses__()): + classname = str(cls) + if classname not in used_classes: + used_classes[classname] = cls + + for cls in reversed(used_classes.values()): name = cls.__name__ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" scaler = cls(commandline_options.get(cmd_name, None)) diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index f3d49c44..3fb76b65 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -52,7 +52,7 @@ class DDPM(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -107,7 +107,7 @@ class DDPM(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) # If initialing from EMA-only checkpoint, create EMA model after loading. if self.use_ema and not load_ema: @@ -194,7 +194,9 @@ class DDPM(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): + ignore_keys = ignore_keys or [] + sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] @@ -223,7 +225,7 @@ class DDPM(pl.LightningModule): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) @@ -386,7 +388,7 @@ class DDPM(pl.LightningModule): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @@ -403,7 +405,7 @@ class DDPM(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -411,7 +413,7 @@ class DDPM(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -473,13 +475,13 @@ class LatentDiffusion(DDPM): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -891,16 +893,6 @@ class LatentDiffusion(DDPM): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] @@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, @@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM): use_ddim = False - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: @@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. @@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py index e1265e3f..dbb35964 100644 --- a/modules/models/diffusion/uni_pc/__init__.py +++ b/modules/models/diffusion/uni_pc/__init__.py @@ -1 +1 @@ -from .sampler import UniPCSampler +from .sampler import UniPCSampler # noqa: F401 diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index a241c8a7..0a9defa1 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -54,7 +54,8 @@ class UniPCSampler(object): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 1d1b07bd..d257a728 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F import math import tqdm @@ -94,7 +93,7 @@ class NoiseScheduleVP: """ if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'") self.schedule = schedule if schedule == 'discrete': @@ -179,13 +178,13 @@ def model_wrapper( model, noise_schedule, model_type="noise", - model_kwargs={}, + model_kwargs=None, guidance_type="uncond", #condition=None, #unconditional_condition=None, guidance_scale=1., classifier_fn=None, - classifier_kwargs={}, + classifier_kwargs=None, ): """Create a wrapper function for the noise prediction model. @@ -276,6 +275,9 @@ def model_wrapper( A noise prediction model that accepts the noised data and the continuous time as the inputs. """ + model_kwargs = model_kwargs or {} + classifier_kwargs = classifier_kwargs or {} + def get_model_input_time(t_continuous): """ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. @@ -342,7 +344,7 @@ def model_wrapper( t_in = torch.cat([t_continuous] * 2) if isinstance(condition, dict): assert isinstance(unconditional_condition, dict) - c_in = dict() + c_in = {} for k in condition: if isinstance(condition[k], list): c_in[k] = [torch.cat([ @@ -353,7 +355,7 @@ def model_wrapper( unconditional_condition[k], condition[k]]) elif isinstance(condition, list): - c_in = list() + c_in = [] assert isinstance(unconditional_condition, list) for i in range(len(condition)): c_in.append(torch.cat([unconditional_condition[i], condition[i]])) @@ -469,7 +471,7 @@ class UniPC: t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) return t else: - raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'") def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ diff --git a/modules/ngrok.py b/modules/ngrok.py index 3df2c06b..7a7b4b26 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -7,12 +7,24 @@ def connect(token, port, region): else: if ':' in token: # token = authtoken:username:password - account = token.split(':')[1] + ':' + token.split(':')[-1] - token = token.split(':')[0] + token, username, password = token.split(':', 2) + account = f"{username}:{password}" config = conf.PyngrokConfig( auth_token=token, region=region ) + + # Guard for existing tunnels + existing = ngrok.get_tunnels(pyngrok_config=config) + if existing: + for established in existing: + # Extra configuration in the case that the user is also using ngrok for other tunnels + if established.config['addr'][-4:] == str(port): + public_url = existing[0].public_url + print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n' + 'You can use this link after the launch is complete.') + return + try: if account is None: public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url diff --git a/modules/paths.py b/modules/paths.py index 0e1e00e7..5f6474c0 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,8 +1,8 @@ import os
import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
-import modules.safe
+import modules.safe # noqa: F401
# data_path = cmd_opts_pre.data
@@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths: sd_path = os.path.abspath(possible_sd_path)
break
-assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
+assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 926ec3bb..6765bafe 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
+config_states_dir = os.path.join(script_path, "config_states")
diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 09d8e605..736315e2 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -18,9 +18,14 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, if extras_mode == 1:
for img in image_folder:
- image = Image.open(img)
+ if isinstance(img, Image.Image):
+ image = img
+ fn = ''
+ else:
+ image = Image.open(os.path.abspath(img.name))
+ fn = os.path.splitext(img.orig_name)[0]
image_data.append(image)
- image_names.append(os.path.splitext(img.orig_name)[0])
+ image_names.append(fn)
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'
diff --git a/modules/processing.py b/modules/processing.py index 6d9c6a8d..c3932d6b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -2,7 +2,7 @@ import json import math
import os
import sys
-import warnings
+import hashlib
import torch
import numpy as np
@@ -10,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps import random
import cv2
from skimage import exposure
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -105,7 +105,7 @@ class StableDiffusionProcessing: """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -140,6 +140,7 @@ class StableDiffusionProcessing: self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
+ self.s_min_uncond = s_min_uncond or opts.s_min_uncond
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
@@ -162,6 +163,8 @@ class StableDiffusionProcessing: self.all_seeds = None
self.all_subseeds = None
self.iteration = 0
+ self.is_hr_pass = False
+
@property
def sd_model(self):
@@ -454,6 +457,16 @@ def fix_seed(p): p.subseed = get_fixed_seed(p.subseed)
+def program_version():
+ import launch
+
+ res = launch.git_tag()
+ if res == "<none>":
+ res = None
+
+ return res
+
+
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
@@ -476,13 +489,17 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "Init image hash": getattr(p, 'init_img_hash', None),
+ "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+ "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+ "Version": program_version() if opts.add_version_to_infotext else None,
}
generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
+ negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -491,6 +508,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
+ # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
+ if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+ p.override_settings.pop('sd_model_checkpoint', None)
+ sd_models.reload_model_weights()
+
for k, v in p.override_settings.items():
setattr(opts, k, v)
@@ -507,8 +529,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
- if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights()
if k == 'sd_vae':
sd_vae.reload_vae_weights()
@@ -639,8 +659,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
- c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
+ step_multiplier = 1
+ if not shared.opts.dont_fix_second_order_samplers_schedule:
+ try:
+ step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
+ except Exception:
+ pass
+ uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
+ c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -670,6 +696,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
for i, x_sample in enumerate(x_samples_ddim):
+ p.batch_index = i
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
@@ -706,9 +734,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image.info["parameters"] = text
output_images.append(image)
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+ if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
+ image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
@@ -718,7 +746,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.return_mask:
output_images.append(image_mask)
-
+
if opts.return_mask_composite:
output_images.append(image_mask_composite)
@@ -751,7 +779,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc()
- res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+ res = Processed(
+ p,
+ images_list=output_images,
+ seed=p.all_seeds[0],
+ info=infotext(),
+ comments="".join(f"\n\n{comment}" for comment in comments),
+ subseed=p.all_subseeds[0],
+ index_of_first_image=index_of_first_image,
+ infotexts=infotexts,
+ )
if p.scripts is not None:
p.scripts.postprocess(p, res)
@@ -871,6 +908,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr:
return samples
+ self.is_hr_pass = True
+
target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y
@@ -940,6 +979,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ self.is_hr_pass = False
+
return samples
@@ -1007,6 +1048,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.color_corrections = []
imgs = []
for img in self.init_images:
+
+ # Save init image
+ if opts.save_init_img:
+ self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+
image = images.flatten(img, opts.img2img_background_color)
if crop_region is None and self.resize_mode != 3:
diff --git a/modules/progress.py b/modules/progress.py index c69ecf3d..289dd311 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -13,6 +13,8 @@ import modules.shared as shared current_task = None
pending_tasks = {}
finished_tasks = []
+recorded_results = []
+recorded_results_limit = 2
def start_task(id_task):
@@ -33,6 +35,12 @@ def finish_task(id_task): finished_tasks.pop(0)
+def record_results(id_task, res):
+ recorded_results.append((id_task, res))
+ if len(recorded_results) > recorded_results_limit:
+ recorded_results.pop(0)
+
+
def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time()
@@ -87,8 +95,9 @@ def progressapi(req: ProgressRequest): image = shared.state.current_image
if image is not None:
buffered = io.BytesIO()
- image.save(buffered, format="png")
- live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
+ image.save(buffered, format=opts.live_previews_format)
+ base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
+ live_preview = f"data:image/{opts.live_previews_format};base64,{base64_image}"
id_live_preview = shared.state.id_live_preview
else:
live_preview = None
@@ -97,3 +106,13 @@ def progressapi(req: ProgressRequest): return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+
+def restore_progress(id_task):
+ while id_task == current_task or id_task in pending_tasks:
+ time.sleep(0.1)
+
+ res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
+ if res is not None:
+ return res
+
+ return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 69665372..b4aff704 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): """
def collect_steps(steps, tree):
- l = [steps]
+ res = [steps]
+
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
- l.append(tree.children[-1])
+ res.append(tree.children[-1])
+
def alternate(self, tree):
- l.extend(range(1, steps+1))
+ res.extend(range(1, steps+1))
+
CollectSteps().visit(tree)
- return sorted(set(l))
+ return sorted(set(res))
def at_step(step, tree):
class AtStep(lark.Transformer):
@@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): def get_schedule(prompt):
try:
tree = schedule_parser.parse(prompt)
- except lark.exceptions.LarkError as e:
+ except lark.exceptions.LarkError:
if 0:
import traceback
traceback.print_exc()
@@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps): conds = model.get_learned_conditioning(texts)
cond_schedule = []
- for i, (end_at_step, text) in enumerate(prompt_schedule):
+ for i, (end_at_step, _) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule
@@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
target_index = 0
- for current, (end_at, cond) in enumerate(cond_schedule):
- if current_step <= end_at:
+ for current, entry in enumerate(cond_schedule):
+ if current_step <= entry.end_at_step:
target_index = current
break
res[i] = cond_schedule[target_index].cond
@@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): tensors = []
conds_list = []
- for batch_no, composable_prompts in enumerate(c.batch):
+ for composable_prompts in c.batch:
conds_for_batch = []
- for cond_index, composable_prompt in enumerate(composable_prompts):
+ for composable_prompt in composable_prompts:
target_index = 0
- for current, (end_at, cond) in enumerate(composable_prompt.schedules):
- if current_step <= end_at:
+ for current, entry in enumerate(composable_prompt.schedules):
+ if current_step <= entry.end_at_step:
target_index = current
break
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index aad4a629..c24d8dbb 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -9,7 +9,7 @@ from realesrgan import RealESRGANer from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
-
+from modules import modelloader
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
@@ -17,13 +17,21 @@ class UpscalerRealESRGAN(Upscaler): self.user_path = path
super().__init__()
try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+ from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
+ from realesrgan import RealESRGANer # noqa: F401
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True
self.scalers = []
scalers = self.load_models(path)
+
+ local_model_paths = self.find_models(ext_filter=[".pth"])
for scaler in scalers:
+ if scaler.local_data_path.startswith("http"):
+ filename = modelloader.friendly_name(scaler.local_data_path)
+ local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
+ if local_model_candidates:
+ scaler.local_data_path = local_model_candidates[0]
+
if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler)
@@ -39,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler): info = self.load_model(path)
if not os.path.exists(info.local_data_path):
- print("Unable to load RealESRGAN model: %s" % info.name)
+ print(f"Unable to load RealESRGAN model: {info.name}")
return img
upsampler = RealESRGANer(
@@ -64,7 +72,9 @@ class UpscalerRealESRGAN(Upscaler): print(f"Unable to find model info: {path}")
return None
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+ if info.local_data_path.startswith("http"):
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
@@ -124,6 +134,6 @@ def get_realesrgan_models(scaler): ),
]
return models
- except Exception as e:
+ except Exception:
print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/safe.py b/modules/safe.py index 82d44be3..1e791c5b 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -1,6 +1,5 @@ # this code is adapted from the script contributed by anon from /h/
-import io
import pickle
import collections
import sys
@@ -12,11 +11,9 @@ import _codecs import zipfile
import re
-
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
-
def encode(*args):
out = _codecs.encode(*args)
return out
@@ -27,7 +24,11 @@ class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
- return TypedStorage()
+
+ try:
+ return TypedStorage(_internal=True)
+ except TypeError:
+ return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
def find_class(self, module, name):
if self.extra_handler is not None:
@@ -94,16 +95,16 @@ def check_pt(filename, extra_handler): except zipfile.BadZipfile:
- # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
- for i in range(5):
+ for _ in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
- return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
+ return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 07911876..17109732 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -93,6 +93,7 @@ callback_map = dict( callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
callbacks_before_ui=[],
+ callbacks_on_reload=[],
)
@@ -109,6 +110,14 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): report_exception(c, 'app_started_callback')
+def app_reload_callback():
+ for c in callback_map['callbacks_on_reload']:
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'callbacks_on_reload')
+
+
def model_loaded_callback(sd_model):
for c in callback_map['callbacks_model_loaded']:
try:
@@ -254,6 +263,11 @@ def on_app_started(callback): add_callback(callback_map['callbacks_app_started'], callback)
+def on_before_reload(callback):
+ """register a function to be called just before the server reloads."""
+ add_callback(callback_map['callbacks_on_reload'], callback)
+
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument; this function is also called when the script is reloaded. """
diff --git a/modules/script_loading.py b/modules/script_loading.py index a7d2203f..57b15862 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -2,7 +2,6 @@ import os import sys
import traceback
import importlib.util
-from types import ModuleType
def load_module(path):
diff --git a/modules/scripts.py b/modules/scripts.py index 4d0bbd66..0c12ebd5 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -163,7 +163,8 @@ class Script: """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
- tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+ tabkind = 'img2img' if self.is_img2img else 'txt2txt'
+ tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
return f'script_{tabname}{title}_{item_id}'
@@ -230,7 +231,7 @@ def load_scripts(): syspath = sys.path
def register_scripts_from_module(module):
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) != type:
continue
@@ -294,9 +295,9 @@ class ScriptRunner: auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
- for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
- script = script_class()
- script.filename = path
+ for script_data in auto_processing_scripts + scripts_data:
+ script = script_data.script_class()
+ script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -491,7 +492,7 @@ class ScriptRunner: module = script_loading.load_module(script.filename)
cache[filename] = module
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
@@ -526,7 +527,7 @@ def add_classes_to_gradio_component(comp): this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
- comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py index 30d6d658..d63078de 100644 --- a/modules/scripts_auto_postprocessing.py +++ b/modules/scripts_auto_postprocessing.py @@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script): return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args):
- args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
+ args_dict = dict(zip(self.postprocessing_controls, args))
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index b11568c0..bac1335d 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -66,9 +66,9 @@ class ScriptPostprocessingRunner: def initialize_scripts(self, scripts_data):
self.scripts = []
- for script_class, path, basedir, script_module in scripts_data:
- script: ScriptPostprocessing = script_class()
- script.filename = path
+ for script_data in scripts_data:
+ script: ScriptPostprocessing = script_data.script_class()
+ script.filename = script_data.path
if script.name == "Simple Upscale":
continue
@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner: script_args = args[script.args_from:script.args_to]
process_args = {}
- for (name, component), value in zip(script.controls.items(), script_args):
+ for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value
script.process(pp, **process_args)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index c4a09d15..9fc89dc6 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -61,7 +61,7 @@ class DisableInitialization: if res is None:
res = original(url, *args, local_files_only=False, **kwargs)
return res
- except Exception as e:
+ except Exception:
return original(url, *args, local_files_only=False, **kwargs)
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f4bb0266..e374aeb8 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,7 +3,7 @@ from torch.nn.functional import silu from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules import devices, sd_hijack_optimizations, shared
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -37,7 +37,7 @@ def apply_optimizations(): optimization_method = None
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
@@ -118,7 +118,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs): try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
- except AttributeError as e:
+ except AttributeError:
pass
#If we have an old loss function, reset the loss function to the original one
@@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model): def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
- except AttributeError as e:
+ except AttributeError:
pass
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 9fa5c5c5..cc6e8c21 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): self.hijack.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes:
- for position, embedding in fixes:
+ for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding
z = self.process_tokens(tokens, multipliers)
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index 6d9fbbe6..a3476e95 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+ embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
+ self.hijack.comments.append(f"Used embeddings: {embedding_names}")
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 55a2ce4d..c1977b19 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -1,16 +1,10 @@ -import os import torch -from einops import repeat -from omegaconf import ListConfig - import ldm.models.diffusion.ddpm import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.ddim import DDIMSampler, noise_like +from ldm.models.diffusion.ddim import noise_like from ldm.models.diffusion.sampling_util import norm_thresholding @@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F if isinstance(c, dict): assert isinstance(unconditional_conditioning, dict) - c_in = dict() + c_in = {} for k in c: if isinstance(c[k], list): c_in[k] = [ diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py index 3c727d3b..6fe6b6ff 100644 --- a/modules/sd_hijack_ip2p.py +++ b/modules/sd_hijack_ip2p.py @@ -1,8 +1,5 @@ -import collections import os.path -import sys -import gc -import time + def should_hijack_ip2p(checkpoint_info): from modules import sd_models_config @@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info): ckpt_basename = os.path.basename(checkpoint_info.filename).lower() cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() - return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename + return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 372555ff..a174bbe1 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): v_in = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -98,7 +98,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): del context, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
@@ -229,7 +229,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ if q.device.type == 'mps':
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
+
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -293,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
- query_chunk_size = q_tokens
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
@@ -332,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -458,7 +460,7 @@ def xformers_attnblock_forward(self, x): k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -480,7 +482,7 @@ def sdp_attnblock_forward(self, x): k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -504,7 +506,7 @@ def sub_quad_attnblock_forward(self, x): k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 15858263..ca1daf45 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -18,7 +18,7 @@ class TorchHijackForUnet: if hasattr(torch, item):
return getattr(torch, item)
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2:
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py index 4ac51c38..28528329 100644 --- a/modules/sd_hijack_xlmr.py +++ b/modules/sd_hijack_xlmr.py @@ -1,8 +1,6 @@ -import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
-from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ea874df..d1e946a5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,8 @@ import collections import os.path
import sys
import gc
+import threading
+
import torch
import re
import safetensors.torch
@@ -13,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
@@ -45,20 +46,29 @@ class CheckpointInfo: self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+ self.metadata = {}
+
+ _, ext = os.path.splitext(self.filename)
+ if ext.lower() == ".safetensors":
+ try:
+ self.metadata = read_metadata_from_safetensors(filename)
+ except Exception as e:
+ errors.display(e, f"reading checkpoint metadata: {filename}")
+
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
if self.sha256 is None:
return
@@ -76,8 +86,7 @@ class CheckpointInfo: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
@@ -228,7 +237,7 @@ def read_metadata_from_safetensors(filename): if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
- except Exception as e:
+ except Exception:
pass
return res
@@ -395,13 +404,39 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
+
+class SdModelData:
+ def __init__(self):
+ self.sd_model = None
+ self.lock = threading.Lock()
+
+ def get_sd_model(self):
+ if self.sd_model is None:
+ with self.lock:
+ try:
+ load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load", file=sys.stderr)
+ self.sd_model = None
+
+ return self.sd_model
+
+ def set_sd_model(self, v):
+ self.sd_model = v
+
+
+model_data = SdModelData()
+
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
- if shared.sd_model:
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
gc.collect()
devices.torch_gc()
@@ -430,7 +465,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
+ except Exception:
pass
if sd_model is None:
@@ -455,7 +490,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ timer.record("hijack")
sd_model.eval()
- shared.sd_model = sd_model
+ model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -475,7 +510,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_info = info or select_checkpoint()
if not sd_model:
- sd_model = shared.sd_model
+ sd_model = model_data.sd_model
if sd_model is None: # previous model load failed
current_checkpoint_info = None
@@ -503,11 +538,11 @@ def reload_model_weights(sd_model=None, info=None): del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
- return shared.sd_model
+ return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
+ except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
@@ -526,17 +561,15 @@ def reload_model_weights(sd_model=None, info=None): return sd_model
+
def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
- if shared.sd_model:
-
- # shared.sd_model.cond_stage_model.to(devices.cpu)
- # shared.sd_model.first_stage_model.to(devices.cpu)
- shared.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ model_data.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
@@ -544,4 +577,4 @@ def unload_model_weights(sd_model=None, info=None): print(f"Unloaded weights {timer.summary()}.")
- return sd_model
\ No newline at end of file + return sd_model
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9398f528..9bfe1237 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -1,4 +1,3 @@ -import re
import os
import torch
@@ -111,7 +110,7 @@ def find_checkpoint_config_near_filename(info): if info is None:
return None
- config = os.path.splitext(info.filename)[0] + ".yaml"
+ config = f"{os.path.splitext(info.filename)[0]}.yaml"
if os.path.exists(config):
return config
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index ff361f22..4f1bf21d 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,7 +1,7 @@ from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
# imports for functions that previously were here and are used by other modules
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index a1aac7cf..bc074238 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -60,3 +60,13 @@ def store_latent(decoded): class InterruptedException(BaseException):
pass
+
+
+if opts.randn_source == "CPU":
+ import torchsde._brownian.brownian_interval
+
+ def torchsde_randn(size, dtype, device, seed):
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+
+ torchsde._brownian.brownian_interval._randn = torchsde_randn
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index bfcc5574..b1ee3be7 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler: def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
@@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler: conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
- assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
+ assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e9f08518..2f733cf5 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,7 +1,6 @@ from collections import deque
import torch
import inspect
-import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
@@ -76,7 +75,7 @@ class CFGDenoiser(torch.nn.Module): return denoised
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -87,7 +86,7 @@ class CFGDenoiser(torch.nn.Module): conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -115,12 +114,21 @@ class CFGDenoiser(torch.nn.Module): sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
+ skip_uncond = False
- if tensor.shape[1] == uncond.shape[1]:
- if not is_edit_model:
- cond_in = torch.cat([tensor, uncond])
- else:
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
@@ -144,7 +152,13 @@ class CFGDenoiser(torch.nn.Module): x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
cfg_denoised_callback(denoised_params)
@@ -152,20 +166,21 @@ class CFGDenoiser(torch.nn.Module): devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
+ sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
- if not is_edit_model:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- else:
+ if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.step += 1
-
return denoised
@@ -182,7 +197,7 @@ class TorchHijack: if hasattr(torch, item):
return getattr(torch, item)
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
if self.sampler_noises:
@@ -190,7 +205,7 @@ class TorchHijack: if noise.shape == x.shape:
return noise
- if x.device.type == 'mps':
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
@@ -210,6 +225,7 @@ class KDiffusionSampler: self.eta = None
self.config = None
self.last_latent = None
+ self.s_min_uncond = None
self.conditioning_key = sd_model.model.conditioning_key
@@ -244,6 +260,7 @@ class KDiffusionSampler: self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -326,6 +343,7 @@ class KDiffusionSampler: 'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
@@ -359,7 +377,8 @@ class KDiffusionSampler: 'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9b00f76e..b7176125 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,8 +1,5 @@ -import torch -import safetensors.torch import os import collections -from collections import namedtuple from modules import paths, shared, devices, script_callbacks, sd_models import glob from copy import deepcopy @@ -89,7 +86,7 @@ def refresh_vae_list(): def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.splitext(checkpoint_file)[0] - for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]: if os.path.isfile(vae_location): return vae_location diff --git a/modules/shared.py b/modules/shared.py index 5fd0eecb..fc39161e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,11 +1,9 @@ -import argparse
import datetime
import json
import os
import sys
import time
-from PIL import Image
import gradio as gr
import tqdm
@@ -14,7 +12,8 @@ import modules.memmon import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
-from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
+from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None
@@ -39,6 +38,7 @@ restricted_opts = { "outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
+ "outdir_init_images"
}
ui_reorder_categories = [
@@ -54,6 +54,21 @@ ui_reorder_categories = [ "scripts",
]
+# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
+gradio_hf_hub_themes = [
+ "gradio/glass",
+ "gradio/monochrome",
+ "gradio/seafoam",
+ "gradio/soft",
+ "freddyaboulton/dracula_revamped",
+ "gradio/dracula_test",
+ "abidlabs/dracula_test",
+ "abidlabs/pakistan",
+ "dawood/microsoft_windows",
+ "ysharma/steampunk"
+]
+
+
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
@@ -196,7 +211,7 @@ class OptionInfo: def options_section(section_identifier, options_dict):
- for k, v in options_dict.items():
+ for v in options_dict.values():
v.section = section_identifier
return options_dict
@@ -252,7 +267,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
- "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
+ "save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
@@ -268,6 +283,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
+ "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
@@ -283,6 +299,8 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
+ "SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -331,6 +349,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
+ "randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -338,6 +357,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
+ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
@@ -361,7 +381,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", hypernetworks]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
@@ -369,29 +389,38 @@ options_templates.update(options_section(('ui', "User interface"), { "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
- "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
- "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
+ "js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
+ "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
- "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
+ "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+ "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
+ "gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
+}))
+
+options_templates.update(options_section(('infotext', "Infotext"), {
+ "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+ "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
+ "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
}))
options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "live_previews_format": OptionInfo("jpeg", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
@@ -405,6 +434,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
@@ -424,6 +454,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), { options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
+ "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
@@ -516,6 +547,10 @@ class Options: with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
+ # 1.1.1 quicksettings list migration
+ if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
+ self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
+
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
@@ -545,11 +580,11 @@ class Options: section_ids = {}
settings_items = self.data_labels.items()
- for k, item in settings_items:
+ for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
- self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+ self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
@@ -574,13 +609,37 @@ class Options: return value
-
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
+
+class Shared(sys.modules[__name__].__class__):
+ """
+ this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
+ at program startup.
+ """
+
+ sd_model_val = None
+
+ @property
+ def sd_model(self):
+ import modules.sd_models
+
+ return modules.sd_models.model_data.get_sd_model()
+
+ @sd_model.setter
+ def sd_model(self, value):
+ import modules.sd_models
+
+ modules.sd_models.model_data.set_sd_model(value)
+
+
+sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
+sys.modules[__name__].__class__ = Shared
+
settings_components = None
-"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
+"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
@@ -594,12 +653,28 @@ latent_upscale_modes = { sd_upscalers = []
-sd_model = None
-
clip_model = None
progress_print_out = sys.stdout
+gradio_theme = gr.themes.Base()
+
+
+def reload_gradio_theme(theme_name=None):
+ global gradio_theme
+ if not theme_name:
+ theme_name = opts.gradio_theme
+
+ if theme_name == "Default":
+ gradio_theme = gr.themes.Default()
+ else:
+ try:
+ gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
+ except Exception as e:
+ errors.display(e, "changing gradio theme")
+ gradio_theme = gr.themes.Default()
+
+
class TotalTQDM:
def __init__(self):
@@ -657,3 +732,20 @@ def html(filename): return file.read()
return ""
+
+
+def walk_files(path, allowed_extensions=None):
+ if not os.path.exists(path):
+ return
+
+ if allowed_extensions is not None:
+ allowed_extensions = set(allowed_extensions)
+
+ for root, _, files in os.walk(path):
+ for filename in files:
+ if allowed_extensions is not None:
+ _, ext = os.path.splitext(filename)
+ if ext not in allowed_extensions:
+ continue
+
+ yield os.path.join(root, filename)
diff --git a/modules/styles.py b/modules/styles.py index 990d5623..c22769cf 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,18 +1,9 @@ -# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
-from __future__ import annotations
-
import csv
import os
import os.path
import typing
-import collections.abc as abc
-import tempfile
import shutil
-if typing.TYPE_CHECKING:
- # Only import this when code is being type-checked, it doesn't have any effect at runtime
- from .processing import StableDiffusionProcessing
-
class PromptStyle(typing.NamedTuple):
name: str
@@ -72,16 +63,14 @@ class StyleDatabase: return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
def save_styles(self, path: str) -> None:
- # Write to temporary file first, so we don't nuke the file if something goes wrong
- fd, temp_path = tempfile.mkstemp(".csv")
+ # Always keep a backup file around
+ if os.path.exists(path):
+ shutil.copy(path, f"{path}.bak")
+
+ fd = os.open(path, os.O_RDWR|os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items())
-
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.move(path, path + ".bak")
- shutil.move(temp_path, path)
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 05595323..cc38debd 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -201,14 +201,15 @@ def efficient_dot_product_attention( key=key, value=value, ) - - # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, - # and pass slices to be mutated, instead of torch.cat()ing the returned slices - res = torch.cat([ - compute_query_chunk_attn( + + res = torch.zeros_like(query) + for i in range(math.ceil(q_tokens / query_chunk_size)): + attn_scores = compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key=key, value=value, - ) for i in range(math.ceil(q_tokens / query_chunk_size)) - ], dim=1) + ) + + res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores + return res diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 68e1103c..7770d22f 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,10 +1,8 @@ import cv2
import requests
import os
-from collections import defaultdict
-from math import log, sqrt
import numpy as np
-from PIL import Image, ImageDraw
+from PIL import ImageDraw
GREEN = "#0F0"
BLUE = "#00F"
@@ -111,7 +109,7 @@ def focal_point(im, settings): if corner_centroid is not None:
color = BLUE
box = corner_centroid.bounding(max_size * corner_centroid.weight)
- d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
+ d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color)
if len(corner_points) > 1:
for f in corner_points:
@@ -119,7 +117,7 @@ def focal_point(im, settings): if entropy_centroid is not None:
color = "#ff0"
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
- d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
+ d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color)
if len(entropy_points) > 1:
for f in entropy_points:
@@ -127,7 +125,7 @@ def focal_point(im, settings): if face_centroid is not None:
color = RED
box = face_centroid.bounding(max_size * face_centroid.weight)
- d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
+ d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color)
if len(face_points) > 1:
for f in face_points:
@@ -185,7 +183,7 @@ def image_face_points(im, settings): try:
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
+ except Exception:
continue
if len(faces) > 0:
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index af9fbcf2..41610e03 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -72,7 +72,7 @@ class PersonalizedBase(Dataset): except Exception:
continue
- text_filename = os.path.splitext(path)[0] + ".txt"
+ text_filename = f"{os.path.splitext(path)[0]}.txt"
filename = os.path.basename(path)
if os.path.exists(text_filename):
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 5593f88c..d85a4888 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -2,7 +2,7 @@ import base64 import json
import numpy as np
import zlib
-from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
+from PIL import Image, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
from modules.shared import opts
@@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder): class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
- json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+ json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index f63fc72f..c56bea45 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -12,7 +12,7 @@ class LearnScheduleIterator: self.it = 0
self.maxit = 0
try:
- for i, pair in enumerate(pairs):
+ for pair in pairs:
if not pair.strip():
continue
tmp = pair.split(':')
@@ -32,8 +32,8 @@ class LearnScheduleIterator: self.maxit += 1
return
assert self.rates
- except (ValueError, AssertionError):
- raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+ except (ValueError, AssertionError) as e:
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
def __iter__(self):
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 2239cb84..d0cad09e 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,17 +1,13 @@ import os
from PIL import Image, ImageOps
import math
-import platform
-import sys
import tqdm
-import time
from modules import paths, shared, images, deepbooru
-from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
@@ -19,7 +15,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height, if process_caption_deepbooru:
deepbooru.model.start()
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
finally:
@@ -63,9 +59,9 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti image.save(os.path.join(params.dstdir, f"{basename}.png"))
if params.preprocess_txt_action == 'prepend' and existing_caption:
- caption = existing_caption + ' ' + caption
+ caption = f"{existing_caption} {caption}"
elif params.preprocess_txt_action == 'append' and existing_caption:
- caption = caption + ' ' + existing_caption
+ caption = f"{caption} {existing_caption}"
elif params.preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
@@ -131,7 +127,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr return wh and center_crop(image, *wh)
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -161,7 +157,9 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre params.subindex = 0
filename = os.path.join(src, imagefile)
try:
- img = Image.open(filename).convert("RGB")
+ img = Image.open(filename)
+ img = ImageOps.exif_transpose(img)
+ img = img.convert("RGB")
except Exception:
continue
@@ -172,7 +170,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre params.src = filename
existing_caption = None
- existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
if os.path.exists(existing_caption_filename):
with open(existing_caption_filename, 'r', encoding="utf8") as file:
existing_caption = file.read()
@@ -223,6 +221,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
process_default_resize = False
+ if process_keep_original_size:
+ save_pic(img, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, params, existing_caption=existing_caption)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d2e62e58..9e1b2b9a 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,7 +1,6 @@ import os
import sys
import traceback
-import inspect
from collections import namedtuple
import torch
@@ -30,7 +29,7 @@ textual_inversion_templates = {} def list_textual_inversion_templates():
textual_inversion_templates.clear()
- for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
for fn in fns:
path = os.path.join(root, fn)
@@ -69,7 +68,7 @@ class Embedding: 'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
- torch.save(optimizer_saved_dict, filename + '.optim')
+ torch.save(optimizer_saved_dict, f"{filename}.optim")
def checksum(self):
if self.cached_checksum is not None:
@@ -167,8 +166,7 @@ class EmbeddingDatabase: # textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
@@ -199,7 +197,7 @@ class EmbeddingDatabase: if not os.path.isdir(embdir.path):
return
- for root, dirs, fns in os.walk(embdir.path, followlinks=True):
+ for root, _, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -216,7 +214,7 @@ class EmbeddingDatabase: def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
if embdir.has_changed():
need_reload = True
break
@@ -229,10 +227,16 @@ class EmbeddingDatabase: self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
self.load_from_dir(embdir)
embdir.update()
+ # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
+ # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
+ sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
+ self.word_embeddings.clear()
+ self.word_embeddings.update(sorted_word_embeddings)
+
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
@@ -431,8 +435,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state:
optimizer_state_dict = None
- if os.path.exists(filename + '.optim'):
- optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if os.path.exists(f"{filename}.optim"):
+ optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
@@ -464,7 +468,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
@@ -593,17 +597,17 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name', '???'))
+ title = f"<{data.get('name', '???')}>"
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
+ except Exception:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.shorthash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
+ footer_mid = f'[{checkpoint.shorthash}]'
+ footer_right = f'{vectorSize}v {steps_done}s'
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
diff --git a/modules/txt2img.py b/modules/txt2img.py index 16841d0f..f022381c 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,18 +1,15 @@ import modules.scripts
-from modules import sd_samplers
+from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
- StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
-import modules.processing as processing
from modules.ui import plaintext_to_html
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
- p = StableDiffusionProcessingTxt2Img(
+ p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
@@ -53,7 +50,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p)
+ processed = processing.process_images(p)
p.close()
diff --git a/modules/ui.py b/modules/ui.py index 627fbe0b..1efb656a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,29 +1,23 @@ -import html
import json
-import math
import mimetypes
import os
-import platform
-import random
import sys
-import tempfile
-import time
import traceback
-from functools import partial, reduce
+from functools import reduce
import warnings
import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np
-from PIL import Image, PngImagePlugin
+from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
-from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import opts, cmd_opts
import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
@@ -34,7 +28,6 @@ import modules.shared as shared import modules.styles
import modules.textual_inversion.ui
from modules import prompt_parser
-from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.textual_inversion import textual_inversion
@@ -81,6 +74,7 @@ apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
+restore_progress_symbol = '\U0001F300' # 🌀
def plaintext_to_html(text):
@@ -92,13 +86,6 @@ def send_gradio_gallery_to_image(x): return None
return image_from_url_text(x[0])
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(path + "/" + str(x.label), x)
-
def add_style(name: str, prompt: str, negative_prompt: str):
if name is None:
@@ -127,6 +114,16 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
+def resize_from_to_html(width, height, scale_by):
+ target_width = int(width * scale_by)
+ target_height = int(height * scale_by)
+
+ if not target_width or not target_height:
+ return "no image selected"
+
+ return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
+
+
def apply_styles(prompt, prompt_neg, styles):
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
@@ -152,7 +149,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di img = Image.open(image)
filename = os.path.basename(image)
left, _ = os.path.splitext(filename)
- print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
return [gr.update(), None]
@@ -168,29 +165,29 @@ def interrogate_deepbooru(image): def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
seed.style(container=False)
- random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+ random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
+ reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
- with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
+ with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
- subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
subseed.style(container=False)
- random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
+ random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
+ reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -232,7 +229,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
- except json.decoder.JSONDecodeError as e:
+ except json.decoder.JSONDecodeError:
if gen_info_string != '':
print("Error parsing JSON generation info:", file=sys.stderr)
print(gen_info_string, file=sys.stderr)
@@ -312,6 +309,7 @@ def create_toprow(is_img2img): extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
+ restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
@@ -329,7 +327,7 @@ def create_toprow(is_img2img): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
def setup_progressbar(*args, **kwargs):
@@ -408,7 +406,7 @@ def create_sampler_and_steps_selection(choices, tabname): def ordered_ui_categories():
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
+ for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
@@ -446,7 +444,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
@@ -468,7 +466,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
@@ -578,6 +576,19 @@ def create_ui(): res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
+ restore_progress_button.click(
+ fn=progress.restore_progress,
+ _js="restoreProgressTxt2img",
+ inputs=[dummy_component],
+ outputs=[
+ txt2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
txt_prompt_img.change(
fn=modules.images.image_data,
inputs=[
@@ -646,7 +657,7 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
@@ -673,6 +684,8 @@ def create_ui(): copy_image_buttons.append((button, name, elem))
with gr.Tabs(elem_id="mode_img2img"):
+ img2img_selected_tab = gr.State(0)
+
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
add_copy_image_controls('img2img', init_img)
@@ -706,8 +719,8 @@ def create_ui(): with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(
- f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
- f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
+ "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
f"{hidden}</p>"
)
@@ -715,6 +728,11 @@ def create_ui(): img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
+
+ for i, tab in enumerate(img2img_tabs):
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
+
def copy_image(img):
if isinstance(img, dict) and 'image' in img:
return img['image']
@@ -729,7 +747,7 @@ def create_ui(): )
button.click(
fn=lambda: None,
- _js="switch_to_"+name.replace(" ", "_"),
+ _js=f"switch_to_{name.replace(' ', '_')}",
inputs=[],
outputs=[],
)
@@ -744,11 +762,44 @@ def create_ui(): elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
-
- with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+ selected_scale_tab = gr.State(value=0)
+
+ with gr.Tabs():
+ with gr.Tab(label="Resize to") as tab_scale_to:
+ with FormRow():
+ with gr.Column(elem_id="img2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+ with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+
+ with gr.Tab(label="Resize by") as tab_scale_by:
+ scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
+
+ with FormRow():
+ scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
+ gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
+ button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
+
+ on_change_args = dict(
+ fn=resize_from_to_html,
+ _js="currentImg2imgSourceResolution",
+ inputs=[dummy_component, dummy_component, scale_by],
+ outputs=scale_by_html,
+ show_progress=False,
+ )
+
+ scale_by.release(**on_change_args)
+ button_update_resize_to.click(**on_change_args)
+
+ # the code below is meant to update the resolution label after the image in the image selection UI has changed.
+ # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
+ # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
+ for component in [init_img, sketch]:
+ component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
+
+ tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
+ tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="img2img_column_batch"):
@@ -759,7 +810,7 @@ def create_ui(): with FormGroup():
with FormRow():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
- image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed":
@@ -806,7 +857,7 @@ def create_ui(): def select_img2img_tab(tab):
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
- for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ for i, elem in enumerate(img2img_tabs):
elem.select(
fn=lambda tab=i: select_img2img_tab(tab),
inputs=[],
@@ -859,8 +910,10 @@ def create_ui(): denoising_strength,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ selected_scale_tab,
height,
width,
+ scale_by,
resize_mode,
inpaint_full_res,
inpaint_full_res_padding,
@@ -898,6 +951,19 @@ def create_ui(): submit.click(**img2img_args)
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
+ restore_progress_button.click(
+ fn=progress.restore_progress,
+ _js="restoreProgressImg2img",
+ inputs=[dummy_component],
+ outputs=[
+ img2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args,
@@ -1019,8 +1085,9 @@ def create_ui(): interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
- checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
+ save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
with FormRow():
with gr.Column():
@@ -1048,7 +1115,7 @@ def create_ui(): with gr.Row(variant="compact").style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
- with gr.Tab(label="Create embedding"):
+ with gr.Tab(label="Create embedding", id="create_embedding"):
new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
@@ -1061,7 +1128,7 @@ def create_ui(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
- with gr.Tab(label="Create hypernetwork"):
+ with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
@@ -1079,7 +1146,7 @@ def create_ui(): with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
- with gr.Tab(label="Preprocess images"):
+ with gr.Tab(label="Preprocess images", id="preprocess_images"):
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
@@ -1087,6 +1154,7 @@ def create_ui(): preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
+ process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
@@ -1144,16 +1212,16 @@ def create_ui(): )
def get_textual_inversion_template_names():
- return sorted([x for x in textual_inversion.textual_inversion_templates])
+ return sorted(textual_inversion.textual_inversion_templates)
- with gr.Tab(label="Train"):
+ with gr.Tab(label="Train", id="train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with FormRow():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
- train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
- create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
with FormRow():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
@@ -1204,8 +1272,8 @@ def create_ui(): with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
- ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
- ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
+ gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
create_embedding.click(
@@ -1253,6 +1321,7 @@ def create_ui(): process_width,
process_height,
preprocess_txt_action,
+ process_keep_original_size,
process_flip,
process_split,
process_caption,
@@ -1375,23 +1444,25 @@ def create_ui(): elif t == bool:
comp = gr.Checkbox
else:
- raise Exception(f'bad options item type: {str(t)} for key {key}')
+ raise Exception(f'bad options item type: {t} for key {key}')
- elem_id = "setting_"+key
+ elem_id = f"setting_{key}"
if info.refresh is not None:
if is_quicksettings:
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
else:
with FormRow():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
else:
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
return res
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
+
components = []
component_dict = {}
shared.settings_components = component_dict
@@ -1438,7 +1509,7 @@ def create_ui(): result = gr.HTML(elem_id="settings_result")
- quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = opts.quicksettings_list
quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
quicksettings_list = []
@@ -1458,7 +1529,7 @@ def create_ui(): current_tab.__exit__()
gr.Group()
- current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
+ current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
current_tab.__enter__()
current_row = gr.Column(variant='compact')
current_row.__enter__()
@@ -1479,7 +1550,10 @@ def create_ui(): current_row.__exit__()
current_tab.__exit__()
- with gr.TabItem("Actions"):
+ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
+ loadsave.create_ui()
+
+ with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
@@ -1487,7 +1561,7 @@ def create_ui(): unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
- with gr.TabItem("Licenses"):
+ with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@@ -1552,7 +1626,7 @@ def create_ui(): (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
- (train_interface, "Train", "ti"),
+ (train_interface, "Train", "train"),
]
interfaces += script_callbacks.ui_tabs_callback()
@@ -1565,9 +1639,9 @@ def create_ui(): for _interface, label, _ifid in interfaces:
shared.tab_names.append(label)
- with gr.Blocks(analytics_enabled=False, title="Stable Diffusion") as demo:
+ with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"):
- for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
+ for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
@@ -1577,11 +1651,21 @@ def create_ui(): for interface, label, ifid in interfaces:
if label in shared.opts.hidden_tabs:
continue
- with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
+ with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
interface.render()
+ for interface, _label, ifid in interfaces:
+ if ifid in ["extensions", "settings"]:
+ continue
+
+ loadsave.add_block(interface, ifid)
+
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
+
+ loadsave.setup_ui()
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
- audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+ gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html")
footer = footer.format(versions=versions_html())
@@ -1594,22 +1678,21 @@ def create_ui(): outputs=[text_settings, result],
)
- for i, k, item in quicksettings_list:
+ for _i, k, _item in quicksettings_list:
component = component_dict[k]
info = opts.data_labels[k]
- component.change(
+ change_handler = component.release if hasattr(component, 'release') else component.change
+ change_handler(
fn=lambda value, k=k: run_settings_single(value, key=k),
inputs=[component],
outputs=[component, text_settings],
show_progress=info.refresh is not None,
)
- text_settings.change(
- fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
- inputs=[],
- outputs=[image_cfg_scale],
- )
+ update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
+ text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
+ demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
@@ -1658,6 +1741,7 @@ def create_ui(): config_source,
bake_in_vae,
discard_weights,
+ save_metadata,
],
outputs=[
primary_model_name,
@@ -1668,82 +1752,8 @@ def create_ui(): ]
)
- ui_config_file = cmd_opts.ui_config_file
- ui_settings = {}
- settings_count = len(ui_settings)
- error_loading = False
-
- try:
- if os.path.exists(ui_config_file):
- with open(ui_config_file, "r", encoding="utf8") as file:
- ui_settings = json.load(file)
- except Exception:
- error_loading = True
- print("Error loading settings:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def loadsave(path, x):
- def apply_field(obj, field, condition=None, init_field=None):
- key = path + "/" + field
-
- if getattr(obj, 'custom_script_source', None) is not None:
- key = 'customscript/' + obj.custom_script_source + '/' + key
-
- if getattr(obj, 'do_not_save_to_config', False):
- return
-
- saved_value = ui_settings.get(key, None)
- if saved_value is None:
- ui_settings[key] = getattr(obj, field)
- elif condition and not condition(saved_value):
- pass
-
- # this warning is generally not useful;
- # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
- else:
- setattr(obj, field, saved_value)
- if init_field is not None:
- init_field(saved_value)
-
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
- apply_field(x, 'visible')
-
- if type(x) == gr.Slider:
- apply_field(x, 'value')
- apply_field(x, 'minimum')
- apply_field(x, 'maximum')
- apply_field(x, 'step')
-
- if type(x) == gr.Radio:
- apply_field(x, 'value', lambda val: val in x.choices)
-
- if type(x) == gr.Checkbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Textbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Number:
- apply_field(x, 'value')
-
- if type(x) == gr.Dropdown:
- def check_dropdown(val):
- if getattr(x, 'multiselect', False):
- return all([value in x.choices for value in val])
- else:
- return val in x.choices
-
- apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
-
- visit(txt2img_interface, loadsave, "txt2img")
- visit(img2img_interface, loadsave, "img2img")
- visit(extras_interface, loadsave, "extras")
- visit(modelmerger_interface, loadsave, "modelmerger")
- visit(train_interface, loadsave, "train")
-
- if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
- with open(ui_config_file, "w", encoding="utf8") as file:
- json.dump(ui_settings, file, indent=4)
+ loadsave.dump_defaults()
+ demo.ui_loadsave = loadsave
# Required as a workaround for change() event not triggering when loading values from ui-config.json
interp_description.value = update_interp_description(interp_method.value)
@@ -1821,7 +1831,7 @@ def versions_html(): python_version = ".".join([str(x) for x in sys.version_info[0:3]])
commit = launch.commit_hash()
- short_commit = commit[0:8]
+ tag = launch.git_tag()
if shared.xformers_available:
import xformers
@@ -1830,6 +1840,8 @@ def versions_html(): xformers_version = "N/A"
return f"""
+version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
+ •
python: <span title="{sys.version}">{python_version}</span>
•
torch: {getattr(torch, '__long_version__',torch.__version__)}
@@ -1838,7 +1850,21 @@ xformers: {xformers_version} •
gradio: {gr.__version__}
•
-commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
- •
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
+
+
+def setup_ui_api(app):
+ from pydantic import BaseModel, Field
+ from typing import List
+
+ class QuicksettingsHint(BaseModel):
+ name: str = Field(title="Name of the quicksettings field")
+ label: str = Field(title="Label of the quicksettings field")
+
+ def quicksettings_hint():
+ return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
+
+ app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
+
+ app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
diff --git a/modules/ui_common.py b/modules/ui_common.py index 3b11dcc8..27ab3ebb 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -125,7 +125,7 @@ Requested path was: {f} with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
generation_info = None
with gr.Column():
diff --git a/modules/ui_components.py b/modules/ui_components.py index 2b1da2cb..64451df7 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -62,3 +62,13 @@ class DropdownMulti(FormComponent, gr.Dropdown): def get_block_name(self):
return "dropdown"
+
+
+class DropdownEditable(FormComponent, gr.Dropdown):
+ """Same as gr.Dropdown but allows editing value"""
+ def __init__(self, **kwargs):
+ super().__init__(allow_custom_value=True, **kwargs)
+
+ def get_block_name(self):
+ return "dropdown"
+
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index efd6cda2..ed70abe5 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -2,6 +2,7 @@ import json import os.path
import sys
import time
+from datetime import datetime
import traceback
import git
@@ -11,10 +12,12 @@ import html import shutil
import errno
-from modules import extensions, shared, paths
+from modules import extensions, shared, paths, config_states
+from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
available_extensions = {"extensions": []}
+STYLE_PRIMARY = ' style="color: var(--primary-400)"'
def check_access():
@@ -30,6 +33,9 @@ def apply_and_restart(disable_list, update_list, disable_all): update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+ if update:
+ save_config_state("Backup (pre-update)")
+
update = set(update)
for ext in extensions.extensions:
@@ -50,6 +56,47 @@ def apply_and_restart(disable_list, update_list, disable_all): shared.state.need_restart = True
+def save_config_state(name):
+ current_config_state = config_states.get_config()
+ if not name:
+ name = "Config"
+ current_config_state["name"] = name
+ timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
+ filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
+ print(f"Saving backup of webui/extension state to {filename}.")
+ with open(filename, "w", encoding="utf-8") as f:
+ json.dump(current_config_state, f)
+ config_states.list_config_states()
+ new_value = next(iter(config_states.all_config_states.keys()), "Current")
+ new_choices = ["Current"] + list(config_states.all_config_states.keys())
+ return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
+
+
+def restore_config_state(confirmed, config_state_name, restore_type):
+ if config_state_name == "Current":
+ return "<span>Select a config to restore from.</span>"
+ if not confirmed:
+ return "<span>Cancelled.</span>"
+
+ check_access()
+
+ config_state = config_states.all_config_states[config_state_name]
+
+ print(f"*** Restoring webui state from backup: {restore_type} ***")
+
+ if restore_type == "extensions" or restore_type == "both":
+ shared.opts.restore_config_state_file = config_state["filepath"]
+ shared.opts.save(shared.config_filename)
+
+ if restore_type == "webui" or restore_type == "both":
+ config_states.restore_webui_config(config_state)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+ return ""
+
+
def check_updates(id_task, disable_list):
check_access()
@@ -76,6 +123,16 @@ def check_updates(id_task, disable_list): return extension_table(), ""
+def make_commit_link(commit_hash, remote, text=None):
+ if text is None:
+ text = commit_hash[:8]
+ if remote.startswith("https://github.com/"):
+ href = os.path.join(remote, "commit", commit_hash)
+ return f'<a href="{href}" target="_blank">{text}</a>'
+ else:
+ return text
+
+
def extension_table():
code = f"""<!-- {time.time()} -->
<table id="extensions">
@@ -102,13 +159,17 @@ def extension_table(): style = ""
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
- style = ' style="color: var(--primary-400)"'
+ style = STYLE_PRIMARY
+
+ version_link = ext.version
+ if ext.commit_hash and ext.remote:
+ version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
code += f"""
<tr>
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td>
- <td>{ext.version}</td>
+ <td>{version_link}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -121,6 +182,133 @@ def extension_table(): return code
+def update_config_states_table(state_name):
+ if state_name == "Current":
+ config_state = config_states.get_config()
+ else:
+ config_state = config_states.all_config_states[state_name]
+
+ config_name = config_state.get("name", "Config")
+ created_date = time.asctime(time.gmtime(config_state["created_at"]))
+ filepath = config_state.get("filepath", "<unknown>")
+
+ code = f"""<!-- {time.time()} -->"""
+
+ webui_remote = config_state["webui"]["remote"] or ""
+ webui_branch = config_state["webui"]["branch"]
+ webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
+ webui_commit_date = config_state["webui"]["commit_date"]
+ if webui_commit_date:
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
+ else:
+ webui_commit_date = "<unknown>"
+
+ remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
+
+ current_webui = config_states.get_webui_config()
+
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if current_webui["remote"] != webui_remote:
+ style_remote = STYLE_PRIMARY
+ if current_webui["branch"] != webui_branch:
+ style_branch = STYLE_PRIMARY
+ if current_webui["commit_hash"] != webui_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f"""<h2>Config Backup: {config_name}</h2>
+ <div><b>Filepath:</b> {filepath}</div>
+ <div><b>Created at:</b> {created_date}</div>"""
+
+ code += f"""<h2>WebUI State</h2>
+ <table id="config_state_webui">
+ <thead>
+ <tr>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td><label{style_remote}>{remote}</label></td>
+ <td><label{style_branch}>{webui_branch}</label></td>
+ <td><label{style_commit}>{commit_link}</label></td>
+ <td><label{style_commit}>{date_link}</label></td>
+ </tr>
+ </tbody>
+ </table>
+ """
+
+ code += """<h2>Extension State</h2>
+ <table id="config_state_extensions">
+ <thead>
+ <tr>
+ <th>Extension</th>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+ """
+
+ ext_map = {ext.name: ext for ext in extensions.extensions}
+
+ for ext_name, ext_conf in config_state["extensions"].items():
+ ext_remote = ext_conf["remote"] or ""
+ ext_branch = ext_conf["branch"] or "<unknown>"
+ ext_enabled = ext_conf["enabled"]
+ ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
+ ext_commit_date = ext_conf["commit_date"]
+ if ext_commit_date:
+ ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ else:
+ ext_commit_date = "<unknown>"
+
+ remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
+ commit_link = make_commit_link(ext_commit_hash, ext_remote)
+ date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+
+ style_enabled = ""
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if ext_name in ext_map:
+ current_ext = ext_map[ext_name]
+ current_ext.read_info_from_repo()
+ if current_ext.enabled != ext_enabled:
+ style_enabled = STYLE_PRIMARY
+ if current_ext.remote != ext_remote:
+ style_remote = STYLE_PRIMARY
+ if current_ext.branch != ext_branch:
+ style_branch = STYLE_PRIMARY
+ if current_ext.commit_hash != ext_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f"""
+ <tr>
+ <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
+ <td><label{style_remote}>{remote}</label></td>
+ <td><label{style_branch}>{ext_branch}</label></td>
+ <td><label{style_commit}>{commit_link}</label></td>
+ <td><label{style_commit}>{date_link}</label></td>
+ </tr>
+ """
+
+ code += """
+ </tbody>
+ </table>
+ """
+
+ return code
+
+
def normalize_git_url(url):
if url is None:
return ""
@@ -129,7 +317,7 @@ def normalize_git_url(url): return url
-def install_extension_from_url(dirname, url):
+def install_extension_from_url(dirname, url, branch_name=None):
check_access()
assert url, 'No URL specified'
@@ -150,10 +338,17 @@ def install_extension_from_url(dirname, url): try:
shutil.rmtree(tmpdir, True)
- with git.Repo.clone_from(url, tmpdir) as repo:
- repo.remote().fetch()
- for submodule in repo.submodules:
- submodule.update()
+ if not branch_name:
+ # if no branch is specified, use the default branch
+ with git.Repo.clone_from(url, tmpdir) as repo:
+ repo.remote().fetch()
+ for submodule in repo.submodules:
+ submodule.update()
+ else:
+ with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
+ repo.remote().fetch()
+ for submodule in repo.submodules:
+ submodule.update()
try:
os.rename(tmpdir, target_dir)
except OSError as err:
@@ -292,9 +487,11 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" def create_ui():
import modules.ui
+ config_states.list_config_states()
+
with gr.Blocks(analytics_enabled=False) as ui:
- with gr.Tabs(elem_id="tabs_extensions") as tabs:
- with gr.TabItem("Installed"):
+ with gr.Tabs(elem_id="tabs_extensions"):
+ with gr.TabItem("Installed", id="installed"):
with gr.Row(elem_id="extensions_installed_top"):
apply = gr.Button(value="Apply and restart UI", variant="primary")
@@ -327,7 +524,7 @@ def create_ui(): outputs=[extensions_table, info],
)
- with gr.TabItem("Available"):
+ with gr.TabItem("Available", id="available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
@@ -374,16 +571,41 @@ def create_ui(): outputs=[available_extensions_table, install_result]
)
- with gr.TabItem("Install from URL"):
+ with gr.TabItem("Install from URL", id="install_from_url"):
install_url = gr.Text(label="URL for extension's git repository")
+ install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary")
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
- inputs=[install_dirname, install_url],
+ inputs=[install_dirname, install_url, install_branch],
outputs=[extensions_table, install_result],
)
+ with gr.TabItem("Backup/Restore"):
+ with gr.Row(elem_id="extensions_backup_top_row"):
+ config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
+ modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
+ config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
+ config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
+ with gr.Row(elem_id="extensions_backup_top_row2"):
+ config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
+ config_save_button = gr.Button(value="Save Current Config")
+
+ config_states_info = gr.HTML("")
+ config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
+
+ config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
+
+ dummy_component = gr.Label(visible=False)
+ config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
+
+ config_states_list.change(
+ fn=update_config_states_table,
+ inputs=[config_states_list],
+ outputs=[config_states_table],
+ )
+
return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 25eb464b..2fd82e8e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,4 +1,3 @@ -import glob
import os.path
import urllib.parse
from pathlib import Path
@@ -27,7 +26,7 @@ def register_page(page): def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
- if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
@@ -69,7 +68,9 @@ class ExtraNetworksPage: pass
def link_preview(self, filename):
- return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
+ quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
+ mtime = os.path.getmtime(filename)
+ return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
def search_terms_from_path(self, filename, possible_directories=None):
abspath = os.path.abspath(filename)
@@ -89,19 +90,22 @@ class ExtraNetworksPage: subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
- if not os.path.isdir(x):
- continue
+ for root, dirs, _ in os.walk(parentdir):
+ for dirname in dirs:
+ x = os.path.join(root, dirname)
- subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
- while subdir.startswith("/"):
- subdir = subdir[1:]
+ if not os.path.isdir(x):
+ continue
- is_empty = len(os.listdir(x)) == 0
- if not is_empty and not subdir.endswith("/"):
- subdir = subdir + "/"
+ subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
+ while subdir.startswith("/"):
+ subdir = subdir[1:]
- subdirs[subdir] = 1
+ is_empty = len(os.listdir(x)) == 0
+ if not is_empty and not subdir.endswith("/"):
+ subdir = subdir + "/"
+
+ subdirs[subdir] = 1
if subdirs:
subdirs = {"": 1, **subdirs}
@@ -157,8 +161,20 @@ class ExtraNetworksPage: if metadata:
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
+ local_path = ""
+ filename = item.get("filename", "")
+ for reldir in self.allowed_directories_for_previews():
+ absdir = os.path.abspath(reldir)
+
+ if filename.startswith(absdir):
+ local_path = filename[len(absdir):]
+
+ # if this is true, the item must not be show in the default view, and must instead only be
+ # shown when searching for it
+ serach_only = "/." in local_path or "\\." in local_path
+
args = {
- "style": f"'{height}{width}{background_image}'",
+ "style": f"'display: none; {height}{width}{background_image}'",
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
@@ -168,6 +184,7 @@ class ExtraNetworksPage: "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
+ "serach_only": " search_only" if serach_only else "",
}
return self.card_page.format(**args)
@@ -209,6 +226,11 @@ def intialize(): class ExtraNetworksUi:
def __init__(self):
self.pages = None
+ """gradio HTML components related to extra networks' pages"""
+
+ self.page_contents = None
+ """HTML content of the above; empty initially, filled when extra pages have to be shown"""
+
self.stored_extra_pages = None
self.button_save_preview = None
@@ -236,17 +258,22 @@ def pages_in_preferred_order(pages): def create_ui(container, button, tabname):
ui = ExtraNetworksUi()
ui.pages = []
+ ui.pages_contents = []
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
- with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ with gr.Tabs(elem_id=tabname+"_extra_tabs"):
for page in ui.stored_extra_pages:
- with gr.Tab(page.title):
+ page_id = page.title.lower().replace(" ", "_")
- page_elem = gr.HTML(page.create_html(ui.tabname))
+ with gr.Tab(page.title, id=page_id):
+ elem_id = f"{tabname}_{page_id}_cards_html"
+ page_elem = gr.HTML('', elem_id=elem_id)
ui.pages.append(page_elem)
- filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+ page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
+
+ gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
@@ -254,19 +281,22 @@ def create_ui(container, button, tabname): def toggle_visibility(is_visible):
is_visible = not is_visible
- return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+
+ if is_visible and not ui.pages_contents:
+ refresh()
+
+ return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
state_visible = gr.State(value=False)
- button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
def refresh():
- res = []
-
for pg in ui.stored_extra_pages:
pg.refresh()
- res.append(pg.create_html(ui.tabname))
- return res
+ ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
+
+ return ui.pages_contents
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
@@ -296,7 +326,7 @@ def setup_ui(ui, gallery): is_allowed = False
for extra_page in ui.stored_extra_pages:
- if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
is_allowed = True
break
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py new file mode 100644 index 00000000..728fec9e --- /dev/null +++ b/modules/ui_loadsave.py @@ -0,0 +1,208 @@ +import json
+import os
+
+import gradio as gr
+
+from modules import errors
+from modules.ui_components import ToolButton
+
+
+class UiLoadsave:
+ """allows saving and restorig default values for gradio components"""
+
+ def __init__(self, filename):
+ self.filename = filename
+ self.ui_settings = {}
+ self.component_mapping = {}
+ self.error_loading = False
+ self.finalized_ui = False
+
+ self.ui_defaults_view = None
+ self.ui_defaults_apply = None
+ self.ui_defaults_review = None
+
+ try:
+ if os.path.exists(self.filename):
+ self.ui_settings = self.read_from_file()
+ except Exception as e:
+ self.error_loading = True
+ errors.display(e, "loading settings")
+
+ def add_component(self, path, x):
+ """adds component to the registry of tracked components"""
+
+ assert not self.finalized_ui
+
+ def apply_field(obj, field, condition=None, init_field=None):
+ key = f"{path}/{field}"
+
+ if getattr(obj, 'custom_script_source', None) is not None:
+ key = f"customscript/{obj.custom_script_source}/{key}"
+
+ if getattr(obj, 'do_not_save_to_config', False):
+ return
+
+ saved_value = self.ui_settings.get(key, None)
+ if saved_value is None:
+ self.ui_settings[key] = getattr(obj, field)
+ elif condition and not condition(saved_value):
+ pass
+ else:
+ setattr(obj, field, saved_value)
+ if init_field is not None:
+ init_field(saved_value)
+
+ if field == 'value' and key not in self.component_mapping:
+ self.component_mapping[key] = x
+
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
+ apply_field(x, 'visible')
+
+ if type(x) == gr.Slider:
+ apply_field(x, 'value')
+ apply_field(x, 'minimum')
+ apply_field(x, 'maximum')
+ apply_field(x, 'step')
+
+ if type(x) == gr.Radio:
+ apply_field(x, 'value', lambda val: val in x.choices)
+
+ if type(x) == gr.Checkbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Textbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Number:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Dropdown:
+ def check_dropdown(val):
+ if getattr(x, 'multiselect', False):
+ return all(value in x.choices for value in val)
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+
+ def check_tab_id(tab_id):
+ tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
+ if type(tab_id) == str:
+ tab_ids = [t.id for t in tab_items]
+ return tab_id in tab_ids
+ elif type(tab_id) == int:
+ return 0 <= tab_id < len(tab_items)
+ else:
+ return False
+
+ if type(x) == gr.Tabs:
+ apply_field(x, 'selected', check_tab_id)
+
+ def add_block(self, x, path=""):
+ """adds all components inside a gradio block x to the registry of tracked components"""
+
+ if hasattr(x, 'children'):
+ if isinstance(x, gr.Tabs) and x.elem_id is not None:
+ # Tabs element can't have a label, have to use elem_id instead
+ self.add_component(f"{path}/Tabs@{x.elem_id}", x)
+ for c in x.children:
+ self.add_block(c, path)
+ elif x.label is not None:
+ self.add_component(f"{path}/{x.label}", x)
+
+ def read_from_file(self):
+ with open(self.filename, "r", encoding="utf8") as file:
+ return json.load(file)
+
+ def write_to_file(self, current_ui_settings):
+ with open(self.filename, "w", encoding="utf8") as file:
+ json.dump(current_ui_settings, file, indent=4)
+
+ def dump_defaults(self):
+ """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
+
+ if self.error_loading and os.path.exists(self.filename):
+ return
+
+ self.write_to_file(self.ui_settings)
+
+ def iter_changes(self, current_ui_settings, values):
+ """
+ given a dictionary with defaults from a file and current values from gradio elements, returns
+ an iterator over tuples of values that are not the same between the file and the current;
+ tuple contents are: path, old value, new value
+ """
+
+ for (path, component), new_value in zip(self.component_mapping.items(), values):
+ old_value = current_ui_settings.get(path)
+
+ choices = getattr(component, 'choices', None)
+ if isinstance(new_value, int) and choices:
+ if new_value >= len(choices):
+ continue
+
+ new_value = choices[new_value]
+
+ if new_value == old_value:
+ continue
+
+ if old_value is None and new_value == '' or new_value == []:
+ continue
+
+ yield path, old_value, new_value
+
+ def ui_view(self, *values):
+ text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
+
+ for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
+ if old_value is None:
+ old_value = "<span class='ui-defaults-none'>None</span>"
+
+ text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
+
+ if len(text) == 1:
+ text.append("<tr><td colspan=3>No changes</td></tr>")
+
+ text.append("</tbody>")
+ return "".join(text)
+
+ def ui_apply(self, *values):
+ num_changed = 0
+
+ current_ui_settings = self.read_from_file()
+
+ for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
+ num_changed += 1
+ current_ui_settings[path] = new_value
+
+ if num_changed == 0:
+ return "No changes."
+
+ self.write_to_file(current_ui_settings)
+
+ return f"Wrote {num_changed} changes."
+
+ def create_ui(self):
+ """creates ui elements for editing defaults UI, without adding any logic to them"""
+
+ gr.HTML(
+ f"This page allows you to change default values in UI elements on other tabs.<br />"
+ f"Make your changes, press 'View changes' to review the changed default values,<br />"
+ f"then press 'Apply' to write them to {self.filename}.<br />"
+ f"New defaults will apply after you restart the UI.<br />"
+ )
+
+ with gr.Row():
+ self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
+ self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
+
+ self.ui_defaults_review = gr.HTML("")
+
+ def setup_ui(self):
+ """adds logic to elements created with create_ui; all add_block class must be made before this"""
+
+ assert not self.finalized_ui
+ self.finalized_ui = True
+
+ self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
+ self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index b418d955..c7dc1154 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,5 +1,5 @@ import gradio as gr
-from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+from modules import scripts, shared, ui_common, postprocessing, call_queue
import modules.generation_parameters_copypaste as parameters_copypaste
@@ -9,13 +9,13 @@ def create_ui(): with gr.Row().style(equal_height=False, variant='compact'):
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
+ with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
- with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
+ with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
- with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 21945235..f05049e1 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -23,7 +23,7 @@ def register_tmp_file(gradio, filename): def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
- return any([filename in fileset for fileset in gradio.temp_file_sets])
+ return any(filename in fileset for fileset in gradio.temp_file_sets)
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
@@ -36,7 +36,7 @@ def save_pil_to_file(pil_image, dir=None): if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
- file_obj = Savedfile(already_saved_as)
+ file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
return file_obj
if shared.opts.temp_dir != "":
@@ -72,7 +72,7 @@ def cleanup_tmpdr(): if temp_dir == "" or not os.path.isdir(temp_dir):
return
- for root, dirs, files in os.walk(temp_dir, topdown=False):
+ for root, _, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
diff --git a/modules/upscaler.py b/modules/upscaler.py index e2eaa730..8acb6e96 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -2,8 +2,6 @@ import os from abc import abstractmethod import PIL -import numpy as np -import torch from PIL import Image import modules.shared @@ -43,9 +41,9 @@ class Upscaler: os.makedirs(self.model_path, exist_ok=True) try: - import cv2 + import cv2 # noqa: F401 self.can_tile = True - except: + except Exception: pass @abstractmethod @@ -57,7 +55,7 @@ class Upscaler: dest_w = int(img.width * scale) dest_h = int(img.height * scale) - for i in range(3): + for _ in range(3): shape = (img.width, img.height) img = self.do_upscale(img, selected_model) diff --git a/modules/xlmr.py b/modules/xlmr.py index beab3fdf..e056c3f6 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -1,4 +1,4 @@ -from transformers import BertPreTrainedModel,BertModel,BertConfig +from transformers import BertPreTrainedModel, BertConfig import torch.nn as nn import torch from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..c88907be --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[tool.ruff] + +target-version = "py39" + +extend-select = [ + "B", + "C", + "I", +] + +exclude = [ + "extensions", + "extensions-disabled", +] + +ignore = [ + "E501", # Line too long + "E731", # Do not assign a `lambda` expression, use a `def` + + "I001", # Import block is un-sorted or un-formatted + "C901", # Function is too complex + "C408", # Rewrite as a literal + +] + +[tool.ruff.per-file-ignores] +"webui.py" = ["E402"] # Module level import not at top of file + +[tool.ruff.flake8-bugbear] +# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. +extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
\ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c72b2927..35987a13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ +astunparse
blendmodes
accelerate
basicsr
fonts
font-roboto
gfpgan
-gradio==3.23
-invisible-watermark
+gradio==3.29.0
numpy
omegaconf
opencv-contrib-python
diff --git a/requirements_versions.txt b/requirements_versions.txt index df65431a..7bce02e5 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,10 +1,10 @@ blendmodes==2022
transformers==4.25.1
-accelerate==0.12.0
+accelerate==0.18.0
basicsr==1.4.2
gfpgan==1.3.8
-gradio==3.23
-numpy==1.23.3
+gradio==3.29.0
+numpy==1.23.5
Pillow==9.4.0
realesrgan==0.3.0
torch
@@ -25,6 +25,6 @@ lark==1.1.2 inflection==0.5.1
GitPython==3.1.30
torchsde==0.2.5
-safetensors==0.3.0
+safetensors==0.3.1
httpcore<=0.15
fastapi==0.94.0
@@ -7,7 +7,7 @@ function gradioApp() { } function get_uiCurrentTab() { - return gradioApp().querySelector('#tabs button:not(.border-transparent)') + return gradioApp().querySelector('#tabs button.selected') } function get_uiCurrentTabContent() { diff --git a/scripts/custom_code.py b/scripts/custom_code.py index d29113e6..cc6f0d49 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -1,8 +1,39 @@ import modules.scripts as scripts
import gradio as gr
+import ast
+import copy
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import cmd_opts
+
+
+def convertExpr2Expression(expr):
+ expr.lineno = 0
+ expr.col_offset = 0
+ result = ast.Expression(expr.value, lineno=0, col_offset = 0)
+
+ return result
+
+
+def exec_with_return(code, module):
+ """
+ like exec() but can return values
+ https://stackoverflow.com/a/52361938/5862977
+ """
+ code_ast = ast.parse(code)
+
+ init_ast = copy.deepcopy(code_ast)
+ init_ast.body = code_ast.body[:-1]
+
+ last_ast = copy.deepcopy(code_ast)
+ last_ast.body = code_ast.body[-1:]
+
+ exec(compile(init_ast, "<ast>", "exec"), module.__dict__)
+ if type(last_ast.body[0]) == ast.Expr:
+ return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__)
+ else:
+ exec(compile(last_ast, "<ast>", "exec"), module.__dict__)
+
class Script(scripts.Script):
@@ -13,12 +44,23 @@ class Script(scripts.Script): return cmd_opts.allow_code
def ui(self, is_img2img):
- code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
+ example = """from modules.processing import process_images
+
+p.width = 768
+p.height = 768
+p.batch_size = 2
+p.steps = 10
+
+return process_images(p)
+"""
+
- return [code]
+ code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code"))
+ indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level"))
+ return [code, indent_level]
- def run(self, p, code):
+ def run(self, p, code, indent_level):
assert cmd_opts.allow_code, '--allow-code option must be enabled'
display_result_data = [[], -1, ""]
@@ -29,13 +71,20 @@ class Script(scripts.Script): display_result_data[2] = i
from types import ModuleType
- compiled = compile(code, '', 'exec')
module = ModuleType("testmodule")
module.__dict__.update(globals())
module.p = p
module.display = display
- exec(compiled, module.__dict__)
+
+ indent = " " * indent_level
+ indented = code.replace('\n', f"\n{indent}")
+ body = f"""def __webuitemp__():
+{indent}{indented}
+__webuitemp__()"""
+
+ result = exec_with_return(body, module)
+
+ if isinstance(result, Processed):
+ return result
return Processed(p, *display_result_data)
-
-
\ No newline at end of file diff --git a/scripts/loopback.py b/scripts/loopback.py index d3065fe6..ad6609be 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -84,7 +84,7 @@ class Script(scripts.Script): p.color_corrections = initial_color_corrections
if append_interrogation != "None":
- p.prompt = original_prompt + ", " if original_prompt != "" else ""
+ p.prompt = f"{original_prompt}, " if original_prompt else ""
if append_interrogation == "CLIP":
p.prompt += shared.interrogator.interrogate(p.init_images[0])
elif append_interrogation == "DeepBooru":
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 0906da6a..665dbe89 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -7,9 +7,9 @@ import modules.scripts as scripts import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
@@ -72,7 +72,7 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0 height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
- np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
+ _np_src_image[:] * (1. - np_mask_rgb)
np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
@@ -275,7 +275,7 @@ class Script(scripts.Script): if opts.samples_save:
for img in all_processed_images:
- images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+ images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p)
if opts.grid_save and not unwanted_grid_because_of_img_count:
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index d8feda00..c0bbecc1 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -4,9 +4,9 @@ import modules.scripts as scripts import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images, devices
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
@@ -138,7 +138,7 @@ class Script(scripts.Script): combined_image = images.combine_grid(grid)
if opts.samples_save:
- images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p)
+ images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
processed = Processed(p, [combined_image], initial_seed, initial_info)
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 11eab31a..edb70ac0 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -4,8 +4,8 @@ import numpy as np from modules import scripts_postprocessing, shared
import gradio as gr
-from modules.ui_components import FormRow
-
+from modules.ui_components import FormRow, ToolButton
+from modules.ui import switch_values_symbol
upscale_cache = {}
@@ -25,9 +25,12 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
with FormRow():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
+ with gr.Column(elem_id="upscaling_column_size", scale=4):
+ upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
+ with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
+ upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow():
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
@@ -36,6 +39,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
+ upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
@@ -94,13 +98,13 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
+ pp.info["Postprocess upscaler"] = upscaler1.name
if upscaler2 and upscaler_2_visibility > 0:
second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
- pp.info[f"Postprocess upscaler 2"] = upscaler2.name
+ pp.info["Postprocess upscaler 2"] = upscaler2.name
pp.image = upscaled_image
@@ -130,4 +134,4 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): assert upscaler1, f'could not find upscaler named {upscaler_name}'
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
+ pp.info["Postprocess upscaler"] = upscaler1.name
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index e9b11517..fb06beab 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -1,14 +1,11 @@ import math
-from collections import namedtuple
-from copy import copy
-import random
import modules.scripts as scripts
import gradio as gr
from modules import images
-from modules.processing import process_images, Processed
-from modules.shared import opts, cmd_opts, state
+from modules.processing import process_images
+from modules.shared import opts, state
import modules.sd_samplers
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 76dc5778..27af5ff6 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,6 +1,4 @@ import copy
-import math
-import os
import random
import sys
import traceback
@@ -11,8 +9,7 @@ import gradio as gr from modules import sd_samplers
from modules.processing import Processed, process_images
-from PIL import Image
-from modules.shared import opts, cmd_opts, state
+from modules.shared import state
def process_string_tag(tag):
@@ -159,7 +156,7 @@ class Script(scripts.Script): images = []
all_prompts = []
infotexts = []
- for n, args in enumerate(jobs):
+ for args in jobs:
state.job = f"{state.job_no + 1} out of {state.job_count}"
copy_p = copy.copy(p)
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 332d76d9..0b1d3096 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -4,9 +4,9 @@ import modules.scripts as scripts import gradio as gr
from PIL import Image
-from modules import processing, shared, sd_samplers, images, devices
+from modules import processing, shared, images, devices
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
@@ -56,7 +56,7 @@ class Script(scripts.Script): work = []
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
work.append(tiledata[2])
@@ -85,7 +85,7 @@ class Script(scripts.Script): work_results += processed.images
image_index = 0
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
image_index += 1
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 3895a795..38a20381 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -10,15 +10,13 @@ import numpy as np import modules.scripts as scripts
import gradio as gr
-from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
import modules.shared as shared
import modules.sd_samplers
import modules.sd_models
import modules.sd_vae
-import glob
-import os
import re
from modules.ui_components import ToolButton
@@ -86,7 +84,7 @@ def apply_checkpoint(p, x, xs): info = modules.sd_models.get_closet_checkpoint_match(x)
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
- modules.sd_models.reload_model_weights(shared.sd_model, info)
+ p.override_settings['sd_model_checkpoint'] = info.hash
def confirm_checkpoints(p, xs):
@@ -211,7 +209,8 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
+ AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
@@ -221,7 +220,7 @@ axis_options = [ AxisOption("Denoising", float, apply_field("denoising_strength")),
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
- AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
+ AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
@@ -315,7 +314,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend return Processed(p, [])
z_count = len(zs)
- sub_grids = [None] * z_count
+
for i in range(z_count):
start_index = (i * len(xs) * len(ys)) + i
end_index = start_index + len(xs) * len(ys)
@@ -345,7 +344,7 @@ class SharedSettingsStackHelper(object): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.vae = opts.sd_vae
self.uni_pc_order = opts.uni_pc_order
-
+
def __exit__(self, exc_type, exc_value, tb):
opts.data["sd_vae"] = self.vae
opts.data["uni_pc_order"] = self.uni_pc_order
@@ -374,16 +373,19 @@ class Script(scripts.Script): with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
+ x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
+ y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
with gr.Row():
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
+ z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"):
@@ -395,60 +397,80 @@ class Script(scripts.Script): include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
with gr.Column():
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
-
+
with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
- def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
- return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
+ def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):
+ return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown
- xy_swap_args = [x_type, x_values, y_type, y_values]
+ xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]
swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
- yz_swap_args = [y_type, y_values, z_type, z_values]
+ yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]
swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
- xz_swap_args = [x_type, x_values, z_type, z_values]
+ xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
def fill(x_type):
axis = self.current_axis_options[x_type]
- return ", ".join(axis.choices()) if axis.choices else gr.update()
-
- fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
- fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
- fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
-
- def select_axis(x_type):
- return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
-
- x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
- y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
- z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
+ return axis.choices() if axis.choices else gr.update()
+
+ fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
+ fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
+ fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
+
+ def select_axis(axis_type,axis_values_dropdown):
+ choices = self.current_axis_options[axis_type].choices
+ has_choices = choices is not None
+ current_values = axis_values_dropdown
+ if has_choices:
+ choices = choices()
+ if isinstance(current_values,str):
+ current_values = current_values.split(",")
+ current_values = list(filter(lambda x: x in choices, current_values))
+ return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
+
+ x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
+ y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
+ z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
+
+ def get_dropdown_update_from_params(axis,params):
+ val_key = f"{axis} Values"
+ vals = params.get(val_key,"")
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
+ return gr.update(value = valslist)
self.infotext_fields = (
(x_type, "X Type"),
(x_values, "X Values"),
+ (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
(y_type, "Y Type"),
(y_values, "Y Values"),
+ (y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
(z_type, "Z Type"),
(z_values, "Z Values"),
+ (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
)
- return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
+ return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
- def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
+ def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
if not opts.return_grid:
p.batch_size = 1
- def process_axis(opt, vals):
+ def process_axis(opt, vals, vals_dropdown):
if opt.label == 'Nothing':
return [0]
- valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
+ if opt.choices is not None:
+ valslist = vals_dropdown
+ else:
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
if opt.type == int:
valslist_ext = []
@@ -466,7 +488,7 @@ class Script(scripts.Script): start = int(mc.group(1))
end = int(mc.group(2))
num = int(mc.group(3)) if mc.group(3) is not None else 1
-
+
valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
else:
valslist_ext.append(val)
@@ -488,7 +510,7 @@ class Script(scripts.Script): start = float(mc.group(1))
end = float(mc.group(2))
num = int(mc.group(3)) if mc.group(3) is not None else 1
-
+
valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
else:
valslist_ext.append(val)
@@ -506,13 +528,19 @@ class Script(scripts.Script): return valslist
x_opt = self.current_axis_options[x_type]
- xs = process_axis(x_opt, x_values)
+ if x_opt.choices is not None:
+ x_values = ",".join(x_values_dropdown)
+ xs = process_axis(x_opt, x_values, x_values_dropdown)
y_opt = self.current_axis_options[y_type]
- ys = process_axis(y_opt, y_values)
+ if y_opt.choices is not None:
+ y_values = ",".join(y_values_dropdown)
+ ys = process_axis(y_opt, y_values, y_values_dropdown)
z_opt = self.current_axis_options[z_type]
- zs = process_axis(z_opt, z_values)
+ if z_opt.choices is not None:
+ z_values = ",".join(z_values_dropdown)
+ zs = process_axis(z_opt, z_values, z_values_dropdown)
# this could be moved to common code, but unlikely to be ever triggered anywhere else
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
@@ -676,7 +704,7 @@ class Script(scripts.Script): if not include_sub_grids:
# Done with sub-grids, drop all related information:
- for sg in range(z_count):
+ for _ in range(z_count):
del processed.images[1]
del processed.all_prompts[1]
del processed.all_seeds[1]
@@ -125,6 +125,10 @@ div.gradio-html.min{ text-decoration: none;
}
+a{
+ font-weight: bold;
+ cursor: pointer;
+}
/* general styled components */
@@ -246,7 +250,7 @@ button.custom-button{ }
}
-#txt2img_gallery img, #img2img_gallery img{
+#txt2img_gallery img, #img2img_gallery img, #extras_gallery img{
object-fit: scale-down;
}
#txt2img_actions_column, #img2img_actions_column {
@@ -293,7 +297,12 @@ button.custom-button{ margin-left: -0.75em
}
-#txtimg_hr_finalres .resolution{
+#img2img_scale_resolution_preview.block{
+ display: flex;
+ align-items: end;
+}
+
+#txtimg_hr_finalres .resolution, #img2img_scale_resolution_preview .resolution{
font-weight: bold;
}
@@ -312,6 +321,10 @@ div.dimensions-tools{ align-content: center;
}
+div#extras_scale_to_tab div.form{
+ flex-direction: row;
+}
+
#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{
height: 480px !important;
max-height: 480px !important;
@@ -333,6 +346,18 @@ div.dimensions-tools{ overflow-wrap: break-word;
}
+#img2img_column_batch{
+ align-self: end;
+ margin-bottom: 0.9em;
+}
+
+#img2img_unused_scale_by_slider{
+ visibility: hidden;
+ width: 0.5em;
+ max-width: 0.5em;
+ min-width: 0.5em;
+}
+
/* settings */
#quicksettings {
width: fit-content;
@@ -376,6 +401,22 @@ div.dimensions-tools{ margin: 0 1.2em;
}
+table.settings-value-table{
+ background: white;
+ border-collapse: collapse;
+ margin: 1em;
+ border: 4px solid white;
+}
+
+table.settings-value-table td{
+ padding: 0.4em;
+ border: 1px solid #ccc;
+ max-width: 36em;
+}
+
+.ui-defaults-none{
+ color: #aaa !important;
+}
/* live preview */
.progressDiv{
@@ -513,6 +554,8 @@ div.dimensions-tools{ #lightboxModal > img.modalImageFullscreen{
object-fit: contain;
height: 100%;
+ width: 100%;
+ min-height: 0;
}
.modalPrev,
@@ -642,6 +685,12 @@ footer { /* extra networks UI */
+.extra-network-cards{
+ height: 725px;
+ overflow: scroll;
+ resize: vertical;
+}
+
.extra-networks > div > [id *= '_extra_']{
margin: 0.3em;
}
diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 37cac4fb..10ab81c9 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" +export TORCH_COMMAND="pip install torch torchvision" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export PYTORCH_ENABLE_MPS_FALLBACK=1 diff --git a/webui-user.sh b/webui-user.sh index bfa53cb7..49a426ff 100644 --- a/webui-user.sh +++ b/webui-user.sh @@ -43,4 +43,7 @@ # Uncomment to enable accelerated launch #export ACCELERATE="True" +# Uncomment to disable TCMalloc +#export NO_TCMALLOC="True" + ########################################### @@ -5,6 +5,9 @@ import importlib import signal
import re
import warnings
+import json
+from threading import Thread
+
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -13,31 +16,34 @@ from packaging import version import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors
+from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer = timer.Timer()
import torch
-import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
+import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
+warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
+
+
startup_timer.record("import torch")
import gradio
startup_timer.record("import gradio")
-import ldm.modules.encoders.modules
+import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules.call_queue import wrap_queued_call, queue_lock
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
+from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -67,11 +73,51 @@ else: server_name = "0.0.0.0" if cmd_opts.listen else None
+def fix_asyncio_event_loop_policy():
+ """
+ The default `asyncio` event loop policy only automatically creates
+ event loops in the main threads. Other threads must create event
+ loops explicitly or `asyncio.get_event_loop` (and therefore
+ `.IOLoop.current`) will fail. Installing this policy allows event
+ loops to be created automatically on any thread, matching the
+ behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
+ """
+
+ import asyncio
+
+ if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
+ # "Any thread" and "selector" should be orthogonal, but there's not a clean
+ # interface for composing policies so pick the right base.
+ _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
+ else:
+ _BasePolicy = asyncio.DefaultEventLoopPolicy
+
+ class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
+ """Event loop policy that allows loop creation on any thread.
+ Usage::
+
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
+ """
+
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
+ try:
+ return super().get_event_loop()
+ except (RuntimeError, AssertionError):
+ # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
+ # and changed to a RuntimeError in 3.4.3.
+ # "There is no current event loop in thread %r"
+ loop = self.new_event_loop()
+ self.set_event_loop(loop)
+ return loop
+
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
+
+
def check_versions():
if shared.cmd_opts.skip_version_check:
return
- expected_torch_version = "1.13.1"
+ expected_torch_version = "2.0.0"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
@@ -84,7 +130,7 @@ there are reports of issues with training tab on the latest version. Use --skip-version-check commandline argument to disable this check.
""".strip())
- expected_xformers_version = "0.0.16rc425"
+ expected_xformers_version = "0.0.17"
if shared.xformers_available:
import xformers
@@ -99,12 +145,27 @@ Use --skip-version-check commandline argument to disable this check. def initialize():
+ fix_asyncio_event_loop_policy()
+
check_versions()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
+ config_state_file = shared.opts.restore_config_state_file
+ shared.opts.restore_config_state_file = ""
+ shared.opts.save(shared.config_filename)
+
+ if os.path.isfile(config_state_file):
+ print(f"*** About to restore extension state from file: {config_state_file}")
+ with open(config_state_file, "r", encoding="utf-8") as f:
+ config_state = json.load(f)
+ config_states.restore_extension_config(config_state)
+ startup_timer.record("restore extension config")
+ elif config_state_file:
+ print(f"!!! Config state backup not found: {config_state_file}")
+
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
@@ -120,9 +181,6 @@ def initialize(): gfpgan.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
- modelloader.list_builtin_upscalers()
- startup_timer.record("list builtin upscalers")
-
modules.scripts.load_scripts()
startup_timer.record("load scripts")
@@ -135,21 +193,14 @@ def initialize(): modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
- try:
- modules.sd_models.load_model()
- except Exception as e:
- errors.display(e, "loading stable diffusion model")
- print("", file=sys.stderr)
- print("Stable diffusion model failed to load, exiting", file=sys.stderr)
- exit(1)
- startup_timer.record("load SD checkpoint")
-
- shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+ # load model in parallel to other startup stuff
+ Thread(target=lambda: shared.sd_model).start()
- shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
startup_timer.record("opts onchange")
shared.reload_hypernetworks()
@@ -212,6 +263,8 @@ def wait_on_server(demo=None): time.sleep(0.5)
demo.close()
time.sleep(0.5)
+
+ modules.script_callbacks.app_reload_callback()
break
@@ -227,7 +280,6 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-
def webui():
launch_api = cmd_opts.api
initialize()
@@ -254,12 +306,23 @@ def webui(): for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+ # this restores the missing /docs endpoint
+ if launch_api and not hasattr(FastAPI, 'original_setup'):
+ def fastapi_setup(self):
+ self.docs_url = "/docs"
+ self.redoc_url = "/redoc"
+ self.original_setup()
+
+ FastAPI.original_setup = FastAPI.setup
+ FastAPI.setup = fastapi_setup
+
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port,
ssl_keyfile=cmd_opts.tls_keyfile,
ssl_certfile=cmd_opts.tls_certfile,
+ ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
inbrowser=cmd_opts.autolaunch,
@@ -279,6 +342,7 @@ def webui(): setup_middleware(app)
modules.progress.setup_progress_api(app)
+ modules.ui.setup_ui_api(app)
if launch_api:
create_api(app)
@@ -290,6 +354,11 @@ def webui(): print(f"Startup time: {startup_timer.summary()}.")
+ if cmd_opts.subpath:
+ redirector = FastAPI()
+ redirector.get("/")
+ gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
+
wait_on_server(shared.demo)
print('Restarting UI...')
@@ -301,9 +370,21 @@ def webui(): extensions.list_extensions()
startup_timer.record("list extensions")
+ config_state_file = shared.opts.restore_config_state_file
+ shared.opts.restore_config_state_file = ""
+ shared.opts.save(shared.config_filename)
+
+ if os.path.isfile(config_state_file):
+ print(f"*** About to restore extension state from file: {config_state_file}")
+ with open(config_state_file, "r", encoding="utf-8") as f:
+ config_state = json.load(f)
+ config_states.restore_extension_config(config_state)
+ startup_timer.record("restore extension config")
+ elif config_state_file:
+ print(f"!!! Config state backup not found: {config_state_file}")
+
localization.list_localizations(cmd_opts.localizations_dir)
- modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
startup_timer.record("load scripts")
@@ -23,7 +23,7 @@ fi # Install directory without trailing slash if [[ -z "${install_dir}" ]] then - install_dir="/home/$(whoami)" + install_dir="$(pwd)" fi # Name of the subdirectory (defaults to stable-diffusion-webui) @@ -113,12 +113,13 @@ case "$gpu_info" in printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" printf "\n%s\n" "${delimiter}" ;; - *) + *) ;; esac if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] then - export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" + # AMD users will still use torch 1.13 because 2.0 does not seem to work. + export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2" fi for preq in "${GIT}" "${python_cmd}" @@ -152,35 +153,57 @@ else cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } fi -printf "\n%s\n" "${delimiter}" -printf "Create and activate python venv" -printf "\n%s\n" "${delimiter}" -cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } -if [[ ! -d "${venv_dir}" ]] -then - "${python_cmd}" -m venv "${venv_dir}" - first_launch=1 -fi -# shellcheck source=/dev/null -if [[ -f "${venv_dir}"/bin/activate ]] +if [[ -z "${VIRTUAL_ENV}" ]]; then - source "${venv_dir}"/bin/activate + printf "\n%s\n" "${delimiter}" + printf "Create and activate python venv" + printf "\n%s\n" "${delimiter}" + cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } + if [[ ! -d "${venv_dir}" ]] + then + "${python_cmd}" -m venv "${venv_dir}" + first_launch=1 + fi + # shellcheck source=/dev/null + if [[ -f "${venv_dir}"/bin/activate ]] + then + source "${venv_dir}"/bin/activate + else + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" + printf "\n%s\n" "${delimiter}" + exit 1 + fi else printf "\n%s\n" "${delimiter}" - printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" + printf "python venv already activate: ${VIRTUAL_ENV}" printf "\n%s\n" "${delimiter}" - exit 1 fi +# Try using TCMalloc on Linux +prepare_tcmalloc() { + if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then + TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)" + if [[ ! -z "${TCMALLOC}" ]]; then + echo "Using TCMalloc: ${TCMALLOC}" + export LD_PRELOAD="${TCMALLOC}" + else + printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" + fi + fi +} + if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] then printf "\n%s\n" "${delimiter}" printf "Accelerating launch.py..." printf "\n%s\n" "${delimiter}" + prepare_tcmalloc exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." - printf "\n%s\n" "${delimiter}" + printf "\n%s\n" "${delimiter}" + prepare_tcmalloc exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi |