diff options
author | brkirch <brkirch@users.noreply.github.com> | 2023-01-06 06:01:51 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2023-01-06 06:01:51 +0000 |
commit | b95a4c0ce5ab9c414e0494193bfff665f45e9e65 (patch) | |
tree | 64709ed9c735f1e89da0d0d5ddcad6a32005a5f8 /modules/sd_hijack_optimizations.py | |
parent | b119815333026164f2bd7d1ca71f3e4f7a9afd0d (diff) | |
download | stable-diffusion-webui-gfx803-b95a4c0ce5ab9c414e0494193bfff665f45e9e65.tar.gz stable-diffusion-webui-gfx803-b95a4c0ce5ab9c414e0494193bfff665f45e9e65.tar.bz2 stable-diffusion-webui-gfx803-b95a4c0ce5ab9c414e0494193bfff665f45e9e65.zip |
Change sub-quad chunk threshold to use percentage
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f5c153e8..b416e9ac 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
- x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
@@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x
-def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True):
+def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
- available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
-
- if chunk_threshold_bytes is None:
- chunk_threshold_bytes = available_vram
- elif chunk_threshold_bytes == 0:
+ if chunk_threshold is None:
+ chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+ elif chunk_threshold == 0:
chunk_threshold_bytes = None
+ else:
+ chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
- if kv_chunk_size_min is None:
+ if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
elif kv_chunk_size_min == 0:
kv_chunk_size_min = None
@@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x): q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
- out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
|