From adb6cb7619989cbc7a271cc6c2ae27bb936c43d9 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Wed, 23 Nov 2022 18:11:24 +0800 Subject: Patch UNet Forward to support resolutions that are not multiples of 64 Also modifed the UI to no longer step in 64 --- modules/sd_hijack_optimizations.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..8cd4c954 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,7 @@ import importlib import torch from torch import einsum +import torch.nn.functional as F from ldm.util import default from einops import rearrange @@ -12,6 +13,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from ldm.modules.diffusionmodules.util import timestep_embedding + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + if h.shape[-2:] != hs[-1].shape[-2:]: + h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) -- cgit v1.2.3 From 7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:14:30 +0300 Subject: do not replace entire unet for the resolution hack --- modules/sd_hijack_optimizations.py | 28 ---------------------------- 1 file changed, 28 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 8cd4c954..85909eb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) - -def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - if h.shape[-2:] != hs[-1].shape[-2:]: - h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) -- cgit v1.2.3 From 505ec7e4d960e7bea579182509050fafb10bd00c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:17:39 +0300 Subject: cleanup some unneeded imports for hijack files --- modules/sd_hijack_optimizations.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 85909eb9..98123fbf 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,7 +5,6 @@ import importlib import torch from torch import einsum -import torch.nn.functional as F from ldm.util import default from einops import rearrange @@ -13,8 +12,6 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork -from ldm.modules.diffusionmodules.util import timestep_embedding - if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: -- cgit v1.2.3 From 35b1775b32a07f1b7c9dccad61f7aa77027a00fa Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 19 Dec 2022 17:25:14 -0500 Subject: Use other MPS optimization for large q.shape[0] * q.shape[1] Check if q.shape[0] * q.shape[1] is 2**18 or larger and use the lower memory usage MPS optimization if it is. This should prevent most crashes that were occurring at certain resolutions (e.g. 1024x1024, 2048x512, 512x2048). Also included is a change to check slice_size and prevent it from being divisible by 4096 which also results in a crash. Otherwise a crash can occur at 1024x512 or 512x1024 resolution. --- modules/sd_hijack_optimizations.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..02c87f40 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -127,7 +127,7 @@ def check_for_psutil(): invokeAI_mps_available = check_for_psutil() -# -- Taken from https://github.com/invoke-ai/InvokeAI -- +# -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- if invokeAI_mps_available: import psutil mem_total_gb = psutil.virtual_memory().total // (1 << 30) @@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size): return r def einsum_op_mps_v1(q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + if slice_size % 4096 == 0: + slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) def einsum_op_mps_v2(q, k, v): - if mem_total_gb > 8 and q.shape[1] <= 4096: + if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) @@ -188,7 +190,7 @@ def einsum_op(q, k, v): return einsum_op_cuda(q, k, v) if q.device.type == 'mps': - if mem_total_gb >= 32: + if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: return einsum_op_mps_v1(q, k, v) return einsum_op_mps_v2(q, k, v) -- cgit v1.2.3