diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-25 16:12:29 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-25 16:12:29 +0000 |
commit | 1574e967297586d013e4cfbb6628eae595c9fba2 (patch) | |
tree | 99374009f63cf73cadc713b02db5fbba0701e516 /modules/sd_hijack_unet.py | |
parent | 1982ef68900fe3c5eee704dfbda5416c1bb5470b (diff) | |
parent | e3b53fd295aca784253dfc8668ec87b537a72f43 (diff) | |
download | stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.tar.gz stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.tar.bz2 stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.zip |
Merge pull request #6510 from brkirch/unet16-upcast-precision
Add upcast options, full precision sampling from float16 UNet and upcasting attention for inference using SD 2.1 models without --no-half
Diffstat (limited to 'modules/sd_hijack_unet.py')
-rw-r--r-- | modules/sd_hijack_unet.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 18daf8c1..88c94e54 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,4 +1,8 @@ import torch
+from packaging import version
+
+from modules import devices
+from modules.sd_hijack_utils import CondFunc
class TorchHijackForUnet:
@@ -28,3 +32,28 @@ class TorchHijackForUnet: th = TorchHijackForUnet()
+
+
+# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
+def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+ with devices.autocast():
+ return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
+class GELUHijack(torch.nn.GELU, torch.nn.Module):
+ def __init__(self, *args, **kwargs):
+ torch.nn.GELU.__init__(self, *args, **kwargs)
+ def forward(self, x):
+ if devices.unet_needs_upcast:
+ return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
+ else:
+ return torch.nn.GELU.forward(self, x)
+
+unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
+if version.parse(torch.__version__) <= version.parse("1.13.1"):
+ CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
+ CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
+ CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|