aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-01-25 05:23:10 +0000
committerbrkirch <brkirch@users.noreply.github.com>2023-01-25 06:13:04 +0000
commite3b53fd295aca784253dfc8668ec87b537a72f43 (patch)
tree6fb26afd730c0561a2506ead2d2c8295d326de40 /modules/sd_hijack_optimizations.py
parent84d9ce30cb427759547bc7876ed80ab91787d175 (diff)
downloadstable-diffusion-webui-gfx803-e3b53fd295aca784253dfc8668ec87b537a72f43.tar.gz
stable-diffusion-webui-gfx803-e3b53fd295aca784253dfc8668ec87b537a72f43.tar.bz2
stable-diffusion-webui-gfx803-e3b53fd295aca784253dfc8668ec87b537a72f43.zip
Add UI setting for upcasting attention to float32
Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also.
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py159
1 files changed, 99 insertions, 60 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 74452709..c02d954c 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, errors
+from modules import shared, errors, devices
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
- s2 = s1.softmax(dim=-1)
- del s1
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+
+ s2 = s1.softmax(dim=-1)
+ del s1
+
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+ del q, k, v
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
- del q, k, v
+ r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- k_in *= self.scale
-
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- mem_free_total = get_available_vram()
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+ dtype = q_in.dtype
+ if shared.opts.upcast_attn:
+ q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k_in = k_in * self.scale
+
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+
+ mem_free_total = get_available_vram()
+
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
+
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+
+ del q, k, v
- del q, k, v
+ r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k = self.to_k(context_k) * self.scale
+ k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r = einsum_op(q, k, v)
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
+
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k = k * self.scale
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ r = einsum_op(q, k, v)
+ r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
# -- End of code from https://github.com/invoke-ai/InvokeAI --
@@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ x = x.to(dtype)
+
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
@@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
- return efficient_dot_product_attention(
- q,
- k,
- v,
- query_chunk_size=q_chunk_size,
- kv_chunk_size=kv_chunk_size,
- kv_chunk_size_min = kv_chunk_size_min,
- use_checkpoint=use_checkpoint,
- )
+ with devices.without_autocast(disable=q.dtype == v.dtype):
+ return efficient_dot_product_attention(
+ q,
+ k,
+ v,
+ query_chunk_size=q_chunk_size,
+ kv_chunk_size=kv_chunk_size,
+ kv_chunk_size_min = kv_chunk_size_min,
+ use_checkpoint=use_checkpoint,
+ )
def get_xformers_flash_attention_op(q, k, v):
@@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
+ out = out.to(dtype)
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
+ out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out