diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-12-10 06:14:30 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-12-10 06:14:45 +0000 |
commit | 7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d (patch) | |
tree | 1fbb4d81ced27863380f8fab80305519b8014e59 | |
parent | 2641d1b83bca6961ab1367fd9b67ebd8a94af35b (diff) | |
download | stable-diffusion-webui-gfx803-7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d.tar.gz stable-diffusion-webui-gfx803-7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d.tar.bz2 stable-diffusion-webui-gfx803-7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d.zip |
do not replace entire unet for the resolution hack
-rw-r--r-- | modules/sd_hijack.py | 5 | ||||
-rw-r--r-- | modules/sd_hijack_optimizations.py | 28 | ||||
-rw-r--r-- | modules/sd_hijack_unet.py | 30 |
3 files changed, 33 insertions, 30 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 92874a79..47dbc1b7 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import opts, device, cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
from modules.sd_hijack_optimizations import invokeAI_mps_available
@@ -35,11 +35,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
+
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
+ ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 8cd4c954..85909eb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x): return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
-
-def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
- if self.num_classes is not None:
- assert y.shape == (x.shape[0],)
- emb = emb + self.label_emb(y)
-
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context)
- hs.append(h)
- h = self.middle_block(h, emb, context)
- for module in self.output_blocks:
- if h.shape[-2:] != hs[-1].shape[-2:]:
- h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
- h = torch.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context)
- h = h.type(x.dtype)
- if self.predict_codebook_ids:
- return self.id_predictor(h)
- else:
- return self.out(h)
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py new file mode 100644 index 00000000..1b9d7757 --- /dev/null +++ b/modules/sd_hijack_unet.py @@ -0,0 +1,30 @@ +import torch
+
+
+class TorchHijackForUnet:
+ """
+ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
+ this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
+ """
+
+ def __getattr__(self, item):
+ if item == 'cat':
+ return self.cat
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+ def cat(self, tensors, *args, **kwargs):
+ if len(tensors) == 2:
+ a, b = tensors
+ if a.shape[-2:] != b.shape[-2:]:
+ a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
+
+ tensors = (a, b)
+
+ return torch.cat(tensors, *args, **kwargs)
+
+
+th = TorchHijackForUnet()
|