diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 2 | ||||
-rw-r--r-- | modules/api/models.py | 3 | ||||
-rw-r--r-- | modules/hashes.py | 84 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 44 | ||||
-rw-r--r-- | modules/processing.py | 2 | ||||
-rw-r--r-- | modules/sd_models.py | 121 | ||||
-rw-r--r-- | modules/sd_samplers.py | 15 | ||||
-rw-r--r-- | modules/sd_vae.py | 194 | ||||
-rw-r--r-- | modules/shared.py | 21 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 6 | ||||
-rw-r--r-- | modules/ui.py | 44 | ||||
-rw-r--r-- | modules/ui_progress.py | 2 |
12 files changed, 339 insertions, 199 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 5767ba90..9814bbc2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -371,7 +371,7 @@ class Api: return upscalers def get_sd_models(self): - return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index c78095ca..1eb1fcf1 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -224,7 +224,8 @@ class UpscalerItem(BaseModel): class SDModelItem(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") - hash: str = Field(title="Hash") + hash: Optional[str] = Field(title="Short hash") + sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") config: str = Field(title="Config file") diff --git a/modules/hashes.py b/modules/hashes.py new file mode 100644 index 00000000..14231771 --- /dev/null +++ b/modules/hashes.py @@ -0,0 +1,84 @@ +import hashlib
+import json
+import os.path
+
+import filelock
+
+
+cache_filename = "cache.json"
+cache_data = None
+
+
+def dump_cache():
+ with filelock.FileLock(cache_filename+".lock"):
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ global cache_data
+
+ if cache_data is None:
+ with filelock.FileLock(cache_filename+".lock"):
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def calculate_sha256(filename):
+ hash_sha256 = hashlib.sha256()
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def sha256_from_cache(filename, title):
+ hashes = cache("hashes")
+ ondisk_mtime = os.path.getmtime(filename)
+
+ if title not in hashes:
+ return None
+
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
+ return None
+
+ return cached_sha256
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
+
+ sha256_value = sha256_from_cache(filename, title)
+ if sha256_value is not None:
+ return sha256_value
+
+ print(f"Calculating sha256 for {filename}: ", end='')
+ sha256_value = calculate_sha256(filename)
+ print(f"{sha256_value}")
+
+ hashes[title] = {
+ "mtime": os.path.getmtime(filename),
+ "sha256": sha256_value,
+ }
+
+ dump_cache()
+
+ return sha256_value
+
+
+
+
+
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 83cbb4f0..3aebefa8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -225,7 +225,7 @@ class Hypernetwork: torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['hash'] = self.shorthash()
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
@@ -237,32 +237,33 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- print(self.layer_structure)
- optional_info = state_dict.get('optional_info', None)
- if optional_info is not None:
- print(f"INFO:\n {optional_info}\n")
- self.optional_info = optional_info
+ self.optional_info = state_dict.get('optional_info', None)
self.activation_func = state_dict.get('activation_func', None)
- print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
- print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
- print(f"Layer norm is set to {self.add_layer_norm}")
self.dropout_structure = state_dict.get('dropout_structure', None)
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
- print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
- print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
if self.dropout_structure is None:
- print("Using previous dropout structure")
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- print(f"Dropout structure is set to {self.dropout_structure}")
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ if shared.opts.print_hypernet_extra:
+ if self.optional_info is not None:
+ print(f" INFO:\n {self.optional_info}\n")
- if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ print(f" Layer structure: {self.layer_structure}")
+ print(f" Activation function: {self.activation_func}")
+ print(f" Weight initialization: {self.weight_init}")
+ print(f" Layer norm: {self.add_layer_norm}")
+ print(f" Dropout usage: {self.use_dropout}" )
+ print(f" Activate last layer: {self.activate_output}")
+ print(f" Dropout structure: {self.dropout_structure}")
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
+
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
@@ -289,6 +290,11 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
self.eval()
+ def shorthash(self):
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
+
+ return sha256[0:10]
+
def list_hypernetworks(path):
res = {}
@@ -296,7 +302,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name + f"({sd_models.model_hash(filename)})"] = filename
+ res[name] = filename
return res
@@ -509,7 +515,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.opts.save_training_settings_to_txt:
saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
)
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
@@ -737,7 +743,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
- hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.name = hypernetwork_name
hypernetwork.save(filename)
diff --git a/modules/processing.py b/modules/processing.py index ae04cab7..849f6b19 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
diff --git a/modules/sd_models.py b/modules/sd_models.py index c466f273..6a681cef 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,17 +14,58 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
+checkpoint_alisases = {}
checkpoints_loaded = collections.OrderedDict()
+
+class CheckpointInfo:
+ def __init__(self, filename):
+ self.filename = filename
+ abspath = os.path.abspath(filename)
+
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
+ else:
+ name = os.path.basename(filename)
+
+ if name.startswith("\\") or name.startswith("/"):
+ name = name[1:]
+
+ self.title = name
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+ self.hash = model_hash(filename)
+
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
+
+ self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
+
+ def register(self):
+ checkpoints_list[self.title] = self
+ for id in self.ids:
+ checkpoint_alisases[id] = self
+
+ def calculate_shorthash(self):
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10]
+
+ if self.shorthash not in self.ids:
+ self.ids += [self.shorthash, self.sha256]
+ self.register()
+
+ return self.shorthash
+
+
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -43,10 +84,14 @@ def setup_model(): enable_midas_autodownload()
-def checkpoint_tiles():
- convert = lambda name: int(name) if name.isdigit() else name.lower()
- alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
- return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def checkpoint_tiles():
+ def convert(name):
+ return int(name) if name.isdigit() else name.lower()
+
+ def alphanumeric_key(key):
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
@@ -62,48 +107,38 @@ def find_checkpoint_config(info): def list_models():
checkpoints_list.clear()
+ checkpoint_alisases.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
- def modeltitle(path, shorthash):
- abspath = os.path.abspath(path)
-
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
- elif abspath.startswith(model_path):
- name = abspath.replace(model_path, '')
- else:
- name = os.path.basename(path)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
-
- return f'{name} [{shorthash}]', shortname
-
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
- h = model_hash(cmd_ckpt)
- title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
- shared.opts.data['sd_model_checkpoint'] = title
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
+ checkpoint_info.register()
+
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
+
for filename in model_list:
- h = model_hash(filename)
- title, short_model_name = modeltitle(filename, h)
+ checkpoint_info = CheckpointInfo(filename)
+ checkpoint_info.register()
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+def get_closet_checkpoint_match(search_string):
+ checkpoint_info = checkpoint_alisases.get(search_string, None)
+ if checkpoint_info is not None:
+ return checkpoint_info
+
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
-def get_closet_checkpoint_match(searchString):
- applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
- if len(applicable) > 0:
- return applicable[0]
return None
def model_hash(filename):
+ """old hash that only looks at a small part of the file and is prone to collisions"""
+
try:
with open(filename, "rb") as file:
import hashlib
@@ -119,7 +154,7 @@ def model_hash(filename): def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoints_list.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -189,9 +224,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd
-def load_model_weights(model, checkpoint_info, vae_file="auto"):
- checkpoint_file = checkpoint_info.filename
- sd_model_hash = checkpoint_info.hash
+def load_model_weights(model, checkpoint_info: CheckpointInfo):
+ sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -201,9 +235,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
- sd = read_state_dict(checkpoint_file)
+ sd = read_state_dict(checkpoint_info.filename)
model.load_state_dict(sd, strict=False)
del sd
@@ -235,15 +269,16 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
- model.sd_model_checkpoint = checkpoint_file
+ model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
- sd_vae.load_vae(model, vae_file)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ sd_vae.load_vae(model, vae_file, vae_source)
def enable_midas_autodownload():
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 01221b89..7616fded 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None): def store_latent(decoded):
state.current_latent = decoded
- if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
shared.state.current_image = sample_to_image(decoded)
@@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler: self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
@@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
-
+
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module): x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ if opts.live_preview_content == "Prompt":
+ store_latent(x_out[0:uncond.shape[0]])
+ elif opts.live_preview_content == "Negative prompt":
+ store_latent(x_out[-uncond.shape[0]:])
+
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
@@ -423,7 +427,8 @@ class KDiffusionSampler: def callback_state(self, d):
step = d['i']
latent = d["denoised"]
- store_latent(latent)
+ if opts.live_preview_content == "Combined":
+ store_latent(latent)
self.last_latent = latent
if self.stop_at is not None and step > self.stop_at:
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0a49daa1..add5cecf 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -9,23 +9,9 @@ import glob from copy import deepcopy -model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) -vae_dir = "VAE" -vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) - - +vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - - -default_vae_dict = {"auto": "auto", "None": None, None: None} -default_vae_list = ["auto", "None"] - - -default_vae_values = [default_vae_dict[x] for x in default_vae_list] -vae_dict = dict(default_vae_dict) -vae_list = list(default_vae_list) -first_load = True +vae_dict = {} base_vae = None @@ -64,100 +50,69 @@ def restore_base_vae(model): def get_filename(filepath): - return os.path.splitext(os.path.basename(filepath))[0] - - -def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list - res = {} - candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), + return os.path.basename(filepath) + + +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), ] - if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): - candidates.append(shared.cmd_opts.vae_path) + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + for filepath in candidates: name = get_filename(filepath) - res[name] = filepath - vae_list.clear() - vae_list.extend(default_vae_list) - vae_list.extend(list(res.keys())) - vae_dict.clear() - vae_dict.update(res) - vae_dict.update(default_vae_dict) - return vae_list - - -def get_vae_from_settings(vae_file="auto"): - # else, we load from settings, if not set to be default - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print(f"Selected VAE doesn't exist: {vae_file}") - return vae_file - - -def resolve_vae(checkpoint_file=None, vae_file="auto"): - global first_load, vae_dict, vae_list - - # if vae_file argument is provided, it takes priority, but not saved - if vae_file and vae_file not in default_vae_list: - if not os.path.isfile(vae_file): - print(f"VAE provided as function argument doesn't exist: {vae_file}") - vae_file = "auto" - # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported - if first_load and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - shared.opts.data['sd_vae'] = get_filename(vae_file) - else: - print(f"VAE provided as command line argument doesn't exist: {vae_file}") - # fallback to selector in settings, if vae selector not set to act as default fallback - if not shared.opts.sd_vae_as_default: - vae_file = get_vae_from_settings(vae_file) - # vae-path cmd arg takes priority for auto - if vae_file == "auto" and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - print(f"Using VAE provided as command line argument: {vae_file}") - # if still not found, try look for ".vae.pt" beside model - model_path = os.path.splitext(checkpoint_file)[0] - if vae_file == "auto": - vae_file_try = model_path + ".vae.pt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.ckpt" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.ckpt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.safetensors" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.safetensors" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # No more fallbacks for auto - if vae_file == "auto": - vae_file = None - # Last check, just because - if vae_file and not os.path.exists(vae_file): - vae_file = None - - return vae_file - - -def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file + vae_dict[name] = filepath + + +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None + + +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' + + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"): + return vae_near_checkpoint, 'found near the checkpoint' + + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if shared.opts.sd_vae != "Automatic": + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -165,12 +120,12 @@ def load_vae(model, vae_file=None): if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache - print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) _load_vae_dict(model, checkpoints_loaded[vae_file]) else: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) @@ -191,14 +146,12 @@ def load_vae(model, vae_file=None): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file - vae_list.append(vae_opt) + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file - first_load = False - # don't call this from outside def _load_vae_dict(model, vae_dict_1): @@ -211,7 +164,10 @@ def clear_loaded_vae(): loaded_vae_file = None -def reload_vae_weights(sd_model=None, vae_file="auto"): +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack if not sd_model: @@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_info = sd_model.sd_checkpoint_info checkpoint_file = checkpoint_info.filename - vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" if loaded_vae_file == vae_file: return @@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file) + load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) @@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("VAE Weights loaded.") + print("VAE weights loaded.") return sd_model diff --git a/modules/shared.py b/modules/shared.py index b90ded52..9756adea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
+parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -176,7 +176,7 @@ class State: self.interrupted = True
def nextjob(self):
- if opts.show_progress_every_n_steps == -1:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
@@ -224,7 +224,7 @@ class State: if not parallel_processing_allowed:
return
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable:
self.do_set_current_image()
def do_set_current_image(self):
@@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
+ "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -382,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
@@ -422,13 +423,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
@@ -443,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), { 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
+options_templates.update(options_section(('ui', "Live previews"), {
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
+}))
+
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -457,6 +463,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
options_templates.update()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6939efcc..63935878 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
latent_sampling_method = ds.latent_sampling_method
@@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_mid = '[{}]'.format(checkpoint.shorthash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
@@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
try:
- embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint = checkpoint.shorthash
embedding.sd_checkpoint_name = checkpoint.model_name
if remove_cached_checksum:
embedding.cached_checksum = None
diff --git a/modules/ui.py b/modules/ui.py index e86a624b..2425c66f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -795,19 +795,39 @@ def create_ui(): with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"):
+ copy_image_buttons = []
+ copy_image_destinations = {}
+
+ def add_copy_image_controls(tab_name, elem):
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
+
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
+ if name == tab_name:
+ gr.Button(title, interactive=False)
+ copy_image_destinations[name] = elem
+ continue
+
+ button = gr.Button(title)
+ copy_image_buttons.append((button, name, elem))
+
with gr.Tabs(elem_id="mode_img2img"):
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
inpaint_color_sketch_orig = gr.State(None)
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
def update_orig(image, state):
if image is not None:
@@ -824,10 +844,29 @@ def create_ui(): with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
+ gr.HTML(f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
+
+ return img
+
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js="switch_to_"+name.replace(" ", "_"),
+ inputs=[],
+ outputs=[],
+ )
+
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
with FormRow():
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
@@ -856,6 +895,7 @@ def create_ui(): outputs=[inpaint_controls, mask_alpha],
)
+
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
@@ -1841,4 +1881,6 @@ xformers: {xformers_version} gradio: {gr.__version__}
•
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
+ •
+checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
diff --git a/modules/ui_progress.py b/modules/ui_progress.py index 592fda55..7cd312e4 100644 --- a/modules/ui_progress.py +++ b/modules/ui_progress.py @@ -52,7 +52,7 @@ def check_progress_call(id_part): image = gr.update(visible=False)
preview_visibility = gr.update(visible=False)
- if opts.show_progress_every_n_steps != 0:
+ if opts.live_previews_enable:
shared.state.set_current_image()
image = shared.state.current_image
|