From cc92dc1f8d73dd4d574c4c8ccab78b7fc61e440b Mon Sep 17 00:00:00 2001 From: ssysm Date: Sun, 9 Oct 2022 23:17:29 -0400 Subject: add vae path args --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index cb3982b1..b6979432 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,7 +147,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + vae_file = shared.cmd_opts.vae_path or os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") -- cgit v1.2.3 From 7349088d32b080f64058b6e5de5f0380a71ecd09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 16:11:14 +0300 Subject: --no-half-vae --- modules/devices.py | 6 +++++- modules/processing.py | 11 +++++++++-- modules/sd_models.py | 3 +++ modules/sd_samplers.py | 4 ++-- modules/shared.py | 1 + 5 files changed, 20 insertions(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/devices.py b/modules/devices.py index 0158b11f..03ef58f1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32") device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 +dtype_vae = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. @@ -59,9 +60,12 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) -def autocast(): +def autocast(disable=False): from modules import shared + if disable: + return contextlib.nullcontext() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 94d2dd62..ec8651ae 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +def decode_first_stage(model, x): + with devices.autocast(disable=x.dtype == devices.dtype_vae): + x = model.decode_first_stage(x) + + return x + + def get_fixed_seed(seed): if seed is None or seed == '' or seed == -1: return int(random.randrange(4294967294)) @@ -400,7 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: samples_ddim = samples_ddim.to(devices.dtype) - x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) + x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) del samples_ddim @@ -533,7 +540,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scale_latent: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") else: - decoded_samples = self.sd_model.decode_first_stage(samples) + decoded_samples = decode_first_stage(self.sd_model, samples) if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") diff --git a/modules/sd_models.py b/modules/sd_models.py index e63d3c29..2cdcd84f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info): model.half() devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): @@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) + model.first_stage_model.to(devices.dtype_vae) + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6e743f7e..d168b938 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser +from modules import prompt_parser, devices, processing from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None): def sample_to_image(samples): - x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0] + x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0] x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/modules/shared.py b/modules/shared.py index 1995a99a..5dfc344c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") +parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -- cgit v1.2.3 From af62ad4d25dcd0454944368f4925d83101cdedbc Mon Sep 17 00:00:00 2001 From: ssysm Date: Mon, 10 Oct 2022 13:25:28 -0400 Subject: change vae loading method --- modules/sd_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b0e1d8bd..7a42d924 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -150,9 +150,16 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - vae_file = shared.cmd_opts.vae_path or os.path.splitext(checkpoint_file)[0] + ".vae.pt" + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if os.path.exists(vae_file): + print(f"Found VAE Weights: {vae_file}") + elif shared.cmd_opts.vae_path != None: + vae_file = shared.cmd_opts.vae_path + print(f'No VAE found for inside the model folder. Using CLI specified : {vae_file}') + else: + print("No VAE found for inside the model folder. Passing.") + if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} -- cgit v1.2.3 From 727e4d108674dc2813507e2a973a733ef21e8d53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 20:46:55 +0300 Subject: no to different messages plus fix using != to compare to None --- modules/sd_models.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 4c06051e..0a55b4c3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -152,15 +152,12 @@ def load_model_weights(model, checkpoint_info): devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - if os.path.exists(vae_file): - print(f"Found VAE Weights: {vae_file}") - elif shared.cmd_opts.vae_path != None: + + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: vae_file = shared.cmd_opts.vae_path - print(f'No VAE found for inside the model folder. Using CLI specified : {vae_file}') - else: - print("No VAE found for inside the model folder. Passing.") if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} -- cgit v1.2.3