diff options
author | Billy Cao <aliencaocao@gmail.com> | 2022-11-23 10:11:24 +0000 |
---|---|---|
committer | Billy Cao <aliencaocao@gmail.com> | 2022-11-23 10:11:24 +0000 |
commit | adb6cb7619989cbc7a271cc6c2ae27bb936c43d9 (patch) | |
tree | 164da7276d0dcb00d3f6871c9099604a05151277 /modules/sd_hijack_optimizations.py | |
parent | 828438b4a190759807f9054932cae3a8b880ddf1 (diff) | |
download | stable-diffusion-webui-gfx803-adb6cb7619989cbc7a271cc6c2ae27bb936c43d9.tar.gz stable-diffusion-webui-gfx803-adb6cb7619989cbc7a271cc6c2ae27bb936c43d9.tar.bz2 stable-diffusion-webui-gfx803-adb6cb7619989cbc7a271cc6c2ae27bb936c43d9.zip |
Patch UNet Forward to support resolutions that are not multiples of 64
Also modifed the UI to no longer step in 64
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..8cd4c954 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,7 @@ import importlib import torch
from torch import einsum
+import torch.nn.functional as F
from ldm.util import default
from einops import rearrange
@@ -12,6 +13,8 @@ from einops import rearrange from modules import shared
from modules.hypernetworks import hypernetwork
+from ldm.modules.diffusionmodules.util import timestep_embedding
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x): return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
+def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ if h.shape[-2:] != hs[-1].shape[-2:]:
+ h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
|