aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xmodules/processing.py340
-rw-r--r--modules/processing_scripts/refiner.py16
-rw-r--r--modules/sd_samplers_common.py7
-rw-r--r--scripts/xyz_grid.py124
4 files changed, 282 insertions, 205 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 6ad105d7..007a4e05 100755
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
import json
import logging
import math
import os
import sys
import hashlib
+from dataclasses import dataclass, field
import torch
import numpy as np
@@ -11,7 +13,7 @@ from PIL import Image, ImageOps
import random
import cv2
from skimage import exposure
-from typing import Any, Dict, List
+from typing import Any
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
@@ -104,97 +106,126 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+@dataclass(repr=False)
class StableDiffusionProcessing:
- """
- The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
- """
+ sd_model: object = None
+ outpath_samples: str = None
+ outpath_grids: str = None
+ prompt: str = ""
+ prompt_for_display: str = None
+ negative_prompt: str = ""
+ styles: list[str] = field(default_factory=list)
+ seed: int = -1
+ subseed: int = -1
+ subseed_strength: float = 0
+ seed_resize_from_h: int = -1
+ seed_resize_from_w: int = -1
+ seed_enable_extras: bool = True
+ sampler_name: str = None
+ batch_size: int = 1
+ n_iter: int = 1
+ steps: int = 50
+ cfg_scale: float = 7.0
+ width: int = 512
+ height: int = 512
+ restore_faces: bool = None
+ tiling: bool = None
+ do_not_save_samples: bool = False
+ do_not_save_grid: bool = False
+ extra_generation_params: dict[str, Any] = None
+ overlay_images: list = None
+ eta: float = None
+ do_not_reload_embeddings: bool = False
+ denoising_strength: float = 0
+ ddim_discretize: str = None
+ s_min_uncond: float = None
+ s_churn: float = None
+ s_tmax: float = None
+ s_tmin: float = None
+ s_noise: float = None
+ override_settings: dict[str, Any] = None
+ override_settings_restore_afterwards: bool = True
+ sampler_index: int = None
+ refiner_checkpoint: str = None
+ refiner_switch_at: float = None
+ token_merging_ratio = 0
+ token_merging_ratio_hr = 0
+ disable_extra_networks: bool = False
+
+ script_args: list = None
+
cached_uc = [None, None]
cached_c = [None, None]
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
- if sampler_index is not None:
+ sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
+ is_using_inpainting_conditioning: bool = field(default=False, init=False)
+ paste_to: tuple | None = field(default=None, init=False)
+
+ is_hr_pass: bool = field(default=False, init=False)
+
+ c: tuple = field(default=None, init=False)
+ uc: tuple = field(default=None, init=False)
+
+ rng: rng.ImageRNG | None = field(default=None, init=False)
+ step_multiplier: int = field(default=1, init=False)
+ color_corrections: list = field(default=None, init=False)
+
+ scripts: list = field(default=None, init=False)
+ all_prompts: list = field(default=None, init=False)
+ all_negative_prompts: list = field(default=None, init=False)
+ all_seeds: list = field(default=None, init=False)
+ all_subseeds: list = field(default=None, init=False)
+ iteration: int = field(default=0, init=False)
+ main_prompt: str = field(default=None, init=False)
+ main_negative_prompt: str = field(default=None, init=False)
+
+ prompts: list = field(default=None, init=False)
+ negative_prompts: list = field(default=None, init=False)
+ seeds: list = field(default=None, init=False)
+ subseeds: list = field(default=None, init=False)
+ extra_network_data: dict = field(default=None, init=False)
+
+ user: str = field(default=None, init=False)
+
+ sd_model_name: str = field(default=None, init=False)
+ sd_model_hash: str = field(default=None, init=False)
+ sd_vae_name: str = field(default=None, init=False)
+ sd_vae_hash: str = field(default=None, init=False)
+
+ def __post_init__(self):
+ if self.sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
- self.outpath_samples: str = outpath_samples
- self.outpath_grids: str = outpath_grids
- self.prompt: str = prompt
- self.prompt_for_display: str = None
- self.negative_prompt: str = (negative_prompt or "")
- self.styles: list = styles or []
- self.seed: int = seed
- self.subseed: int = subseed
- self.subseed_strength: float = subseed_strength
- self.seed_resize_from_h: int = seed_resize_from_h
- self.seed_resize_from_w: int = seed_resize_from_w
- self.sampler_name: str = sampler_name
- self.batch_size: int = batch_size
- self.n_iter: int = n_iter
- self.steps: int = steps
- self.cfg_scale: float = cfg_scale
- self.width: int = width
- self.height: int = height
- self.restore_faces: bool = restore_faces
- self.tiling: bool = tiling
- self.do_not_save_samples: bool = do_not_save_samples
- self.do_not_save_grid: bool = do_not_save_grid
- self.extra_generation_params: dict = extra_generation_params or {}
- self.overlay_images = overlay_images
- self.eta = eta
- self.do_not_reload_embeddings = do_not_reload_embeddings
- self.paste_to = None
- self.color_corrections = None
- self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
- self.ddim_discretize = ddim_discretize or opts.ddim_discretize
- self.s_min_uncond = s_min_uncond or opts.s_min_uncond
- self.s_churn = s_churn or opts.s_churn
- self.s_tmin = s_tmin or opts.s_tmin
- self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
- self.s_noise = s_noise if s_noise is not None else opts.s_noise
- self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
- self.override_settings_restore_afterwards = override_settings_restore_afterwards
- self.is_using_inpainting_conditioning = False
- self.disable_extra_networks = False
- self.token_merging_ratio = 0
- self.token_merging_ratio_hr = 0
-
- if not seed_enable_extras:
+ self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
+ self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
+ self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
+ self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
+ self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
+
+ self.extra_generation_params = self.extra_generation_params or {}
+ self.override_settings = self.override_settings or {}
+ self.script_args = self.script_args or {}
+
+ self.refiner_checkpoint_info = None
+
+ if not self.seed_enable_extras:
self.subseed = -1
self.subseed_strength = 0
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
- self.scripts = None
- self.script_args = script_args
- self.all_prompts = None
- self.all_negative_prompts = None
- self.all_seeds = None
- self.all_subseeds = None
- self.iteration = 0
- self.is_hr_pass = False
- self.sampler = None
- self.main_prompt = None
- self.main_negative_prompt = None
-
- self.prompts = None
- self.negative_prompts = None
- self.extra_network_data = None
- self.seeds = None
- self.subseeds = None
-
- self.step_multiplier = 1
self.cached_uc = StableDiffusionProcessing.cached_uc
self.cached_c = StableDiffusionProcessing.cached_c
- self.uc = None
- self.c = None
- self.rng: rng.ImageRNG = None
-
- self.user = None
@property
def sd_model(self):
return shared.sd_model
+ @sd_model.setter
+ def sd_model(self, value):
+ pass
+
def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@@ -408,7 +439,10 @@ class Processed:
self.batch_size = p.batch_size
self.restore_faces = p.restore_faces
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
- self.sd_model_hash = shared.sd_model.sd_model_hash
+ self.sd_model_name = p.sd_model_name
+ self.sd_model_hash = p.sd_model_hash
+ self.sd_vae_name = p.sd_vae_name
+ self.sd_vae_hash = p.sd_vae_hash
self.seed_resize_from_w = p.seed_resize_from_w
self.seed_resize_from_h = p.seed_resize_from_h
self.denoising_strength = getattr(p, 'denoising_strength', None)
@@ -459,7 +493,10 @@ class Processed:
"batch_size": self.batch_size,
"restore_faces": self.restore_faces,
"face_restoration_model": self.face_restoration_model,
+ "sd_model_name": self.sd_model_name,
"sd_model_hash": self.sd_model_hash,
+ "sd_vae_name": self.sd_vae_name,
+ "sd_vae_hash": self.sd_vae_hash,
"seed_resize_from_w": self.seed_resize_from_w,
"seed_resize_from_h": self.seed_resize_from_h,
"denoising_strength": self.denoising_strength,
@@ -578,10 +615,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
"Size": f"{p.width}x{p.height}",
- "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
- "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
- "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
+ "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
+ "Model": p.sd_model_name if opts.add_model_name_to_info else None,
+ "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
+ "VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -670,8 +707,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.tiling is None:
p.tiling = opts.tiling
- p.loaded_vae_name = sd_vae.get_loaded_vae_name()
- p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
+ if p.refiner_checkpoint not in (None, "", "None"):
+ p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
+ if p.refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
+
+ p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
+ p.sd_model_hash = shared.sd_model.sd_model_hash
+ p.sd_vae_name = sd_vae.get_loaded_vae_name()
+ p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
@@ -910,49 +954,51 @@ def old_hires_fix_first_pass_dimensions(width, height):
return width, height
+@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
- sampler = None
+ enable_hr: bool = False
+ denoising_strength: float = 0.75
+ firstphase_width: int = 0
+ firstphase_height: int = 0
+ hr_scale: float = 2.0
+ hr_upscaler: str = None
+ hr_second_pass_steps: int = 0
+ hr_resize_x: int = 0
+ hr_resize_y: int = 0
+ hr_checkpoint_name: str = None
+ hr_sampler_name: str = None
+ hr_prompt: str = ''
+ hr_negative_prompt: str = ''
+
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
- super().__init__(**kwargs)
- self.enable_hr = enable_hr
- self.denoising_strength = denoising_strength
- self.hr_scale = hr_scale
- self.hr_upscaler = hr_upscaler
- self.hr_second_pass_steps = hr_second_pass_steps
- self.hr_resize_x = hr_resize_x
- self.hr_resize_y = hr_resize_y
- self.hr_upscale_to_x = hr_resize_x
- self.hr_upscale_to_y = hr_resize_y
- self.hr_checkpoint_name = hr_checkpoint_name
- self.hr_checkpoint_info = None
- self.hr_sampler_name = hr_sampler_name
- self.hr_prompt = hr_prompt
- self.hr_negative_prompt = hr_negative_prompt
- self.all_hr_prompts = None
- self.all_hr_negative_prompts = None
- self.latent_scale_mode = None
-
- if firstphase_width != 0 or firstphase_height != 0:
+ hr_checkpoint_info: dict = field(default=None, init=False)
+ hr_upscale_to_x: int = field(default=0, init=False)
+ hr_upscale_to_y: int = field(default=0, init=False)
+ truncate_x: int = field(default=0, init=False)
+ truncate_y: int = field(default=0, init=False)
+ applied_old_hires_behavior_to: tuple = field(default=None, init=False)
+ latent_scale_mode: dict = field(default=None, init=False)
+ hr_c: tuple | None = field(default=None, init=False)
+ hr_uc: tuple | None = field(default=None, init=False)
+ all_hr_prompts: list = field(default=None, init=False)
+ all_hr_negative_prompts: list = field(default=None, init=False)
+ hr_prompts: list = field(default=None, init=False)
+ hr_negative_prompts: list = field(default=None, init=False)
+ hr_extra_network_data: list = field(default=None, init=False)
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ if self.firstphase_width != 0 or self.firstphase_height != 0:
self.hr_upscale_to_x = self.width
self.hr_upscale_to_y = self.height
- self.width = firstphase_width
- self.height = firstphase_height
-
- self.truncate_x = 0
- self.truncate_y = 0
- self.applied_old_hires_behavior_to = None
-
- self.hr_prompts = None
- self.hr_negative_prompts = None
- self.hr_extra_network_data = None
+ self.width = self.firstphase_width
+ self.height = self.firstphase_height
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
- self.hr_c = None
- self.hr_uc = None
def calculate_target_resolution(self):
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
@@ -1230,7 +1276,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return super().get_conds()
-
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
@@ -1243,32 +1288,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return res
+@dataclass(repr=False)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
- sampler = None
-
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
- super().__init__(**kwargs)
-
- self.init_images = init_images
- self.resize_mode: int = resize_mode
- self.denoising_strength: float = denoising_strength
- self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
- self.init_latent = None
- self.image_mask = mask
- self.latent_mask = None
- self.mask_for_overlay = None
- self.mask_blur_x = mask_blur_x
- self.mask_blur_y = mask_blur_y
- if mask_blur is not None:
- self.mask_blur = mask_blur
- self.inpainting_fill = inpainting_fill
- self.inpaint_full_res = inpaint_full_res
- self.inpaint_full_res_padding = inpaint_full_res_padding
- self.inpainting_mask_invert = inpainting_mask_invert
- self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
+ init_images: list = None
+ resize_mode: int = 0
+ denoising_strength: float = 0.75
+ image_cfg_scale: float = None
+ mask: Any = None
+ mask_blur_x: int = 4
+ mask_blur_y: int = 4
+ mask_blur: int = None
+ inpainting_fill: int = 0
+ inpaint_full_res: bool = True
+ inpaint_full_res_padding: int = 0
+ inpainting_mask_invert: int = 0
+ initial_noise_multiplier: float = None
+ latent_mask: Image = None
+
+ image_mask: Any = field(default=None, init=False)
+
+ nmask: torch.Tensor = field(default=None, init=False)
+ image_conditioning: torch.Tensor = field(default=None, init=False)
+ init_img_hash: str = field(default=None, init=False)
+ mask_for_overlay: Image = field(default=None, init=False)
+ init_latent: torch.Tensor = field(default=None, init=False)
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ self.image_mask = self.mask
self.mask = None
- self.nmask = None
- self.image_conditioning = None
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
@property
def mask_blur(self):
@@ -1278,15 +1328,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
@mask_blur.setter
def mask_blur(self, value):
- self.mask_blur_x = value
- self.mask_blur_y = value
-
- @mask_blur.deleter
- def mask_blur(self):
- del self.mask_blur_x
- del self.mask_blur_y
+ if isinstance(value, int):
+ self.mask_blur_x = value
+ self.mask_blur_y = value
def init(self, all_prompts, all_seeds, all_subseeds):
+ self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
+
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 773ec5d0..7b946d05 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -41,15 +41,9 @@ class ScriptRefiner(scripts.Script):
def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner
- p.refiner_checkpoint_info = None
- p.refiner_switch_at = None
-
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
- return
-
- refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
- if refiner_checkpoint_info is None:
- raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
-
- p.refiner_checkpoint_info = refiner_checkpoint_info
- p.refiner_switch_at = refiner_switch_at
+ p.refiner_checkpoint_info = None
+ p.refiner_switch_at = None
+ else:
+ p.refiner_checkpoint = refiner_checkpoint
+ p.refiner_switch_at = refiner_switch_at
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 40c7aae0..09d1e11e 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -145,7 +145,7 @@ def apply_refiner(cfg_denoiser):
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
- if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
+ if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
@@ -305,5 +305,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()
-
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index d37b428f..da0e48aa 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -175,14 +175,22 @@ def do_nothing(p, x, xs):
def format_nothing(p, opt, x):
return ""
+
def format_remove_path(p, opt, x):
return os.path.basename(x)
+
def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
+def list_to_csv_string(data_list):
+ with StringIO() as o:
+ csv.writer(o).writerow(data_list)
+ return o.getvalue().strip()
+
+
class AxisOption:
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
self.label = label
@@ -199,6 +207,7 @@ class AxisOptionImg2Img(AxisOption):
super().__init__(*args, **kwargs)
self.is_img2img = True
+
class AxisOptionTxt2Img(AxisOption):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -286,11 +295,10 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
cell_size = (processed_result.width, processed_result.height)
if processed_result.images[0] is not None:
cell_mode = processed_result.images[0].mode
- #This corrects size in case of batches:
+ # This corrects size in case of batches:
cell_size = processed_result.images[0].size
processed_result.images[idx] = Image.new(cell_mode, cell_size)
-
if first_axes_processed == 'x':
for ix, x in enumerate(xs):
if second_axes_processed == 'y':
@@ -348,9 +356,9 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
if draw_legend:
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
processed_result.images.insert(0, z_grid)
- #TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
- #processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
- #processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
+ # TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
+ # processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
+ # processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
processed_result.infotexts.insert(0, processed_result.infotexts[0])
return processed_result
@@ -374,8 +382,8 @@ class SharedSettingsStackHelper(object):
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
-re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
-re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
+re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*")
+re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*])?\s*")
class Script(scripts.Script):
@@ -390,19 +398,19 @@ class Script(scripts.Script):
with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
- x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
+ x_values_dropdown = gr.Dropdown(label="X values", visible=False, multiselect=True, interactive=True)
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
- y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
+ y_values_dropdown = gr.Dropdown(label="Y values", visible=False, multiselect=True, interactive=True)
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
with gr.Row():
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
- z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
+ z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True)
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"):
@@ -414,6 +422,9 @@ class Script(scripts.Script):
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
with gr.Column():
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
+ with gr.Column():
+ csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
+
with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
@@ -430,50 +441,71 @@ class Script(scripts.Script):
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
- def fill(x_type):
- axis = self.current_axis_options[x_type]
- return axis.choices() if axis.choices else gr.update()
+ def fill(axis_type, csv_mode):
+ axis = self.current_axis_options[axis_type]
+ if axis.choices:
+ if csv_mode:
+ return list_to_csv_string(axis.choices()), gr.update()
+ else:
+ return gr.update(), axis.choices()
+ else:
+ return gr.update(), gr.update()
- fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
- fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
- fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
+ fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])
+ fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])
+ fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
- def select_axis(axis_type,axis_values_dropdown):
+ def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
- current_values = axis_values_dropdown
+
+ current_values = axis_values
+ current_dropdown_values = axis_values_dropdown
if has_choices:
choices = choices()
- if isinstance(current_values,str):
- current_values = current_values.split(",")
- current_values = list(filter(lambda x: x in choices, current_values))
- return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
+ if csv_mode:
+ current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
+ current_values = list_to_csv_string(current_dropdown_values)
+ else:
+ current_dropdown_values = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(axis_values)))]
+ current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
+
+ return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=current_values),
+ gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=current_dropdown_values))
+
+ x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])
+ y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])
+ z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])
+
+ def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):
+ _fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)
+ _fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)
+ _fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)
+ return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown
- x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
- y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
- z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
+ csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])
- def get_dropdown_update_from_params(axis,params):
+ def get_dropdown_update_from_params(axis, params):
val_key = f"{axis} Values"
- vals = params.get(val_key,"")
+ vals = params.get(val_key, "")
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
- return gr.update(value = valslist)
+ return gr.update(value=valslist)
self.infotext_fields = (
(x_type, "X Type"),
(x_values, "X Values"),
- (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
+ (x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)),
(y_type, "Y Type"),
(y_values, "Y Values"),
- (y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
+ (y_values_dropdown, lambda params: get_dropdown_update_from_params("Y", params)),
(z_type, "Z Type"),
(z_values, "Z Values"),
- (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
+ (z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),
)
- return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
+ return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode]
- def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
+ def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
@@ -484,7 +516,7 @@ class Script(scripts.Script):
if opt.label == 'Nothing':
return [0]
- if opt.choices is not None:
+ if opt.choices is not None and not csv_mode:
valslist = vals_dropdown
else:
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
@@ -503,8 +535,8 @@ class Script(scripts.Script):
valslist_ext += list(range(start, end, step))
elif mc is not None:
start = int(mc.group(1))
- end = int(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
+ end = int(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
else:
@@ -525,8 +557,8 @@ class Script(scripts.Script):
valslist_ext += np.arange(start, end + step, step).tolist()
elif mc is not None:
start = float(mc.group(1))
- end = float(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
+ end = float(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
else:
@@ -545,22 +577,22 @@ class Script(scripts.Script):
return valslist
x_opt = self.current_axis_options[x_type]
- if x_opt.choices is not None:
- x_values = ",".join(x_values_dropdown)
+ if x_opt.choices is not None and not csv_mode:
+ x_values = list_to_csv_string(x_values_dropdown)
xs = process_axis(x_opt, x_values, x_values_dropdown)
y_opt = self.current_axis_options[y_type]
- if y_opt.choices is not None:
- y_values = ",".join(y_values_dropdown)
+ if y_opt.choices is not None and not csv_mode:
+ y_values = list_to_csv_string(y_values_dropdown)
ys = process_axis(y_opt, y_values, y_values_dropdown)
z_opt = self.current_axis_options[z_type]
- if z_opt.choices is not None:
- z_values = ",".join(z_values_dropdown)
+ if z_opt.choices is not None and not csv_mode:
+ z_values = list_to_csv_string(z_values_dropdown)
zs = process_axis(z_opt, z_values, z_values_dropdown)
# this could be moved to common code, but unlikely to be ever triggered anywhere else
- Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
+ Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
@@ -720,7 +752,7 @@ class Script(scripts.Script):
# Auto-save main and sub-grids:
grid_count = z_count + 1 if z_count > 1 else 1
for g in range(grid_count):
- #TODO: See previous comment about intentional data misalignment.
+ # TODO: See previous comment about intentional data misalignment.
adj_g = g-1 if g > 0 else g
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)