diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-28 15:44:36 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-28 15:44:36 +0000 |
commit | fecb990debdae2cc99b64808d22ba902e34e575b (patch) | |
tree | 5531204920675c2bbb9189e41cf331554b92f062 /modules/sd_hijack_unet.py | |
parent | 41e76d12093dd78315925f923bb1cc0f541e5f54 (diff) | |
parent | f9edd578e9e29d160e6d56038bb368dc49895d64 (diff) | |
download | stable-diffusion-webui-gfx803-fecb990debdae2cc99b64808d22ba902e34e575b.tar.gz stable-diffusion-webui-gfx803-fecb990debdae2cc99b64808d22ba902e34e575b.tar.bz2 stable-diffusion-webui-gfx803-fecb990debdae2cc99b64808d22ba902e34e575b.zip |
Merge pull request #7309 from brkirch/fix-embeddings
Fix embeddings, upscalers, and refactor `--upcast-sampling`
Diffstat (limited to 'modules/sd_hijack_unet.py')
-rw-r--r-- | modules/sd_hijack_unet.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index a6ee577c..45cf2b18 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
+
+first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
+first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|