diff options
-rw-r--r-- | modules/images.py | 3 | ||||
-rw-r--r-- | modules/memmon.py | 3 | ||||
-rw-r--r-- | modules/processing.py | 7 | ||||
-rw-r--r-- | modules/sd_hijack_inpainting.py | 2 | ||||
-rw-r--r-- | modules/sd_vae.py | 31 | ||||
-rw-r--r-- | modules/shared.py | 3 |
6 files changed, 38 insertions, 11 deletions
diff --git a/modules/images.py b/modules/images.py index 31d4528d..962a955d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -525,6 +525,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+ if image_to_save.mode == 'RGBA':
+ image_to_save = image_to_save.convert("RGB")
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None:
diff --git a/modules/memmon.py b/modules/memmon.py index 9fb9b687..a7060f58 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread): def read(self): if not self.disabled: free, total = torch.cuda.mem_get_info() + self.data["free"] = free self.data["total"] = total torch_stats = torch.cuda.memory_stats(self.device) + self.data["active"] = torch_stats["active.all.current"] self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved"] = torch_stats["reserved_bytes.all.current"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] self.data["system_peak"] = total - self.data["min_free"] diff --git a/modules/processing.py b/modules/processing.py index 4a406084..0a9a8f95 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -338,13 +338,14 @@ def slerp(val, low, high): def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+ eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -384,8 +385,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
- if opts.eta_noise_seed_delta > 0:
- torch.manual_seed(seed + opts.eta_noise_seed_delta)
+ if eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index bb5499b3..06b75772 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -178,7 +178,7 @@ def sample_plms(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') + # print(f'Data shape for PLMS sampling is {size}') # remove unnecessary message samples, intermediates = self.plms_sampling(conditioning, size, callback=callback, diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 3856418e..ac71d62d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,6 @@ import torch import os +import collections from collections import namedtuple from modules import shared, devices, script_callbacks from modules.paths import models_path @@ -30,6 +31,7 @@ base_vae = None loaded_vae_file = None checkpoint_info = None +checkpoints_loaded = collections.OrderedDict() def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: @@ -149,13 +151,30 @@ def load_vae(model, vae_file=None): global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False + cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + if vae_file: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") - store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - _load_vae_dict(model, vae_dict_1) + if cache_enabled and vae_file in checkpoints_loaded: + # use vae checkpoint cache + print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + store_base_vae(model) + _load_vae_dict(model, checkpoints_loaded[vae_file]) + else: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" + print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _load_vae_dict(model, vae_dict_1) + + if cache_enabled: + # cache newly loaded vae + checkpoints_loaded[vae_file] = vae_dict_1.copy() + + # clean up cache if limit is reached + if cache_enabled: + while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model + checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..c494a3b9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -168,7 +168,7 @@ class State: def dict(self):
obj = {
"skipped": self.skipped,
- "interrupted": self.skipped,
+ "interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_no": self.job_no,
@@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|