From adb6cb7619989cbc7a271cc6c2ae27bb936c43d9 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Wed, 23 Nov 2022 18:11:24 +0800 Subject: Patch UNet Forward to support resolutions that are not multiples of 64 Also modifed the UI to no longer step in 64 --- modules/sd_hijack_optimizations.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..8cd4c954 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,7 @@ import importlib import torch from torch import einsum +import torch.nn.functional as F from ldm.util import default from einops import rearrange @@ -12,6 +13,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from ldm.modules.diffusionmodules.util import timestep_embedding + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + if h.shape[-2:] != hs[-1].shape[-2:]: + h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) -- cgit v1.2.3 From 7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:14:30 +0300 Subject: do not replace entire unet for the resolution hack --- modules/sd_hijack_optimizations.py | 28 ---------------------------- 1 file changed, 28 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 8cd4c954..85909eb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) - -def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - if h.shape[-2:] != hs[-1].shape[-2:]: - h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) -- cgit v1.2.3 From 505ec7e4d960e7bea579182509050fafb10bd00c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:17:39 +0300 Subject: cleanup some unneeded imports for hijack files --- modules/sd_hijack_optimizations.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 85909eb9..98123fbf 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,7 +5,6 @@ import importlib import torch from torch import einsum -import torch.nn.functional as F from ldm.util import default from einops import rearrange @@ -13,8 +12,6 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork -from ldm.modules.diffusionmodules.util import timestep_embedding - if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: -- cgit v1.2.3 From 35b1775b32a07f1b7c9dccad61f7aa77027a00fa Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 19 Dec 2022 17:25:14 -0500 Subject: Use other MPS optimization for large q.shape[0] * q.shape[1] Check if q.shape[0] * q.shape[1] is 2**18 or larger and use the lower memory usage MPS optimization if it is. This should prevent most crashes that were occurring at certain resolutions (e.g. 1024x1024, 2048x512, 512x2048). Also included is a change to check slice_size and prevent it from being divisible by 4096 which also results in a crash. Otherwise a crash can occur at 1024x512 or 512x1024 resolution. --- modules/sd_hijack_optimizations.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..02c87f40 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -127,7 +127,7 @@ def check_for_psutil(): invokeAI_mps_available = check_for_psutil() -# -- Taken from https://github.com/invoke-ai/InvokeAI -- +# -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- if invokeAI_mps_available: import psutil mem_total_gb = psutil.virtual_memory().total // (1 << 30) @@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size): return r def einsum_op_mps_v1(q, k, v): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + if slice_size % 4096 == 0: + slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) def einsum_op_mps_v2(q, k, v): - if mem_total_gb > 8 and q.shape[1] <= 4096: + if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) @@ -188,7 +190,7 @@ def einsum_op(q, k, v): return einsum_op_cuda(q, k, v) if q.device.type == 'mps': - if mem_total_gb >= 32: + if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: return einsum_op_mps_v1(q, k, v) return einsum_op_mps_v2(q, k, v) -- cgit v1.2.3 From d782a95967c9eea753df3333cd1954b6ec73eba0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 27 Dec 2022 08:50:55 -0500 Subject: Add Birch-san's sub-quadratic attention implementation --- modules/sd_hijack_optimizations.py | 124 +++++++++++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 25 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 02c87f40..f5c153e8 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,7 @@ import math import sys import traceback -import importlib +import psutil import torch from torch import einsum @@ -12,6 +12,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from .sub_quadratic_attention import efficient_dot_product_attention + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: print(traceback.format_exc(), file=sys.stderr) +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads @@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() @@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def check_for_psutil(): - try: - spec = importlib.util.find_spec('psutil') - return spec is not None - except ModuleNotFoundError: - return False - -invokeAI_mps_available = check_for_psutil() - # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -if invokeAI_mps_available: - import psutil - mem_total_gb = psutil.virtual_memory().total // (1 << 30) +mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) @@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # -- End of code from https://github.com/invoke-ai/InvokeAI -- + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +def sub_quad_attention_forward(self, x, context=None, mask=None): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + + if chunk_threshold_bytes is None: + chunk_threshold_bytes = available_vram + elif chunk_threshold_bytes == 0: + chunk_threshold_bytes = None + + if kv_chunk_size_min is None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x): h_ = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 @@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out -- cgit v1.2.3 From b95a4c0ce5ab9c414e0494193bfff665f45e9e65 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:01:51 -0500 Subject: Change sub-quad chunk threshold to use percentage --- modules/sd_hijack_optimizations.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f5c153e8..b416e9ac 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) @@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x -def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape _, k_tokens, _ = k.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) - - if chunk_threshold_bytes is None: - chunk_threshold_bytes = available_vram - elif chunk_threshold_bytes == 0: + if chunk_threshold is None: + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + elif chunk_threshold == 0: chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) - if kv_chunk_size_min is None: + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) elif kv_chunk_size_min == 0: kv_chunk_size_min = None @@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out -- cgit v1.2.3 From c18add68ef7d2de3617cbbaff864b0c74cfdf6c0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 16:42:47 -0500 Subject: Added license --- modules/sd_hijack_optimizations.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b416e9ac..cdc63ed7 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -216,6 +216,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface def sub_quad_attention_forward(self, x, context=None, mask=None): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." -- cgit v1.2.3