diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-05-17 06:24:01 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-05-17 06:24:01 +0000 |
commit | 56a2672831751480f94a018f861f0143a8234ae8 (patch) | |
tree | b7b4a37178c6a4945a748b9c94b81c259e4315b8 /modules/sd_vae_taesd.py | |
parent | b217ebc49000b41baab3094dbc8caaf33eaf5579 (diff) | |
download | stable-diffusion-webui-gfx803-56a2672831751480f94a018f861f0143a8234ae8.tar.gz stable-diffusion-webui-gfx803-56a2672831751480f94a018f861f0143a8234ae8.tar.bz2 stable-diffusion-webui-gfx803-56a2672831751480f94a018f861f0143a8234ae8.zip |
return live preview defaults to how they were
only download TAESD model when it's needed
return calculations in single_sample_to_image to just if/elif/elif blocks
keep taesd model in its own directory
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r-- | modules/sd_vae_taesd.py | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 927a7298..d23812ef 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -61,16 +61,28 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def decode(): +def download_model(model_path): + model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' + + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading TAESD decoder to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + +def model(): global sd_vae_taesd if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth") + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") + download_model(model_path) + if os.path.exists(model_path): sd_vae_taesd = TAESD(model_path) sd_vae_taesd.eval() sd_vae_taesd.to(devices.device, devices.dtype) else: - raise FileNotFoundError('Tiny AE model not found') + raise FileNotFoundError('TAESD model not found') return sd_vae_taesd.decoder |