diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-22 04:15:34 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-22 04:15:34 +0000 |
commit | 8137bdba61fd57cc1ddae801f6080d51e13d70c5 (patch) | |
tree | c5a02e9f9ae57c9f0ff8499379c6cc61a97c094e /extensions-builtin/LDSR/sd_hijack_autoencoder.py | |
parent | a862428902c4aecde8852761c3a4d95c196885cb (diff) | |
parent | 3366e494a1147e570d8527eea19da88edb3a1e0c (diff) | |
download | stable-diffusion-webui-gfx803-8137bdba61fd57cc1ddae801f6080d51e13d70c5.tar.gz stable-diffusion-webui-gfx803-8137bdba61fd57cc1ddae801f6080d51e13d70c5.tar.bz2 stable-diffusion-webui-gfx803-8137bdba61fd57cc1ddae801f6080d51e13d70c5.zip |
Merge branch 'dev' into text-drag-fix
Diffstat (limited to 'extensions-builtin/LDSR/sd_hijack_autoencoder.py')
-rw-r--r-- | extensions-builtin/LDSR/sd_hijack_autoencoder.py | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 8e03c7f8..81c5101b 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -1,16 +1,21 @@ # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder - +import numpy as np import torch import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager + +from torch.optim.lr_scheduler import LambdaLR + +from ldm.modules.ema import LitEma from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config import ldm.models.autoencoder +from packaging import version class VQModel(pl.LightningModule): def __init__(self, @@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): n_embed, embed_dim, ckpt_path=None, - ignore_keys=[], + ignore_keys=None, image_key="image", colorize_nlabels=None, monitor=None, @@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or []) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor @@ -76,11 +81,11 @@ class VQModel(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -165,7 +170,7 @@ class VQModel(pl.LightningModule): def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + self._validation_step(batch, batch_idx, suffix="_ema") return log_dict def _validation_step(self, batch, batch_idx, suffix=""): @@ -232,7 +237,7 @@ class VQModel(pl.LightningModule): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: @@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log @@ -264,7 +270,7 @@ class VQModel(pl.LightningModule): class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) + super().__init__(*args, embed_dim=embed_dim, **kwargs) self.embed_dim = embed_dim def encode(self, x): @@ -282,5 +288,5 @@ class VQModelInterface(VQModel): dec = self.decoder(quant) return dec -setattr(ldm.models.autoencoder, "VQModel", VQModel) -setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) +ldm.models.autoencoder.VQModel = VQModel +ldm.models.autoencoder.VQModelInterface = VQModelInterface |