aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/generation_parameters_copypaste.py17
-rw-r--r--modules/hypernetworks/hypernetwork.py9
-rw-r--r--modules/processing.py26
-rw-r--r--modules/sd_vae.py20
-rw-r--r--modules/shared.py2
-rw-r--r--modules/sub_quadratic_attention.py15
-rw-r--r--modules/textual_inversion/dataset.py10
-rw-r--r--modules/textual_inversion/textual_inversion.py36
-rw-r--r--modules/ui.py37
9 files changed, 129 insertions, 43 deletions
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 12a9de3d..f7f68b67 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res):
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
+ if shared.opts.use_old_hires_fix_width_height:
+ hires_width = int(res.get("Hires resize-1", None))
+ hires_height = int(res.get("Hires resize-2", None))
+
+ if hires_width is not None and hires_height is not None:
+ res['Size-1'] = hires_width
+ res['Size-2'] = hires_height
+ return
+
if firstpass_width is None or firstpass_height is None:
return
@@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
- # old algorithm for auto-calculating first pass size
- desired_pixel_count = 512 * 512
- actual_pixel_count = width * height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- firstpass_width = math.ceil(scale * width / 64) * 64
- firstpass_height = math.ceil(scale * height / 64) * 64
+ from modules import processing
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b0cfbe71..ea3f1db9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,6 +24,7 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -456,7 +459,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
diff --git a/modules/processing.py b/modules/processing.py
index 1d23b15f..f04a0e1e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
return res
+def old_hires_fix_first_pass_dimensions(width, height):
+ """old algorithm for auto-calculating first pass size"""
+
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ width = math.ceil(scale * width / 64) * 64
+ height = math.ceil(scale * height / 64) * 64
+
+ return width, height
+
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
@@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_upscale_to_y = hr_resize_y
if firstphase_width != 0 or firstphase_height != 0:
- print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
- self.hr_scale = self.width / firstphase_width
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
self.width = firstphase_width
self.height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
+ self.applied_old_hires_behavior_to = None
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
+ self.hr_resize_x = self.width
+ self.hr_resize_y = self.height
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
+
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
+ self.applied_old_hires_behavior_to = (self.width, self.height)
+
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
self.extra_generation_params["Hires upscale"] = self.hr_scale
self.hr_upscale_to_x = int(self.width * self.hr_scale)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index ac71d62d..0a49daa1 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,8 +1,9 @@
import torch
+import safetensors.torch
import os
import collections
from collections import namedtuple
-from modules import shared, devices, script_callbacks
+from modules import shared, devices, script_callbacks, sd_models
from modules.paths import models_path
import glob
from copy import deepcopy
@@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
+ *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
@@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print(f"Using VAE found similar to selected model: {vae_file}")
+ # if still not found, try look for ".vae.safetensors" beside model
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.safetensors"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
@@ -163,8 +172,9 @@ def load_vae(model, vae_file=None):
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+
+ vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
@@ -195,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
+
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
+
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
diff --git a/modules/shared.py b/modules/shared.py
index a6712dae..aa37c8ce 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
@@ -398,6 +399,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+ "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index fea7aaac..55052815 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -15,7 +15,8 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
-from typing import Optional, NamedTuple, Protocol, List
+from typing import Optional, NamedTuple, List
+
def narrow_trunc(
input: Tensor,
@@ -25,12 +26,14 @@ def narrow_trunc(
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
+
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
-class SummarizeChunk(Protocol):
+
+class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
@@ -38,7 +41,8 @@ class SummarizeChunk(Protocol):
value: Tensor,
) -> AttnChunk: ...
-class ComputeQueryChunkAttn(Protocol):
+
+class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
@@ -46,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
value: Tensor,
) -> Tensor: ...
+
def _summarize_chunk(
query: Tensor,
key: Tensor,
@@ -66,6 +71,7 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
def _query_chunk_attention(
query: Tensor,
key: Tensor,
@@ -106,6 +112,7 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
@@ -125,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
+
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
+
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 88d68c76..fa48708e 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -28,13 +28,11 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
- self.width = width
- self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
@@ -47,10 +45,10 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
+ assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
@@ -59,7 +57,9 @@ class PersonalizedBase(Dataset):
if shared.state.interrupted:
raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB')
+ if not varsize:
+ image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 217fe9eb..5420903f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
import inspect
+from collections import namedtuple
import torch
import tqdm
@@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
- insert_image_data_embed, extract_image_data_embed,
- caption_image_overlay)
+from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
from modules.textual_inversion.logging import save_settings_to_file
+TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
+textual_inversion_templates = {}
+
+
+def list_textual_inversion_templates():
+ textual_inversion_templates.clear()
+
+ for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for fn in fns:
+ path = os.path.join(root, fn)
+
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
+
+ return textual_inversion_templates
+
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
@@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
@@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert template_file, "Prompt template file is empty"
- assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert template_filename, "Prompt template file not selected"
+ assert template_file, f"Prompt template file {template_filename} not found"
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0, "Max steps must be positive"
@@ -297,10 +313,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
- validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = textual_inversion_templates.get(template_filename, None)
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = template_file.path
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
@@ -351,7 +369,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
diff --git a/modules/ui.py b/modules/ui.py
index 99483130..b6079aec 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -37,7 +37,7 @@ from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
-import modules.textual_inversion.ui
+from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
@@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
with devices.autocast():
p.init([""], [0], [0])
- return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
+ return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
@@ -745,15 +745,20 @@ def create_ui():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- hr_resolution_preview_args = dict(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False
- )
-
for input in hr_resolution_preview_inputs:
- input.change(**hr_resolution_preview_args)
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -1317,6 +1322,9 @@ def create_ui():
outputs=[process_focal_crop_row],
)
+ def get_textual_inversion_template_names():
+ return sorted([x for x in textual_inversion.textual_inversion_templates])
+
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with FormRow():
@@ -1340,9 +1348,14 @@ def create_ui():
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+
+ with FormRow():
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
+
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
with FormRow():
@@ -1449,6 +1462,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,
@@ -1480,6 +1494,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,