From e14b586d0494d6c5cc3cbc45b5fa00c03d052443 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 14 May 2023 12:42:44 +0800 Subject: Add Tiny AE live preview --- modules/sd_vae_taesd.py | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 modules/sd_vae_taesd.py (limited to 'modules/sd_vae_taesd.py') diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py new file mode 100644 index 00000000..ccc97959 --- /dev/null +++ b/modules/sd_vae_taesd.py @@ -0,0 +1,76 @@ +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) + +https://github.com/madebyollin/taesd +""" +import os +import torch +import torch.nn as nn + +from modules import devices, paths_internal + +sd_vae_taesd = None + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + @staticmethod + def forward(x): + return torch.tanh(x / 3) * 3 + + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + +def decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + + +class TAESD(nn.Module): + latent_magnitude = 2 + latent_shift = 0.5 + + def __init__(self, decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.decoder = decoder() + self.decoder.load_state_dict( + torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + +def decode(): + global sd_vae_taesd + + if sd_vae_taesd is None: + model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth") + if os.path.exists(model_path): + sd_vae_taesd = TAESD(model_path) + sd_vae_taesd.eval() + sd_vae_taesd.to(devices.device, devices.dtype) + else: + raise FileNotFoundError('Tiny AE mdoel not found') + + return sd_vae_taesd.decoder -- cgit v1.2.3 From 742da3193290f5692901c4c614c98bec291163f2 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Mon, 15 May 2023 03:04:34 +0800 Subject: Minor changes --- modules/sd_vae_taesd.py | 2 +- webui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_vae_taesd.py') diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index ccc97959..927a7298 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -71,6 +71,6 @@ def decode(): sd_vae_taesd.eval() sd_vae_taesd.to(devices.device, devices.dtype) else: - raise FileNotFoundError('Tiny AE mdoel not found') + raise FileNotFoundError('Tiny AE model not found') return sd_vae_taesd.decoder diff --git a/webui.py b/webui.py index 0a928434..0d0816bc 100644 --- a/webui.py +++ b/webui.py @@ -151,7 +151,7 @@ def check_taesd(): model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth") if not os.path.exists(model_path): print('download taesd model') - torch.hub.download_url_to_file(model_url, os.path.dirname(model_path)) + torch.hub.download_url_to_file(model_url, model_path) def initialize(): -- cgit v1.2.3 From 56a2672831751480f94a018f861f0143a8234ae8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 17 May 2023 09:24:01 +0300 Subject: return live preview defaults to how they were only download TAESD model when it's needed return calculations in single_sample_to_image to just if/elif/elif blocks keep taesd model in its own directory --- modules/sd_samplers_common.py | 29 +++++++++++++++-------------- modules/sd_vae_taesd.py | 18 +++++++++++++++--- modules/shared.py | 2 +- webui.py | 11 ----------- 4 files changed, 31 insertions(+), 29 deletions(-) (limited to 'modules/sd_vae_taesd.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index b1e8a780..20a9af20 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -22,28 +22,29 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3} +approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} def single_sample_to_image(sample, approximation=None): - if approximation is None or approximation not in approximation_indexes.keys(): - approximation = approximation_indexes.get(opts.show_progress_type, 1) - if approximation == 1: - x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample) - x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1) + if approximation is None: + approximation = approximation_indexes.get(opts.show_progress_type, 0) + + if approximation == 2: + x_sample = sd_vae_approx.cheap_approximation(sample) + elif approximation == 1: + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + elif approximation == 3: + x_sample = sd_vae_taesd.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample) # returns value in [-2, 2] + x_sample = x_sample * 0.5 else: - if approximation == 3: - x_sample = sd_vae_approx.cheap_approximation(sample) - elif approximation == 2: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) + return Image.fromarray(x_sample) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 927a7298..d23812ef 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -61,16 +61,28 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def decode(): +def download_model(model_path): + model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' + + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading TAESD decoder to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + +def model(): global sd_vae_taesd if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth") + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") + download_model(model_path) + if os.path.exists(model_path): sd_vae_taesd = TAESD(model_path) sd_vae_taesd.eval() sd_vae_taesd.to(devices.device, devices.dtype) else: - raise FileNotFoundError('Tiny AE model not found') + raise FileNotFoundError('TAESD model not found') return sd_vae_taesd.decoder diff --git a/modules/shared.py b/modules/shared.py index 6760a900..96036d38 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}), + "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") })) diff --git a/webui.py b/webui.py index 0aa03ea8..727ebd31 100644 --- a/webui.py +++ b/webui.py @@ -144,21 +144,10 @@ Use --skip-version-check commandline argument to disable this check. """.strip()) -def check_taesd(): - from modules.paths_internal import models_path - - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth") - if not os.path.exists(model_path): - print('From taesd repo download decoder model') - torch.hub.download_url_to_file(model_url, model_path) - - def initialize(): fix_asyncio_event_loop_policy() check_versions() - check_taesd() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) -- cgit v1.2.3 From 7a13a3f4ba86dc44fcf7d9944b179018744862f5 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Wed, 17 May 2023 17:39:07 +0800 Subject: TAESD fix --- modules/sd_samplers_common.py | 9 +++++---- modules/sd_vae_taesd.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'modules/sd_vae_taesd.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index ceda6a35..d99c933d 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -35,13 +35,14 @@ def single_sample_to_image(sample, approximation=None): elif approximation == 1: x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() elif approximation == 3: - x_sample = sd_vae_taesd.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample) # returns value in [-2, 2] - x_sample = x_sample * 0.5 + x_sample = sample * 1.5 + x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() else: x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + if approximation != 3: + x_sample = (x_sample + 1.0) / 2.0 + x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index d23812ef..5e8496e8 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -45,7 +45,7 @@ def decoder(): class TAESD(nn.Module): - latent_magnitude = 2 + latent_magnitude = 3 latent_shift = 0.5 def __init__(self, decoder_path="taesd_decoder.pth"): -- cgit v1.2.3