diff options
-rw-r--r-- | modules/sd_samplers_common.py | 21 | ||||
-rw-r--r-- | modules/sd_vae_taesd.py | 76 | ||||
-rw-r--r-- | modules/shared.py | 2 | ||||
-rw-r--r-- | webui.py | 11 |
4 files changed, 101 insertions, 9 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index bc074238..d3dc130c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
+from modules import devices, processing, images, sd_vae_approx, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
- if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
- elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ if approximation == 1:
+ x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample)
+ x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1)
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ if approximation == 3:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 2:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py new file mode 100644 index 00000000..ccc97959 --- /dev/null +++ b/modules/sd_vae_taesd.py @@ -0,0 +1,76 @@ +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) + +https://github.com/madebyollin/taesd +""" +import os +import torch +import torch.nn as nn + +from modules import devices, paths_internal + +sd_vae_taesd = None + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + @staticmethod + def forward(x): + return torch.tanh(x / 3) * 3 + + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + +def decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + + +class TAESD(nn.Module): + latent_magnitude = 2 + latent_shift = 0.5 + + def __init__(self, decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.decoder = decoder() + self.decoder.load_state_dict( + torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + +def decode(): + global sd_vae_taesd + + if sd_vae_taesd is None: + model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth") + if os.path.exists(model_path): + sd_vae_taesd = TAESD(model_path) + sd_vae_taesd.eval() + sd_vae_taesd.to(devices.device, devices.dtype) + else: + raise FileNotFoundError('Tiny AE mdoel not found') + + return sd_vae_taesd.decoder diff --git a/modules/shared.py b/modules/shared.py index 4631965b..6760a900 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
}))
@@ -144,10 +144,21 @@ Use --skip-version-check commandline argument to disable this check. """.strip())
+def check_taesd():
+ from modules.paths_internal import models_path
+
+ model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
+ model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth")
+ if not os.path.exists(model_path):
+ print('download taesd model')
+ torch.hub.download_url_to_file(model_url, os.path.dirname(model_path))
+
+
def initialize():
fix_asyncio_event_loop_policy()
check_versions()
+ check_taesd()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
|