From adb6cb7619989cbc7a271cc6c2ae27bb936c43d9 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Wed, 23 Nov 2022 18:11:24 +0800 Subject: Patch UNet Forward to support resolutions that are not multiples of 64 Also modifed the UI to no longer step in 64 --- modules/sd_hijack_optimizations.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..8cd4c954 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,7 @@ import importlib import torch from torch import einsum +import torch.nn.functional as F from ldm.util import default from einops import rearrange @@ -12,6 +13,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from ldm.modules.diffusionmodules.util import timestep_embedding + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + if h.shape[-2:] != hs[-1].shape[-2:]: + h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) -- cgit v1.2.3 From 7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:14:30 +0300 Subject: do not replace entire unet for the resolution hack --- modules/sd_hijack.py | 5 +++-- modules/sd_hijack_optimizations.py | 28 ---------------------------- modules/sd_hijack_unet.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 30 deletions(-) create mode 100644 modules/sd_hijack_unet.py (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 92874a79..47dbc1b7 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import opts, device, cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -35,11 +35,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] ldm.modules.attention.print = lambda *args: None ldm.modules.diffusionmodules.model.print = lambda *args: None + def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward + ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 8cd4c954..85909eb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) - -def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - if h.shape[-2:] != hs[-1].shape[-2:]: - h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py new file mode 100644 index 00000000..1b9d7757 --- /dev/null +++ b/modules/sd_hijack_unet.py @@ -0,0 +1,30 @@ +import torch + + +class TorchHijackForUnet: + """ + This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; + this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64 + """ + + def __getattr__(self, item): + if item == 'cat': + return self.cat + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + def cat(self, tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + if a.shape[-2:] != b.shape[-2:]: + a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") + + tensors = (a, b) + + return torch.cat(tensors, *args, **kwargs) + + +th = TorchHijackForUnet() -- cgit v1.2.3 From 505ec7e4d960e7bea579182509050fafb10bd00c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:17:39 +0300 Subject: cleanup some unneeded imports for hijack files --- modules/sd_hijack.py | 10 ++-------- modules/sd_hijack_optimizations.py | 3 --- 2 files changed, 2 insertions(+), 11 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 47dbc1b7..690a9ec2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,16 +1,10 @@ -import math -import os -import sys -import traceback import torch -import numpy as np -from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint +from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork -from modules.shared import opts, device, cmd_opts +from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet from modules.sd_hijack_optimizations import invokeAI_mps_available diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 85909eb9..98123fbf 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,7 +5,6 @@ import importlib import torch from torch import einsum -import torch.nn.functional as F from ldm.util import default from einops import rearrange @@ -13,8 +12,6 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork -from ldm.modules.diffusionmodules.util import timestep_embedding - if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: -- cgit v1.2.3 From 35b1775b32a07f1b7c9dccad61f7aa77027a00fa Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 19 Dec 2022 17:25:14 -0500 Subject: Use other MPS optimization for large q.shape[0] * q.shape[1] Check if q.shape[0] * q.shape[1] is 2**18 or larger and use the lower memory usage MPS optimization if it is. This should prevent most crashes that were occurring at certain resolutions (e.g. 1024x1024, 2048x512, 512x2048). Also included is a change to check slice_size and prevent it from being divisible by 4096 which also results in a crash. Otherwise a crash can occur at 1024x512 or 512x1024 resolution. --- modules/sd_hijack_optimizations.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..02c87f40 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -127,7 +127,7 @@ def check_for_psutil(): invokeAI_mps_available = check_for_psutil() -# -- Taken from https://github.com/invoke-ai/InvokeAI -- +# -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- if invokeAI_mps_available: import psutil mem_total_gb = psutil.virtual_memory().total // (1 << 30) @@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size): return r def einsum_op_mps_v1(q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + if slice_size % 4096 == 0: + slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) def einsum_op_mps_v2(q, k, v): - if mem_total_gb > 8 and q.shape[1] <= 4096: + if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) @@ -188,7 +190,7 @@ def einsum_op(q, k, v): return einsum_op_cuda(q, k, v) if q.device.type == 'mps': - if mem_total_gb >= 32: + if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: return einsum_op_mps_v1(q, k, v) return einsum_op_mps_v2(q, k, v) -- cgit v1.2.3