From eea8fc40e16664ddc8a9aec77206da704a35dde0 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 07:24:22 -0800 Subject: Add option to save ti settings to file. --- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 30 +++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..933cd738 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,6 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), + "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 71e07bcc..2bed2ecb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,7 @@ import os import sys import traceback +import inspect import torch import tqdm @@ -229,6 +230,28 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + + # Get a list of the argument names. + arg_names = inspect.getfullargspec(save_settings_to_file).args + + # Create a list of the argument names to include in the settings string. + names = arg_names[:16] # Include all arguments up until the preview-related ones. + if preview_from_txt2img: + names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") + def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" @@ -292,13 +315,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: - clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed @@ -306,7 +329,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + if shared.opts.save_train_settings_to_txt: + save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From 19a81ac2871ec900fc8b7955bbc2554b6c5ac6b1 Mon Sep 17 00:00:00 2001 From: cat Date: Thu, 5 Jan 2023 20:17:39 +0500 Subject: hires-fix: add "nearest-exact" latent upscale mode. --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..b7a3ce5c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -576,6 +576,7 @@ latent_upscale_modes = { "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, "Latent (nearest)": {"mode": "nearest", "antialias": False}, + "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, } sd_upscalers = [] -- cgit v1.2.3 From b85c2b5cf4a6809bc871718cf4680d49c3e95e94 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 08:14:38 -0800 Subject: Clean up ti, add same behavior to hypernetwork. --- modules/hypernetworks/hypernetwork.py | 31 +++++++++++++++++++++++++- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 14 +++++++----- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6a9b1398..d5985263 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -401,7 +401,33 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() +# Note: textual_inversion.py has a nearly identical function of the same name. +def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + # Starting index of preview-related arguments. + border_index = 19 + + # Get a list of the argument names, excluding default argument. + sig = inspect.signature(save_settings_to_file) + arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty] + + # Create a list of the argument names to include in the settings string. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. + + # Include preview-related arguments if applicable. + if preview_from_txt2img: + names.extend(arg_names[border_index:]) + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -457,7 +483,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) diff --git a/modules/shared.py b/modules/shared.py index 933cd738..10231a75 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), + "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2bed2ecb..68648550 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -230,18 +230,20 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +# Note: hypernetwork.py has a nearly identical function of the same name. def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): checkpoint = sd_models.select_checkpoint() model_name = checkpoint.model_name model_hash = '[{}]'.format(checkpoint.hash) - + # Starting index of preview-related arguments. + border_index = 16 # Get a list of the argument names. arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. - names = arg_names[:16] # Include all arguments up until the preview-related ones. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. if preview_from_txt2img: - names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" @@ -329,8 +331,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - if shared.opts.save_train_settings_to_txt: - save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From b6bab2f052b32c0ffebe6aecc1819ccf20cf8c5d Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 09:14:56 -0800 Subject: Include model in log file. Exclude directory. --- modules/hypernetworks/hypernetwork.py | 28 +++++++++----------------- modules/textual_inversion/textual_inversion.py | 22 +++++++++----------- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index d5985263..3237c37a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -402,30 +402,22 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() # Note: textual_inversion.py has a nearly identical function of the same name. -def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - checkpoint = sd_models.select_checkpoint() - model_name = checkpoint.model_name - model_hash = '[{}]'.format(checkpoint.hash) +def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # Starting index of preview-related arguments. - border_index = 19 - - # Get a list of the argument names, excluding default argument. - sig = inspect.signature(save_settings_to_file) - arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty] - + border_index = 21 + # Get a list of the argument names. + arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - - # Include preview-related arguments if applicable. if preview_from_txt2img: - names.extend(arg_names[border_index:]) - + names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" for name in names: - value = locals()[name] - settings_str += f"{name}: {value}\n" - + if name != 'log_directory': # It's useless and redundant to save log_directory. + value = locals()[name] + settings_str += f"{name}: {value}\n" + # Create or append to the file. with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: fout.write(settings_str + "\n\n") @@ -485,7 +477,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 68648550..ce7e4f5d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -231,26 +231,22 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) # Note: hypernetwork.py has a nearly identical function of the same name. -def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - checkpoint = sd_models.select_checkpoint() - model_name = checkpoint.model_name - model_hash = '[{}]'.format(checkpoint.hash) +def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # Starting index of preview-related arguments. - border_index = 16 + border_index = 18 # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - + arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. names = arg_names[:border_index] # Include all arguments up until the preview-related ones. if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True. - + names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" for name in names: - value = locals()[name] - settings_str += f"{name}: {value}\n" - + if name != 'log_directory': # It's useless and redundant to save log_directory. + value = locals()[name] + settings_str += f"{name}: {value}\n" + # Create or append to the file. with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: fout.write(settings_str + "\n\n") @@ -333,7 +329,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method -- cgit v1.2.3 From 81133d4168ae0bae9bf8bf1a1d4983319a589112 Mon Sep 17 00:00:00 2001 From: Faber Date: Fri, 6 Jan 2023 03:38:37 +0700 Subject: allow loading embeddings from subdirectories --- modules/textual_inversion/textual_inversion.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 24b43045..0a059044 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -149,19 +149,20 @@ class EmbeddingDatabase: else: self.skipped_embeddings[name] = embedding - for fn in os.listdir(self.embeddings_dir): - try: - fullfn = os.path.join(self.embeddings_dir, fn) - - if os.stat(fullfn).st_size == 0: + for root, dirs, fns in os.walk(self.embeddings_dir): + for fn in fns: + try: + fullfn = os.path.join(root, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading embedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) continue - process_file(fullfn, fn) - except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") -- cgit v1.2.3 From 8111b5569d07c7ac3b695e28171aede728b4ae56 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 3 Jan 2023 20:43:05 -0500 Subject: Add support for PyTorch nightly and local builds --- modules/devices.py | 28 +++++++++++++++++++++++----- webui.py | 7 ++++++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 800510b7..caeb0276 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): return orig_tensor_numpy(self, *args, **kwargs) -# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working -if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +orig_cumsum = torch.cumsum +orig_Tensor_cumsum = torch.Tensor.cumsum +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + return cumsum_func(input, *args, **kwargs) + + +if has_mps(): + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix + elif version.parse(torch.__version__) > version.parse("1.13.1"): + if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) diff --git a/webui.py b/webui.py index 13375e71..ddfaea95 100644 --- a/webui.py +++ b/webui.py @@ -4,7 +4,7 @@ import threading import time import importlib import signal -import threading +import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -13,6 +13,11 @@ from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path +import torch +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras -- cgit v1.2.3 From d61a5aa4f623f6630670241aca8fc5c2a6381769 Mon Sep 17 00:00:00 2001 From: acncagua Date: Fri, 6 Jan 2023 10:58:22 +0900 Subject: Add files via upload --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 81d96c5b..030f0685 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -550,6 +550,8 @@ Requested path was: {f} os.startfile(path) elif platform.system() == "Darwin": sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) else: sp.Popen(["xdg-open", path]) -- cgit v1.2.3 From d782a95967c9eea753df3333cd1954b6ec73eba0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 27 Dec 2022 08:50:55 -0500 Subject: Add Birch-san's sub-quadratic attention implementation --- README.md | 1 + modules/sd_hijack.py | 15 ++- modules/sd_hijack_optimizations.py | 124 ++++++++++++++++++----- modules/shared.py | 4 + modules/sub_quadratic_attention.py | 201 +++++++++++++++++++++++++++++++++++++ requirements.txt | 2 +- 6 files changed, 312 insertions(+), 35 deletions(-) create mode 100644 modules/sub_quadratic_attention.py diff --git a/README.md b/README.md index 556000fb..1913caf3 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The documentation was moved from this README over to the project's [wiki](https: - Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) +- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443) - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 690a9ec2..019a6f3f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet -from modules.sd_hijack_optimizations import invokeAI_mps_available - import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.openaimodel @@ -40,17 +38,16 @@ def apply_optimizations(): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + elif cmd_opts.opt_sub_quad_attention: + print("Applying sub-quadratic cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): - if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") - print("Applying v1 cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - else: - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 02c87f40..f5c153e8 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,7 @@ import math import sys import traceback -import importlib +import psutil import torch from torch import einsum @@ -12,6 +12,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from .sub_quadratic_attention import efficient_dot_product_attention + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: print(traceback.format_exc(), file=sys.stderr) +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads @@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() @@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def check_for_psutil(): - try: - spec = importlib.util.find_spec('psutil') - return spec is not None - except ModuleNotFoundError: - return False - -invokeAI_mps_available = check_for_psutil() - # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -if invokeAI_mps_available: - import psutil - mem_total_gb = psutil.virtual_memory().total // (1 << 30) +mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) @@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # -- End of code from https://github.com/invoke-ai/InvokeAI -- + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +def sub_quad_attention_forward(self, x, context=None, mask=None): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + + if chunk_threshold_bytes is None: + chunk_threshold_bytes = available_vram + elif chunk_threshold_bytes == 0: + chunk_threshold_bytes = None + + if kv_chunk_size_min is None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x): h_ = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 @@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..487a7792 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") +parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) +parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py new file mode 100644 index 00000000..b11dc1c7 --- /dev/null +++ b/modules/sub_quadratic_attention.py @@ -0,0 +1,201 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# unspecified +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math +from typing import Optional, NamedTuple, Protocol, List + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + return x[slicing] + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + +class SummarizeChunk(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: ... + +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + +def _query_chunk_attention( + query: Tensor, + key: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: + batch_x_heads, k_tokens, k_channels_per_head = key.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int) -> AttnChunk: + key_chunk = dynamic_slice( + key, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, k_channels_per_head) + ) + value_chunk = dynamic_slice( + value, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, v_channels_per_head) + ) + return summarize_chunk(query, key_chunk, value_chunk) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + +def efficient_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, +): + """Computes efficient dot-product attention given query, key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key: keys for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, k_tokens, _ = key.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( + query, + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key=key, + value=value, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key=key, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res diff --git a/requirements.txt b/requirements.txt index 5bed694e..0dbea322 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,4 @@ inflection GitPython torchsde safetensors -psutil; sys_platform == 'darwin' +psutil -- cgit v1.2.3 From b119815333026164f2bd7d1ca71f3e4f7a9afd0d Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 5 Jan 2023 04:37:17 -0500 Subject: Use narrow instead of dynamic_slice --- modules/sub_quadratic_attention.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index b11dc1c7..95924d24 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -5,6 +5,7 @@ # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 @@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint import math from typing import Optional, NamedTuple, Protocol, List -def dynamic_slice( - x: Tensor, - starts: List[int], - sizes: List[int], +def narrow_trunc( + input: Tensor, + dim: int, + start: int, + length: int ) -> Tensor: - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] - return x[slicing] + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) class AttnChunk(NamedTuple): exp_values: Tensor @@ -76,15 +77,17 @@ def _query_chunk_attention( _, _, v_channels_per_head = value.shape def chunk_scanner(chunk_idx: int) -> AttnChunk: - key_chunk = dynamic_slice( + key_chunk = narrow_trunc( key, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, k_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) - value_chunk = dynamic_slice( + value_chunk = narrow_trunc( value, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, v_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) return summarize_chunk(query, key_chunk, value_chunk) @@ -161,10 +164,11 @@ def efficient_dot_product_attention( kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) def get_query_chunk(chunk_idx: int) -> Tensor: - return dynamic_slice( + return narrow_trunc( query, - (0, chunk_idx, 0), - (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + 1, + chunk_idx, + min(query_chunk_size, q_tokens) ) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) -- cgit v1.2.3 From 683287d87f6401083a8d63eedc00ca7410214ca1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 08:52:06 +0300 Subject: rework saving training params to file #6372 --- modules/hypernetworks/hypernetwork.py | 28 +++++++------------------- modules/shared.py | 2 +- modules/textual_inversion/logging.py | 24 ++++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 23 +++------------------ 4 files changed, 35 insertions(+), 42 deletions(-) create mode 100644 modules/textual_inversion/logging.py diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3237c37a..b0cfbe71 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -13,7 +13,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers -from modules.textual_inversion import textual_inversion +from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ @@ -401,25 +401,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() -# Note: textual_inversion.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 21 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -477,7 +459,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + saved_params = dict( + model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} + ) + logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/shared.py b/modules/shared.py index f0e10b35..57e489d0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), + "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py new file mode 100644 index 00000000..8b1981d5 --- /dev/null +++ b/modules/textual_inversion/logging.py @@ -0,0 +1,24 @@ +import datetime +import json +import os + +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} +saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} +saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet +saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} + + +def save_settings_to_file(log_directory, all_params): + now = datetime.datetime.now() + params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} + + keys = saved_params_all + if all_params.get('preview_from_txt2img'): + keys = keys | saved_params_previews + + params.update({k: v for k, v in all_params.items() if k in keys}) + + filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(os.path.join(log_directory, filename), "w") as file: + json.dump(params, file, indent=4) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e9cf432f..f9f5e8cd 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay) +from modules.textual_inversion.logging import save_settings_to_file + class Embedding: def __init__(self, vec, name, step=None): @@ -231,25 +233,6 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) -# Note: hypernetwork.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 18 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" @@ -330,7 +313,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method -- cgit v1.2.3 From b95a4c0ce5ab9c414e0494193bfff665f45e9e65 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:01:51 -0500 Subject: Change sub-quad chunk threshold to use percentage --- modules/sd_hijack_optimizations.py | 18 +++++++++--------- modules/shared.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f5c153e8..b416e9ac 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) @@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x -def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape _, k_tokens, _ = k.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) - - if chunk_threshold_bytes is None: - chunk_threshold_bytes = available_vram - elif chunk_threshold_bytes == 0: + if chunk_threshold is None: + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + elif chunk_threshold == 0: chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) - if kv_chunk_size_min is None: + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) elif kv_chunk_size_min == 0: kv_chunk_size_min = None @@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index cb1dc312..d7a81db1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -59,7 +59,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) -parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -- cgit v1.2.3 From 5deb2a19ccea57a50252e8fcb07b4d17c6599def Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:33:15 -0500 Subject: Allow Doggettx's cross attention opt without CUDA --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ef25dadb..bd101e5b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -50,7 +50,7 @@ def apply_optimizations(): print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 optimization_method = 'V1' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()): print("Applying cross attention optimization (InvokeAI).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI optimization_method = 'InvokeAI' -- cgit v1.2.3 From c9bded39ee05bd0507ccd27d2b674d86d6c0c8e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 12:32:44 +0300 Subject: sort extensions by date and add an option to sort by other columns --- modules/ui_extensions.py | 44 ++++++++++++++++++++++++++++++++------------ style.css | 11 ++++++++++- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index eec9586f..742e745e 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -162,15 +162,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url, hide_tags): +def install_extension_from_index(url, hide_tags, sort_column): ext_table, message = install_extension_from_url(None, url) - code, _ = refresh_available_extensions_from_data(hide_tags) + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) return code, ext_table, message -def refresh_available_extensions(url, hide_tags): +def refresh_available_extensions(url, hide_tags, sort_column): global available_extensions import urllib.request @@ -179,18 +179,28 @@ def refresh_available_extensions(url, hide_tags): available_extensions = json.loads(text) - code, tags = refresh_available_extensions_from_data(hide_tags) + code, tags = refresh_available_extensions_from_data(hide_tags, sort_column) return url, code, gr.CheckboxGroup.update(choices=tags), '' -def refresh_available_extensions_for_tags(hide_tags): - code, _ = refresh_available_extensions_from_data(hide_tags) +def refresh_available_extensions_for_tags(hide_tags, sort_column): + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) return code, '' -def refresh_available_extensions_from_data(hide_tags): +sort_ordering = [ + # (reverse, order_by_function) + (True, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('name', 'z')), + (True, lambda x: x.get('name', 'z')), + (False, lambda x: 'z'), +] + + +def refresh_available_extensions_from_data(hide_tags, sort_column): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} @@ -210,8 +220,11 @@ def refresh_available_extensions_from_data(hide_tags): """ - for ext in extlist: + sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0] + + for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): name = ext.get("name", "noname") + added = ext.get('added', 'unknown') url = ext.get("url", None) description = ext.get("description", "") extension_tags = ext.get("tags", []) @@ -233,7 +246,7 @@ def refresh_available_extensions_from_data(hide_tags): code += f""" {html.escape(name)}
{tags_text} - {html.escape(description)} + {html.escape(description)}

Added: {html.escape(added)}

{install_code} @@ -291,25 +304,32 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") install_result = gr.HTML() available_extensions_table = gr.HTML() refresh_available_extensions_button.click( fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), - inputs=[available_extensions_index, hide_tags], + inputs=[available_extensions_index, hide_tags, sort_column], outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], ) install_extension_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install, hide_tags], + inputs=[extension_to_install, hide_tags, sort_column], outputs=[available_extensions_table, extensions_table, install_result], ) hide_tags.change( fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), - inputs=[hide_tags], + inputs=[hide_tags, sort_column], + outputs=[available_extensions_table, install_result] + ) + + sort_column.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags, sort_column], outputs=[available_extensions_table, install_result] ) diff --git a/style.css b/style.css index ee74d79e..f1b23b53 100644 --- a/style.css +++ b/style.css @@ -555,7 +555,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h /* Extensions */ -#tab_extensions table{ +#tab_extensions table``{ border-collapse: collapse; } @@ -581,6 +581,15 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h font-size: 95%; } +#available_extensions .info{ + margin: 0; +} + +#available_extensions .date_added{ + opacity: 0.85; + font-size: 90%; +} + #image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{ min-width: auto; padding-left: 0.5em; -- cgit v1.2.3 From 65ed4421e609dda3112f236c13e4db14caa71364 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 13:55:50 +0300 Subject: add callback for when the script is unloaded --- modules/script_callbacks.py | 18 +++++++++++++++++- webui.py | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index de69fd9f..608c5300 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_script_unloaded=[], ) @@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def script_unloaded_callback(): + for c in reversed(callback_map['callbacks_script_unloaded']): + try: + c.callback() + except Exception: + report_exception(c, 'script_unloaded') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -202,7 +211,7 @@ def on_app_started(callback): def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is - passed as an argument""" + passed as an argument; this function is also called when the script is reloaded. """ add_callback(callback_map['callbacks_model_loaded'], callback) @@ -279,3 +288,10 @@ def on_image_grid(callback): - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ add_callback(callback_map['callbacks_image_grid'], callback) + + +def on_script_unloaded(callback): + """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that + the script did should be reverted here""" + + add_callback(callback_map['callbacks_script_unloaded'], callback) diff --git a/webui.py b/webui.py index ff6eb6eb..733a06b5 100644 --- a/webui.py +++ b/webui.py @@ -187,12 +187,14 @@ def webui(): sd_samplers.set_samplers() + modules.script_callbacks.script_unloaded_callback() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) modelloader.load_upscalers() for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: -- cgit v1.2.3 From 848605fb654a55ee6947335d7df6e13366606fad Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 06:58:49 -0500 Subject: Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c1944d33..fea6cb35 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) -- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443) +- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention) - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot -- cgit v1.2.3 From 5e6566324bba20554bcc04f3dda798e560397f38 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 07:06:26 -0500 Subject: Always end version number with a digit --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 733a06b5..8737e593 100644 --- a/webui.py +++ b/webui.py @@ -16,7 +16,7 @@ from modules.paths import script_path import torch # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: - torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer -- cgit v1.2.3 From 3246a2d6b898da6a98fe9df4dc67944635a41bd3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 16:03:43 +0300 Subject: remove restriction for saving dropdowns to ui-config.json --- modules/scripts.py | 1 - modules/ui.py | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 0c44f191..35164093 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -290,7 +290,6 @@ class ScriptRunner: script.group = group dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True inputs[0] = dropdown for script in self.selectable_scripts: diff --git a/modules/ui.py b/modules/ui.py index 030f0685..b79d24ee 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -435,11 +435,9 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=1, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style.save_to_config = True with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style2.save_to_config = True return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button @@ -638,7 +636,6 @@ def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - sampler_index.save_to_config = True steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): @@ -1794,7 +1791,7 @@ def create_ui(): if init_field is not None: init_field(saved_value) - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: apply_field(x, 'visible') if type(x) == gr.Slider: @@ -1815,11 +1812,8 @@ def create_ui(): if type(x) == gr.Number: apply_field(x, 'value') - # Since there are many dropdowns that shouldn't be saved, - # we only mark dropdowns that should be saved. - if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): + if type(x) == gr.Dropdown: apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - apply_field(x, 'visible') visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") -- cgit v1.2.3 From 50194de93ffc9db763d9b08fcc9c3bde1aa86151 Mon Sep 17 00:00:00 2001 From: Kuma <36082288+KumiIT@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:12:45 +0100 Subject: typo UI fixes #6391 --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 57e489d0..865c3c07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,7 +430,7 @@ options_templates.update(options_section(('ui', "User interface"), { "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"), + 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) -- cgit v1.2.3 From 3992ecbe6e46a465062508c677964534e7397f72 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:02:46 +0100 Subject: Added UI elements Added a new row to hires fix that shows the new resolution after scaling --- modules/ui.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index b79d24ee..20f7d2a2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -255,6 +255,12 @@ def add_style(name: str, prompt: str, negative_prompt: str): return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] +def calc_resolution_hires(x, y, scale): + #final res can only be a multiple of 8 + scaled_x = int(x * scale // 8) * 8 + scaled_y = int(y * scale // 8) * 8 + + return "

Upscaled Resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -718,6 +724,12 @@ def create_ui(): hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + with FormRow(elem_id="txt2img_hires_fix_row3"): + hr_final_resolution = gr.HTML(value=calc_resolution_hires(width.value, height.value, hr_scale.value), elem_id="txtimg_hr_finalres") + hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) elif category == "batch": if not opts.dimensions_and_batch_together: -- cgit v1.2.3 From 991368c8d54404d8e13d4c6e76a0f32644e65ad4 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:24:29 +0100 Subject: remove camelcase --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 20f7d2a2..6fc8b7d7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -260,7 +260,7 @@ def calc_resolution_hires(x, y, scale): scaled_x = int(x * scale // 8) * 8 scaled_y = int(y * scale // 8) * 8 - return "

Upscaled Resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" + return "

Upscaled resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) -- cgit v1.2.3 From c18add68ef7d2de3617cbbaff864b0c74cfdf6c0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 16:42:47 -0500 Subject: Added license --- html/licenses.html | 29 ++++++++++++++++++++++++++++- modules/sd_hijack_optimizations.py | 1 + modules/sub_quadratic_attention.py | 2 +- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/html/licenses.html b/html/licenses.html index 9eeaa072..570630eb 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -184,7 +184,7 @@ SOFTWARE.

SwinIR

-Code added by contirubtors, most likely copied from this repository. +Code added by contributors, most likely copied from this repository.
                                  Apache License
@@ -390,3 +390,30 @@ SOFTWARE.
    limitations under the License.
 
+

Memory Efficient Attention

+The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that. +
+MIT License
+
+Copyright (c) 2023 Alex Birch
+Copyright (c) 2023 Amin Rezaei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b416e9ac..cdc63ed7 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -216,6 +216,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface def sub_quad_attention_forward(self, x, context=None, mask=None): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 95924d24..fea7aaac 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -1,7 +1,7 @@ # original source: # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py # license: -# unspecified +# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) -- cgit v1.2.3 From 82c1f10b144f733460feead0bdc37a861489dc57 Mon Sep 17 00:00:00 2001 From: Dean Hopkins Date: Fri, 6 Jan 2023 22:00:12 +0000 Subject: increase upscale api validation limit --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/models.py b/modules/api/models.py index f77951fc..22b88c59 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") -- cgit v1.2.3 From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 01:45:28 +0300 Subject: CLIP hijack rework --- modules/sd_hijack.py | 6 +- modules/sd_hijack_clip.py | 348 ++++++++++++------------- modules/sd_hijack_clip_old.py | 81 ++++++ modules/textual_inversion/textual_inversion.py | 1 - modules/ui.py | 2 +- 5 files changed, 256 insertions(+), 182 deletions(-) create mode 100644 modules/sd_hijack_clip_old.py diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa2cd4bb..71cc145a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -150,10 +150,10 @@ class StableDiffusionModelHijack: def clear_comments(self): self.comments = [] - def tokenize(self, text): - _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) + def get_prompt_lengths(self, text): + _, token_count = self.clip.process_texts([text]) - return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count) + return token_count, self.clip.get_target_prompt_token_count(token_count) class EmbeddingsWithFixes(torch.nn.Module): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ca92b142..ac3020d7 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -1,12 +1,28 @@ import math +from collections import namedtuple import torch from modules import prompt_parser, devices from modules.shared import opts -def get_target_prompt_token_count(token_count): - return math.ceil(max(token_count, 1) / 75) * 75 + +class PromptChunk: + """ + This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. + If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. + Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, + so just 75 tokens from prompt. + """ + + def __init__(self): + self.tokens = [] + self.multipliers = [] + self.fixes = [] + + +PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) +"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): @@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): super().__init__() self.wrapped = wrapped self.hijack = hijack + self.chunk_length = 75 + + def empty_chunk(self): + """creates an empty PromptChunk and returns it""" + + chunk = PromptChunk() + chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) + chunk.multipliers = [1.0] * (self.chunk_length + 2) + return chunk + + def get_target_prompt_token_count(self, token_count): + """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" + + return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length def tokenize(self, texts): + """Converts a batch of texts into a batch of token ids""" + raise NotImplementedError def encode_with_transformers(self, tokens): + """ + converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; + All python lists with tokens are assumed to have same length, usually 77. + if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on + model - can be 768 and 1024 + """ + raise NotImplementedError def encode_embedding_init_text(self, init_text, nvpt): + """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through + transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned.""" + raise NotImplementedError - def tokenize_line(self, line, used_custom_terms, hijack_comments): + def tokenize_line(self, line): + """ + this transforms a single prompt into a list of PromptChunk objects - as many as needed to + represent the prompt. + Returns the list and the total number of tokens in the prompt. + """ + if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) else: @@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): tokenized = self.tokenize([text for text, _ in parsed]) - fixes = [] - remade_tokens = [] - multipliers = [] + chunks = [] + chunk = PromptChunk() + token_count = 0 last_comma = -1 - for tokens, (text, weight) in zip(tokenized, parsed): - i = 0 - while i < len(tokens): - token = tokens[i] + def next_chunk(): + """puts current chunk into the list of results and produces the next one - empty""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + token_count += len(chunk.tokens) + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] * to_add + chunk.multipliers += [1.0] * to_add - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + position = 0 + while position < len(tokens): + token = tokens[position] if token == self.comma_token: - last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: - last_comma += 1 - reloc_tokens = remade_tokens[last_comma:] - reloc_mults = multipliers[last_comma:] + last_comma = len(chunk.tokens) + + # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] - remade_tokens = remade_tokens[:last_comma] - length = len(remade_tokens) + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] - rem = int(math.ceil(length / 75)) * 75 - length - remade_tokens += [self.id_end] * rem + reloc_tokens - multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) if embedding is None: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - iteration = len(remade_tokens) // 75 - if (len(remade_tokens) + emb_len) // 75 != iteration: - rem = (75 * (iteration + 1) - len(remade_tokens)) - remade_tokens += [self.id_end] * rem - multipliers += [1.0] * rem - iteration += 1 - fixes.append((iteration, (len(remade_tokens) % 75, embedding))) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - token_count = len(remade_tokens) - prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) - - remade_tokens = remade_tokens + [self.id_end] * tokens_to_add - multipliers = multipliers + [1.0] * tokens_to_add - - return remade_tokens, fixes, multipliers, token_count - - def process_text(self, texts): - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + continue + + emb_len = int(embedding.vec.shape[0]) + if len(chunk.tokens) + emb_len > self.chunk_length: + next_chunk() + + chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) + + chunk.tokens += [0] * emb_len + chunk.multipliers += [weight] * emb_len + position += embedding_length_in_tokens + + if len(chunk.tokens) > 0: + next_chunk() + + return chunks, token_count + + def process_texts(self, texts): + """ + Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum + length, in tokens, of all texts. + """ + token_count = 0 cache = {} - batch_multipliers = [] + batch_chunks = [] for line in texts: if line in cache: - remade_tokens, fixes, multipliers = cache[line] + chunks = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + chunks, current_token_count = self.tokenize_line(line) token_count = max(current_token_count, token_count) - cache[line] = (remade_tokens, fixes, multipliers) + cache[line] = chunks - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) + batch_chunks.append(chunks) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + return batch_chunks, token_count - def process_text_old(self, texts): - id_start = self.id_start - id_end = self.id_end - maxlen = self.wrapped.max_length # you get to stay at 77 - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 + def forward(self, texts): + """ + Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. + Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + An example shape returned by this function can be: (2, 77, 768). + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" + """ - cache = {} - batch_tokens = self.tokenize(texts) - batch_multipliers = [] - for tokens in batch_tokens: - tuple_tokens = tuple(tokens) + if opts.use_old_emphasis_implementation: + import modules.sd_hijack_clip_old + return modules.sd_hijack_clip_old.forward_old(self, texts) - if tuple_tokens in cache: - remade_tokens, fixes, multipliers = cache[tuple_tokens] - else: - fixes = [] - remade_tokens = [] - multipliers = [] - mult = 1.0 - - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - mult_change = self.token_mults.get(token) if opts.enable_emphasis else None - if mult_change is not None: - mult *= mult_change - i += 1 - elif embedding is None: - remade_tokens.append(token) - multipliers.append(mult) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes, multipliers) - - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def forward(self, text): - use_old = opts.use_old_emphasis_implementation - if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) - else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - - self.hijack.comments += hijack_comments - - if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - - if use_old: - self.hijack.fixes = hijack_fixes - return self.process_tokens(remade_batch_tokens, batch_multipliers) - - z = None - i = 0 - while max(map(len, remade_batch_tokens)) != 0: - rem_tokens = [x[75:] for x in remade_batch_tokens] - rem_multipliers = [x[75:] for x in batch_multipliers] - - self.hijack.fixes = [] - for unfiltered in hijack_fixes: - fixes = [] - for fix in unfiltered: - if fix[0] == i: - fixes.append(fix[1]) - self.hijack.fixes.append(fixes) - - tokens = [] - multipliers = [] - for j in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[j]) > 0: - tokens.append(remade_batch_tokens[j][:75]) - multipliers.append(batch_multipliers[j][:75]) - else: - tokens.append([self.id_end] * 75) - multipliers.append([1.0] * 75) - - z1 = self.process_tokens(tokens, multipliers) - z = z1 if z is None else torch.cat((z, z1), axis=-2) - - remade_batch_tokens = rem_tokens - batch_multipliers = rem_multipliers - i += 1 + batch_chunks, token_count = self.process_texts(texts) - return z + used_embeddings = {} + chunk_count = max([len(x) for x in batch_chunks]) - def process_tokens(self, remade_batch_tokens, batch_multipliers): - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens] - batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + zs = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + multipliers = [x.multipliers for x in batch_chunk] + self.hijack.fixes = [x.fixes for x in batch_chunk] + for fixes in self.hijack.fixes: + for position, embedding in fixes: + used_embeddings[embedding.name] = embedding + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + + if len(used_embeddings) > 0: + embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) + self.hijack.comments.append(f"Used embeddings: {embeddings_list}") + + return torch.hstack(zs) + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + """ + sends one single prompt chunk to be encoded by transformers neural network. + remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually + there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. + Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier + corresponds to one token. + """ tokens = torch.asarray(remade_batch_tokens).to(devices.device) + # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. if self.id_end != self.id_pad: for batch_pos in range(len(remade_batch_tokens)): index = remade_batch_tokens[batch_pos].index(self.id_end) @@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.encode_with_transformers(tokens) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] - batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device) + batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py new file mode 100644 index 00000000..6d9fbbe6 --- /dev/null +++ b/modules/sd_hijack_clip_old.py @@ -0,0 +1,81 @@ +from modules import sd_hijack_clip +from modules import shared + + +def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + +def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f9f5e8cd..45882ed6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -79,7 +79,6 @@ class EmbeddingDatabase: self.word_embeddings[embedding.name] = embedding - # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] diff --git a/modules/ui.py b/modules/ui.py index b79d24ee..5d2f5bad 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -368,7 +368,7 @@ def update_token_counter(text, steps): flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" -- cgit v1.2.3 From f94cfc563bbedd923d5e95563a5e8d93c8516ac3 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Sat, 7 Jan 2023 01:15:22 +0100 Subject: Changed HTML to textbox instead Using HTML caused an issue where the row would expand for a frame when changing the sliders because of the loading animation. This solution also doesn't use any additional HTML padding --- modules/ui.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 6fc8b7d7..6ea1b5d7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -260,7 +260,7 @@ def calc_resolution_hires(x, y, scale): scaled_x = int(x * scale // 8) * 8 scaled_y = int(y * scale // 8) * 8 - return "

Upscaled resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" + return str(scaled_x)+"x"+str(scaled_y) def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -726,7 +726,10 @@ def create_ui(): hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.HTML(value=calc_resolution_hires(width.value, height.value, hr_scale.value), elem_id="txtimg_hr_finalres") + hr_final_resolution = gr.Textbox(value=calc_resolution_hires(width.value, height.value, hr_scale.value), + elem_id="txtimg_hr_finalres", + label="Upscaled resolution", + interactive=False) hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) -- cgit v1.2.3 From 08066676a47b560235d4c085dd3cfcb470b80997 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:22:07 +0300 Subject: make it not break on empty inputs; thank you tarded, we are --- modules/sd_hijack_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ac3020d7..16aef76a 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -147,7 +147,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0: + if len(chunk.tokens) > 0 or len(chunks) == 0: next_chunk() return chunks, token_count -- cgit v1.2.3 From 1740c33547b62f692834c95914a2b295d51684c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:48:44 +0300 Subject: more comments --- modules/sd_hijack_clip.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 16aef76a..5520c9b2 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,7 +3,7 @@ from collections import namedtuple import torch -from modules import prompt_parser, devices +from modules import prompt_parser, devices, sd_hijack from modules.shared import opts @@ -22,14 +22,24 @@ class PromptChunk: PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) -"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" +"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt +chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally +are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to + have unlimited prompt length and assign weights to tokens in prompt. + """ + def __init__(self, wrapped, hijack): super().__init__() + self.wrapped = wrapped - self.hijack = hijack + """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, + depending on model.""" + + self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 def empty_chunk(self): @@ -55,7 +65,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; All python lists with tokens are assumed to have same length, usually 77. if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on - model - can be 768 and 1024 + model - can be 768 and 1024. + Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None). """ raise NotImplementedError @@ -113,7 +124,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): last_comma = len(chunk.tokens) # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack - # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 -- cgit v1.2.3 From de9738044571877450d1038e18f1ecce93d24af3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 08:53:53 +0300 Subject: this breaks on default config because width, height, hr_scale are None at that point. --- modules/ui.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f946382d..a18b9007 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -725,14 +725,8 @@ def create_ui(): hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.Textbox(value=calc_resolution_hires(width.value, height.value, hr_scale.value), - elem_id="txtimg_hr_finalres", - label="Upscaled resolution", - interactive=False) - hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + with FormRow(elem_id="txt2img_hires_fix_row3"): + hr_final_resolution = gr.Textbox(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -744,6 +738,10 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) -- cgit v1.2.3 From 1a5b86ad65fd738eadea1ad72f4abad3a4aabf17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 09:56:37 +0300 Subject: rework hires fix preview for #6437: movie it to where it takes less place, make it actually account for all relevant sliders and calculate dimensions correctly --- modules/processing.py | 1 - modules/ui.py | 40 +++++++++++++++++++++++++++------------- modules/ui_components.py | 8 ++++++++ style.css | 17 +++++++++++++++++ 4 files changed, 52 insertions(+), 14 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index a408d622..82157bc9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -711,7 +711,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = 0 self.truncate_y = 0 - def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: if self.hr_resize_x == 0 and self.hr_resize_y == 0: diff --git a/modules/ui.py b/modules/ui.py index a18b9007..6c765262 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -255,12 +255,20 @@ def add_style(name: str, prompt: str, negative_prompt: str): return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] -def calc_resolution_hires(x, y, scale): - #final res can only be a multiple of 8 - scaled_x = int(x * scale // 8) * 8 - scaled_y = int(y * scale // 8) * 8 - - return str(scaled_x)+"x"+str(scaled_y) + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize to: {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -712,6 +720,7 @@ def create_ui(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "hires_fix": with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: @@ -724,9 +733,6 @@ def create_ui(): hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.Textbox(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -738,9 +744,16 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + hr_resolution_preview_args = dict( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False + ) + + for input in hr_resolution_preview_inputs: + input.change(**hr_resolution_preview_args) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) @@ -803,6 +816,7 @@ def create_ui(): fn=lambda x: gr_show(x), inputs=[enable_hr], outputs=[hr_options], + show_progress = False, ) txt2img_paste_fields = [ diff --git a/modules/ui_components.py b/modules/ui_components.py index 91eb0e3d..cac001dc 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -23,3 +23,11 @@ class FormGroup(gr.Group, gr.components.FormComponent): def get_block_name(self): return "group" + + +class FormHTML(gr.HTML, gr.components.FormComponent): + """Same as gr.HTML but fits inside gradio forms""" + + def get_block_name(self): + return "html" + diff --git a/style.css b/style.css index f1b23b53..76721756 100644 --- a/style.css +++ b/style.css @@ -642,6 +642,23 @@ footer { opacity: 0.85; } +#txtimg_hr_finalres{ + min-height: 0 !important; + padding: .625rem .75rem; + margin-left: -0.75em + +} + +#txtimg_hr_finalres .resolution{ + font-weight: bold; +} + +#txt2img_checkboxes > div > div{ + flex: 0; + white-space: nowrap; + min-width: auto; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From fdfce4711076c2ebac1089bac8169d043eb7978f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 13:29:47 +0300 Subject: add "from" resolution for hires fix to be less confusing. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 6c765262..99483130 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize to: {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): -- cgit v1.2.3 From 151233399c4b79934bdbb7c12a97eeb6499572fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 13:30:06 +0300 Subject: new screenshot --- README.md | 9 +++------ screenshot.png | Bin 525075 -> 420577 bytes txt2img_Screenshot.png | Bin 337094 -> 0 bytes 3 files changed, 3 insertions(+), 6 deletions(-) delete mode 100644 txt2img_Screenshot.png diff --git a/README.md b/README.md index fea6cb35..d783fdf0 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,7 @@ # Stable Diffusion web UI A browser interface based on Gradio library for Stable Diffusion. -![](txt2img_Screenshot.png) - -Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users. +![](screenshot.png) ## Features [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features): @@ -97,9 +95,8 @@ Alternatively, use online services (like Google Colab): 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 2. Install [git](https://git-scm.com/download/win). 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. -4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). -5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). -6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. +4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). +5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. ### Automatic Installation on Linux 1. Install the dependencies: diff --git a/screenshot.png b/screenshot.png index 86c3209f..47a1be4e 100644 Binary files a/screenshot.png and b/screenshot.png differ diff --git a/txt2img_Screenshot.png b/txt2img_Screenshot.png deleted file mode 100644 index 6e2759a4..00000000 Binary files a/txt2img_Screenshot.png and /dev/null differ -- cgit v1.2.3