aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_vae_taesd.py
diff options
context:
space:
mode:
authorSakura-Luna <53183413+Sakura-Luna@users.noreply.github.com>2023-05-14 04:42:44 +0000
committerSakura-Luna <53183413+Sakura-Luna@users.noreply.github.com>2023-05-14 06:06:01 +0000
commite14b586d0494d6c5cc3cbc45b5fa00c03d052443 (patch)
tree807b3e771ef465654b672956d09d94af525d14ab /modules/sd_vae_taesd.py
parentb08500cec8a791ef20082628b49b17df833f5dda (diff)
downloadstable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.tar.gz
stable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.tar.bz2
stable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.zip
Add Tiny AE live preview
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r--modules/sd_vae_taesd.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
new file mode 100644
index 00000000..ccc97959
--- /dev/null
+++ b/modules/sd_vae_taesd.py
@@ -0,0 +1,76 @@
+"""
+Tiny AutoEncoder for Stable Diffusion
+(DNN for encoding / decoding SD's latent space)
+
+https://github.com/madebyollin/taesd
+"""
+import os
+import torch
+import torch.nn as nn
+
+from modules import devices, paths_internal
+
+sd_vae_taesd = None
+
+
+def conv(n_in, n_out, **kwargs):
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+
+class Clamp(nn.Module):
+ @staticmethod
+ def forward(x):
+ return torch.tanh(x / 3) * 3
+
+
+class Block(nn.Module):
+ def __init__(self, n_in, n_out):
+ super().__init__()
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.fuse = nn.ReLU()
+
+ def forward(self, x):
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+def decoder():
+ return nn.Sequential(
+ Clamp(), conv(4, 64), nn.ReLU(),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), conv(64, 3),
+ )
+
+
+class TAESD(nn.Module):
+ latent_magnitude = 2
+ latent_shift = 0.5
+
+ def __init__(self, decoder_path="taesd_decoder.pth"):
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
+ super().__init__()
+ self.decoder = decoder()
+ self.decoder.load_state_dict(
+ torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+
+ @staticmethod
+ def unscale_latents(x):
+ """[0, 1] -> raw latents"""
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+
+def decode():
+ global sd_vae_taesd
+
+ if sd_vae_taesd is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth")
+ if os.path.exists(model_path):
+ sd_vae_taesd = TAESD(model_path)
+ sd_vae_taesd.eval()
+ sd_vae_taesd.to(devices.device, devices.dtype)
+ else:
+ raise FileNotFoundError('Tiny AE mdoel not found')
+
+ return sd_vae_taesd.decoder