diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-25 05:18:02 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-25 05:18:02 +0000 |
commit | a3ddf464a2ed24c999f67ddfef7969f8291567be (patch) | |
tree | cf70006b4d1d6df1f42ea944416b1034ae32a92b /modules/sd_vae_taesd.py | |
parent | f865d3e11647dfd6c7b2cdf90dde24680e58acd8 (diff) | |
parent | 2c11e9009ea18bab4ce2963d44db0c6fd3227370 (diff) | |
download | stable-diffusion-webui-gfx803-a3ddf464a2ed24c999f67ddfef7969f8291567be.tar.gz stable-diffusion-webui-gfx803-a3ddf464a2ed24c999f67ddfef7969f8291567be.tar.bz2 stable-diffusion-webui-gfx803-a3ddf464a2ed24c999f67ddfef7969f8291567be.zip |
Merge branch 'release_candidate'
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r-- | modules/sd_vae_taesd.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5e8496e8..5bf7c76e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -8,9 +8,9 @@ import os import torch import torch.nn as nn -from modules import devices, paths_internal +from modules import devices, paths_internal, shared -sd_vae_taesd = None +sd_vae_taesd_models = {} def conv(n_in, n_out, **kwargs): @@ -61,9 +61,7 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def download_model(model_path): - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - +def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) @@ -72,17 +70,19 @@ def download_model(model_path): def model(): - global sd_vae_taesd + model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) - if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") - download_model(model_path) + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - sd_vae_taesd = TAESD(model_path) - sd_vae_taesd.eval() - sd_vae_taesd.to(devices.device, devices.dtype) + loaded_model = TAESD(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return sd_vae_taesd.decoder + return loaded_model.decoder |