aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-02-07 05:05:54 +0000
committerbrkirch <brkirch@users.noreply.github.com>2023-02-08 03:53:45 +0000
commit2016733814433ca2b69d10764bfa0ab4c7088782 (patch)
treece7bc91d6f8d19ca403d6f4ed633887c6d5e4132
parent4738486d8f528a98a525970ac06a109431fd7344 (diff)
downloadstable-diffusion-webui-gfx803-2016733814433ca2b69d10764bfa0ab4c7088782.tar.gz
stable-diffusion-webui-gfx803-2016733814433ca2b69d10764bfa0ab4c7088782.tar.bz2
stable-diffusion-webui-gfx803-2016733814433ca2b69d10764bfa0ab4c7088782.zip
Apply hijacks in ddpm_edit for upcast sampling
To avoid import errors, ddpm_edit hijacks are done after an instruct pix2pix model is loaded.
-rw-r--r--modules/sd_hijack.py3
-rw-r--r--modules/sd_hijack_unet.py11
2 files changed, 14 insertions, 0 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 8fdc5990..fca418cd 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -104,6 +104,9 @@ class StableDiffusionModelHijack:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ if m.cond_stage_key == "edit":
+ sd_hijack_unet.hijack_ddpm_edit()
+
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 45cf2b18..843ab66c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
@@ -53,6 +54,16 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
else:
return torch.nn.GELU.forward(self, x)
+
+ddpm_edit_hijack = None
+def hijack_ddpm_edit():
+ global ddpm_edit_hijack
+ if not ddpm_edit_hijack:
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+ ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+
+
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)