diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-24 08:09:04 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-24 08:09:04 +0000 |
commit | 189229bbf9276fb73e48c783856b02fc57ab5c9b (patch) | |
tree | 728b1ab97fec6d18a1ec687ba552ca83b0dcf109 /modules/sd_vae_taesd.py | |
parent | 31f2be3dcedf85c036c5f784c640208d122b62ed (diff) | |
parent | b6c02174050b2c5dd98bf24c797e85ff269516f5 (diff) | |
download | stable-diffusion-webui-gfx803-189229bbf9276fb73e48c783856b02fc57ab5c9b.tar.gz stable-diffusion-webui-gfx803-189229bbf9276fb73e48c783856b02fc57ab5c9b.tar.bz2 stable-diffusion-webui-gfx803-189229bbf9276fb73e48c783856b02fc57ab5c9b.zip |
Merge branch 'dev' into release_candidate
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r-- | modules/sd_vae_taesd.py | 52 |
1 files changed, 44 insertions, 8 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5bf7c76e..808eb362 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -44,7 +44,17 @@ def decoder(): ) -class TAESD(nn.Module): +def encoder(): + return nn.Sequential( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 4), + ) + + +class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 @@ -55,21 +65,28 @@ class TAESD(nn.Module): self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - @staticmethod - def unscale_latents(x): - """[0, 1] -> raw latents""" - return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + +class TAESDEncoder(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, encoder_path="taesd_encoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.encoder = encoder() + self.encoder.load_state_dict( + torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) - print(f'Downloading TAESD decoder to: {model_path}') + print(f'Downloading TAESD model to: {model_path}') torch.hub.download_url_to_file(model_url, model_path) -def model(): +def decoder_model(): model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" loaded_model = sd_vae_taesd_models.get(model_name) @@ -78,7 +95,7 @@ def model(): download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - loaded_model = TAESD(model_path) + loaded_model = TAESDDecoder(model_path) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) sd_vae_taesd_models[model_name] = loaded_model @@ -86,3 +103,22 @@ def model(): raise FileNotFoundError('TAESD model not found') return loaded_model.decoder + + +def encoder_model(): + model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) + + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) + + if os.path.exists(model_path): + loaded_model = TAESDEncoder(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model + else: + raise FileNotFoundError('TAESD model not found') + + return loaded_model.encoder |