diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-04 14:40:19 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-04 14:40:19 +0000 |
commit | da5c1e8a732c173ed8ccda9fa32f9a194ff91ab6 (patch) | |
tree | a2eec9c47e820e7ab351337f73c99d874b4b904f /modules/sd_vae_approx.py | |
parent | cffc240a7327ae60671ff533469fc4ed4bf605de (diff) | |
parent | 47df0849019abac6722c49512f4dd2285bff5b7d (diff) | |
download | stable-diffusion-webui-gfx803-da5c1e8a732c173ed8ccda9fa32f9a194ff91ab6.tar.gz stable-diffusion-webui-gfx803-da5c1e8a732c173ed8ccda9fa32f9a194ff91ab6.tar.bz2 stable-diffusion-webui-gfx803-da5c1e8a732c173ed8ccda9fa32f9a194ff91ab6.zip |
Merge branch 'master' into inpaint_textual_inversion
Diffstat (limited to 'modules/sd_vae_approx.py')
-rw-r--r-- | modules/sd_vae_approx.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py new file mode 100644 index 00000000..0a58542d --- /dev/null +++ b/modules/sd_vae_approx.py @@ -0,0 +1,58 @@ +import os
+
+import torch
+from torch import nn
+from modules import devices, paths
+
+sd_vae_approx_model = None
+
+
+class VAEApprox(nn.Module):
+ def __init__(self):
+ super(VAEApprox, self).__init__()
+ self.conv1 = nn.Conv2d(4, 8, (7, 7))
+ self.conv2 = nn.Conv2d(8, 16, (5, 5))
+ self.conv3 = nn.Conv2d(16, 32, (3, 3))
+ self.conv4 = nn.Conv2d(32, 64, (3, 3))
+ self.conv5 = nn.Conv2d(64, 32, (3, 3))
+ self.conv6 = nn.Conv2d(32, 16, (3, 3))
+ self.conv7 = nn.Conv2d(16, 8, (3, 3))
+ self.conv8 = nn.Conv2d(8, 3, (3, 3))
+
+ def forward(self, x):
+ extra = 11
+ x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+ x = nn.functional.pad(x, (extra, extra, extra, extra))
+
+ for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
+ x = layer(x)
+ x = nn.functional.leaky_relu(x, 0.1)
+
+ return x
+
+
+def model():
+ global sd_vae_approx_model
+
+ if sd_vae_approx_model is None:
+ sd_vae_approx_model = VAEApprox()
+ sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+ sd_vae_approx_model.eval()
+ sd_vae_approx_model.to(devices.device, devices.dtype)
+
+ return sd_vae_approx_model
+
+
+def cheap_approximation(sample):
+ # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
+
+ coefs = torch.tensor([
+ [0.298, 0.207, 0.208],
+ [0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]).to(sample.device)
+
+ x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+
+ return x_sample
|