diff options
author | brkirch <brkirch@users.noreply.github.com> | 2023-01-25 04:51:45 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2023-01-25 06:13:02 +0000 |
commit | 84d9ce30cb427759547bc7876ed80ab91787d175 (patch) | |
tree | a87ca1a7094ca9b7af4e573a211b1dcf8146af67 /modules/sd_hijack_unet.py | |
parent | 48a15821de768fea76e66f26df83df3fddf18f4b (diff) | |
download | stable-diffusion-webui-gfx803-84d9ce30cb427759547bc7876ed80ab91787d175.tar.gz stable-diffusion-webui-gfx803-84d9ce30cb427759547bc7876ed80ab91787d175.tar.bz2 stable-diffusion-webui-gfx803-84d9ce30cb427759547bc7876ed80ab91787d175.zip |
Add option for float32 sampling with float16 UNet
This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half().
Diffstat (limited to 'modules/sd_hijack_unet.py')
-rw-r--r-- | modules/sd_hijack_unet.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 18daf8c1..88c94e54 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,4 +1,8 @@ import torch
+from packaging import version
+
+from modules import devices
+from modules.sd_hijack_utils import CondFunc
class TorchHijackForUnet:
@@ -28,3 +32,28 @@ class TorchHijackForUnet: th = TorchHijackForUnet()
+
+
+# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
+def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+ with devices.autocast():
+ return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
+class GELUHijack(torch.nn.GELU, torch.nn.Module):
+ def __init__(self, *args, **kwargs):
+ torch.nn.GELU.__init__(self, *args, **kwargs)
+ def forward(self, x):
+ if devices.unet_needs_upcast:
+ return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
+ else:
+ return torch.nn.GELU.forward(self, x)
+
+unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
+if version.parse(torch.__version__) <= version.parse("1.13.1"):
+ CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
+ CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
+ CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|