From b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:24:54 +0300 Subject: add XL support for live previews: approx and TAESD --- modules/sd_vae_approx.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) (limited to 'modules/sd_vae_approx.py') diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index e2f00468..b348f3ae 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -2,9 +2,9 @@ import os import torch from torch import nn -from modules import devices, paths +from modules import devices, paths, shared -sd_vae_approx_model = None +sd_vae_approx_models = {} class VAEApprox(nn.Module): @@ -31,19 +31,34 @@ class VAEApprox(nn.Module): return x +def download_model(model_path, model_url): + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading VAEApprox model to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + def model(): - global sd_vae_approx_model + model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) - if sd_vae_approx_model is None: - model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") - sd_vae_approx_model = VAEApprox() + if loaded_model is None: + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) if not os.path.exists(model_path): - model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") - sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - sd_vae_approx_model.eval() - sd_vae_approx_model.to(devices.device, devices.dtype) + model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name) + + if not os.path.exists(model_path): + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) + download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) + + loaded_model = VAEApprox() + loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_approx_models[model_name] = loaded_model - return sd_vae_approx_model + return loaded_model def cheap_approximation(sample): -- cgit v1.2.3