aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/LDSR/sd_hijack_autoencoder.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-10 08:19:16 +0000
committerAUTOMATIC <16777216c@gmail.com>2023-05-10 08:19:16 +0000
commit550256db1ce18778a9d56ff343d844c61b9f9b83 (patch)
treea17e8fd9cb475381c361844970ba2d9111938b6d /extensions-builtin/LDSR/sd_hijack_autoencoder.py
parent028d3f6425d85f122027c127fba8bcbf4f66ee75 (diff)
downloadstable-diffusion-webui-gfx803-550256db1ce18778a9d56ff343d844c61b9f9b83.tar.gz
stable-diffusion-webui-gfx803-550256db1ce18778a9d56ff343d844c61b9f9b83.tar.bz2
stable-diffusion-webui-gfx803-550256db1ce18778a9d56ff343d844c61b9f9b83.zip
ruff manual fixes
Diffstat (limited to 'extensions-builtin/LDSR/sd_hijack_autoencoder.py')
-rw-r--r--extensions-builtin/LDSR/sd_hijack_autoencoder.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index f457ca93..8cc82d54 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -24,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed,
embed_dim,
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
@@ -62,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@@ -81,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list()):
+ def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
- for ik in ignore_keys:
+ for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
@@ -270,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):