diff options
author | InvincibleDude <81354513+InvincibleDude@users.noreply.github.com> | 2023-03-03 16:49:24 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-03 16:49:24 +0000 |
commit | e97b83bdbb852fd2c06ed2ec6c0f92d458e82245 (patch) | |
tree | 0f0b3779120d602d1d833a50e33c9f73b218f684 /modules/sd_hijack_unet.py | |
parent | 51f81efb02876d24c9e6d844e8c0cbd2384f6514 (diff) | |
parent | 0cc0ee1bcb4c24a8c9715f66cede06601bfc00c8 (diff) | |
download | stable-diffusion-webui-gfx803-e97b83bdbb852fd2c06ed2ec6c0f92d458e82245.tar.gz stable-diffusion-webui-gfx803-e97b83bdbb852fd2c06ed2ec6c0f92d458e82245.tar.bz2 stable-diffusion-webui-gfx803-e97b83bdbb852fd2c06ed2ec6c0f92d458e82245.zip |
Merge branch 'master' into improved-hr-conflict-test
Diffstat (limited to 'modules/sd_hijack_unet.py')
-rw-r--r-- | modules/sd_hijack_unet.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 45cf2b18..843ab66c 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
@@ -53,6 +54,16 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): else:
return torch.nn.GELU.forward(self, x)
+
+ddpm_edit_hijack = None
+def hijack_ddpm_edit():
+ global ddpm_edit_hijack
+ if not ddpm_edit_hijack:
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+ ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+
+
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|