diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-13 14:24:54 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-13 14:24:54 +0000 |
commit | b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b (patch) | |
tree | 4006aa0ac4d6629edd8b381b664dd43d6665c4ea /modules/sd_vae_taesd.py | |
parent | 6f23da603d3cbba82262a3c62cc44c8d5cb9e6db (diff) | |
download | stable-diffusion-webui-gfx803-b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b.tar.gz stable-diffusion-webui-gfx803-b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b.tar.bz2 stable-diffusion-webui-gfx803-b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b.zip |
add XL support for live previews: approx and TAESD
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r-- | modules/sd_vae_taesd.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5e8496e8..5bf7c76e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -8,9 +8,9 @@ import os import torch import torch.nn as nn -from modules import devices, paths_internal +from modules import devices, paths_internal, shared -sd_vae_taesd = None +sd_vae_taesd_models = {} def conv(n_in, n_out, **kwargs): @@ -61,9 +61,7 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def download_model(model_path): - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - +def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) @@ -72,17 +70,19 @@ def download_model(model_path): def model(): - global sd_vae_taesd + model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) - if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") - download_model(model_path) + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - sd_vae_taesd = TAESD(model_path) - sd_vae_taesd.eval() - sd_vae_taesd.to(devices.device, devices.dtype) + loaded_model = TAESD(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return sd_vae_taesd.decoder + return loaded_model.decoder |