diff options
author | Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> | 2023-05-14 04:42:44 +0000 |
---|---|---|
committer | Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> | 2023-05-14 06:06:01 +0000 |
commit | e14b586d0494d6c5cc3cbc45b5fa00c03d052443 (patch) | |
tree | 807b3e771ef465654b672956d09d94af525d14ab /modules/sd_vae_taesd.py | |
parent | b08500cec8a791ef20082628b49b17df833f5dda (diff) | |
download | stable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.tar.gz stable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.tar.bz2 stable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.zip |
Add Tiny AE live preview
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r-- | modules/sd_vae_taesd.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py new file mode 100644 index 00000000..ccc97959 --- /dev/null +++ b/modules/sd_vae_taesd.py @@ -0,0 +1,76 @@ +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) + +https://github.com/madebyollin/taesd +""" +import os +import torch +import torch.nn as nn + +from modules import devices, paths_internal + +sd_vae_taesd = None + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + @staticmethod + def forward(x): + return torch.tanh(x / 3) * 3 + + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + +def decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + + +class TAESD(nn.Module): + latent_magnitude = 2 + latent_shift = 0.5 + + def __init__(self, decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.decoder = decoder() + self.decoder.load_state_dict( + torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + +def decode(): + global sd_vae_taesd + + if sd_vae_taesd is None: + model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth") + if os.path.exists(model_path): + sd_vae_taesd = TAESD(model_path) + sd_vae_taesd.eval() + sd_vae_taesd.to(devices.device, devices.dtype) + else: + raise FileNotFoundError('Tiny AE mdoel not found') + + return sd_vae_taesd.decoder |