diff options
author | Roman Beltiukov <maybe.hello.world@gmail.com> | 2023-05-25 22:10:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-25 22:10:10 +0000 |
commit | b2530c965c2afd5512c5f9020251fd4be8f067e5 (patch) | |
tree | 0c1620e00ac4eddea514706a5c3bf3e03bd46c70 /modules/sd_vae_taesd.py | |
parent | 09d9c3d287ee4543d285e0fde8b81603c9751a7e (diff) | |
parent | a6e653be26cc05f4438145fa0082816e9fbbf5fc (diff) | |
download | stable-diffusion-webui-gfx803-b2530c965c2afd5512c5f9020251fd4be8f067e5.tar.gz stable-diffusion-webui-gfx803-b2530c965c2afd5512c5f9020251fd4be8f067e5.tar.bz2 stable-diffusion-webui-gfx803-b2530c965c2afd5512c5f9020251fd4be8f067e5.zip |
Merge branch 'dev' into master
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r-- | modules/sd_vae_taesd.py | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py new file mode 100644 index 00000000..5e8496e8 --- /dev/null +++ b/modules/sd_vae_taesd.py @@ -0,0 +1,88 @@ +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) + +https://github.com/madebyollin/taesd +""" +import os +import torch +import torch.nn as nn + +from modules import devices, paths_internal + +sd_vae_taesd = None + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + @staticmethod + def forward(x): + return torch.tanh(x / 3) * 3 + + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + +def decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + + +class TAESD(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.decoder = decoder() + self.decoder.load_state_dict( + torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + +def download_model(model_path): + model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' + + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading TAESD decoder to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + +def model(): + global sd_vae_taesd + + if sd_vae_taesd is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") + download_model(model_path) + + if os.path.exists(model_path): + sd_vae_taesd = TAESD(model_path) + sd_vae_taesd.eval() + sd_vae_taesd.to(devices.device, devices.dtype) + else: + raise FileNotFoundError('TAESD model not found') + + return sd_vae_taesd.decoder |