diff options
Diffstat (limited to 'extensions-builtin/LDSR')
-rw-r--r-- | extensions-builtin/LDSR/ldsr_model_arch.py | 13 | ||||
-rw-r--r-- | extensions-builtin/LDSR/scripts/ldsr_model.py | 29 | ||||
-rw-r--r-- | extensions-builtin/LDSR/sd_hijack_autoencoder.py | 33 | ||||
-rw-r--r-- | extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 70 | ||||
-rw-r--r-- | extensions-builtin/LDSR/vqvae_quantize.py | 147 |
5 files changed, 216 insertions, 76 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index bc11cc6e..7f450086 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -88,7 +88,7 @@ class LDSR: x_t = None logs = None - for n in range(n_runs): + for _ in range(n_runs): if custom_shape is not None: x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) @@ -110,7 +110,6 @@ class LDSR: diffusion_steps = int(steps) eta = 1.0 - down_sample_method = 'Lanczos' gc.collect() if torch.cuda.is_available: @@ -131,11 +130,11 @@ class LDSR: im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) else: print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) - + logs = self.run(model["model"], im_padded, diffusion_steps, eta) sample = logs["sample"] @@ -158,7 +157,7 @@ class LDSR: def get_cond(selected_path): - example = dict() + example = {} up_f = 4 c = selected_path.convert('RGB') c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) @@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s @torch.no_grad() def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): - log = dict() + log = {} z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, return_first_stage_outputs=True, @@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) log["sample_noquant"] = x_sample_noquant log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except: + except Exception: pass log["sample"] = x_sample diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index da19cff1..bd78dece 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,13 +1,11 @@ import os -import sys -import traceback - -from basicsr.utils.download_util import load_file_from_url +from modules.modelloader import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR -from modules import shared, script_callbacks -import sd_hijack_autoencoder, sd_hijack_ddpm_v1 +from modules import shared, script_callbacks, errors +import sd_hijack_autoencoder # noqa: F401 +import sd_hijack_ddpm_v1 # noqa: F401 class UpscalerLDSR(Upscaler): @@ -44,22 +42,17 @@ class UpscalerLDSR(Upscaler): if local_safetensors_path is not None and os.path.exists(local_safetensors_path): model = local_safetensors_path else: - model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True) - - yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True) + model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt") - try: - return LDSR(model, yaml) + yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - return None + return LDSR(model, yaml) def do_upscale(self, img, path): - ldsr = self.load_model(path) - if ldsr is None: - print("NO LDSR!") + try: + ldsr = self.load_model(path) + except Exception: + errors.report(f"Failed loading LDSR model {path}", exc_info=True) return img ddim_steps = shared.opts.ldsr_steps return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 8e03c7f8..c29d274d 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -1,16 +1,21 @@ # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder - +import numpy as np import torch import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from torch.optim.lr_scheduler import LambdaLR + +from ldm.modules.ema import LitEma +from vqvae_quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config import ldm.models.autoencoder +from packaging import version class VQModel(pl.LightningModule): def __init__(self, @@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): n_embed, embed_dim, ckpt_path=None, - ignore_keys=[], + ignore_keys=None, image_key="image", colorize_nlabels=None, monitor=None, @@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or []) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor @@ -76,18 +81,19 @@ class VQModel(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") + if unexpected: print(f"Unexpected Keys: {unexpected}") def on_train_batch_end(self, *args, **kwargs): @@ -165,7 +171,7 @@ class VQModel(pl.LightningModule): def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + self._validation_step(batch, batch_idx, suffix="_ema") return log_dict def _validation_step(self, batch, batch_idx, suffix=""): @@ -232,7 +238,7 @@ class VQModel(pl.LightningModule): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: @@ -249,7 +255,8 @@ class VQModel(pl.LightningModule): if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log @@ -264,7 +271,7 @@ class VQModel(pl.LightningModule): class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) + super().__init__(*args, embed_dim=embed_dim, **kwargs) self.embed_dim = embed_dim def encode(self, x): @@ -282,5 +289,5 @@ class VQModelInterface(VQModel): dec = self.decoder(quant) return dec -setattr(ldm.models.autoencoder, "VQModel", VQModel) -setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) +ldm.models.autoencoder.VQModel = VQModel +ldm.models.autoencoder.VQModelInterface = VQModelInterface diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 5c0488e5..04adc5eb 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): if monitor is not None: self.monitor = monitor if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) @@ -182,22 +182,22 @@ class DDPMV1(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): @@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1): self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: @@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1): z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): + if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] @@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1): if hasattr(self, "split_input_params"): assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids + assert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) @@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] @@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, @@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): use_ddim = ddim_steps is not None - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: @@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. @@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] @@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): logs['bbox_image'] = cond_img return logs -setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) -setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) -setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) -setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) +ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1 +ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1 +ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1 +ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1 diff --git a/extensions-builtin/LDSR/vqvae_quantize.py b/extensions-builtin/LDSR/vqvae_quantize.py new file mode 100644 index 00000000..dd14b8fd --- /dev/null +++ b/extensions-builtin/LDSR/vqvae_quantize.py @@ -0,0 +1,147 @@ +# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py, +# where the license is as follows: +# +# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +# OR OTHER DEALINGS IN THE SOFTWARE./ + +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q |