From f174fb29228a04955fb951b32b0bab79e33ec2b8 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:21:49 +0300 Subject: add xformers attention --- modules/sd_hijack_optimizations.py | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..da1b76e1 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,9 @@ import math import torch from torch import einsum - +import xformers.ops +import functorch +xformers._is_functorch_available=True from ldm.util import default from einops import rearrange @@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + _, M, K = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + +def xformers_attention_forward(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) + v_in = self.to_v(context) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + self._maybe_init(q) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) -- cgit v1.2.3 From c9cc65b201679ea43c763b0d85e749d40bbc5433 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 04:09:18 +0300 Subject: switch to the proper way of calling xformers --- modules/sd_hijack_optimizations.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index da1b76e1..7fb4a45e 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -94,39 +94,17 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def _maybe_init(self, x): - """ - Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x - : B, Head, Length - """ - if self.attention_op is not None: - return - _, M, K = x.shape - try: - self.attention_op = xformers.ops.AttentionOpDispatch( - dtype=x.dtype, - device=x.device, - k=K, - attn_bias_type=type(None), - has_dropout=False, - kv_len=M, - q_len=M, - ).op - except NotImplementedError as err: - raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") - def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) v_in = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - self._maybe_init(q) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) def cross_attention_attnblock_forward(self, x): -- cgit v1.2.3 From f2055cb1d4ce45d7aaacc49d8ab5bec7791a8f47 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 8 Oct 2022 01:47:02 -0400 Subject: Add hypernetwork support to split cross attention v1 * Add hypernetwork support to split_cross_attention_forward_v1 * Fix device check in esrgan_model.py to use devices.device_esrgan instead of shared.device --- modules/sd_hijack_optimizations.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index d9cca485..3351c740 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -12,13 +12,22 @@ from modules import shared def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) for i in range(0, q.shape[0], 2): @@ -31,6 +40,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) del s2 + del q, k, v r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 -- cgit v1.2.3 From 5d54f35c583bd5a3b0ee271a862827f1ca81ef09 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:55:02 +0300 Subject: add xformers attnblock and hypernetwork support --- modules/sd_hijack_optimizations.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 7fb4a45e..c78d5838 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -98,8 +98,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) - v_in = self.to_v(context) + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) @@ -169,3 +175,13 @@ def cross_attention_attnblock_forward(self, x): h3 += x return h3 + + def xformers_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_).contiguous() + k1 = self.k(h_).contiguous() + v = self.v(h_).contiguous() + out = xformers.ops.memory_efficient_attention(q1, k1, v) + out = self.proj_out(out) + return x+out -- cgit v1.2.3 From 76a616fa6b814c681eaf6edc87eb3001b8c2b6be Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:55:38 +0300 Subject: Update sd_hijack_optimizations.py --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c78d5838..ee58c7e4 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -176,7 +176,7 @@ def cross_attention_attnblock_forward(self, x): return h3 - def xformers_attnblock_forward(self, x): +def xformers_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) q1 = self.q(h_).contiguous() -- cgit v1.2.3 From 69d0053583757ce2942d62de81e8b89e6be07840 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:21:40 +0300 Subject: update sd_hijack_opt to respect new env variables --- modules/sd_hijack_optimizations.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ee58c7e4..be09ec8f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,9 +1,14 @@ import math import torch from torch import einsum -import xformers.ops -import functorch -xformers._is_functorch_available=True +try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True +except: + print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.') + continue from ldm.util import default from einops import rearrange -- cgit v1.2.3 From 7ff1170a2e11b6f00f587407326db0b9f8f51adf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 16:33:39 +0300 Subject: emergency fix for xformers (continue + shared) --- modules/sd_hijack_optimizations.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index e43e2c7a..05023b6f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,19 +1,19 @@ import math import torch from torch import einsum -try: - import xformers.ops - import functorch - xformers._is_functorch_available = True - shared.xformers_available = True -except: - print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.') - continue + from ldm.util import default from einops import rearrange from modules import shared +try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True +except Exception: + print('Cannot find xformers, defaulting to split attention. Try adding --xformers commandline argument to your webui-user file if you wish to install it.') # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): -- cgit v1.2.3 From dc1117233ef8f9b25ff1ac40b158f20b70ba2fcb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:02:18 +0300 Subject: simplify xfrmers options: --xformers to enable and that's it --- modules/sd_hijack_optimizations.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 05023b6f..d23d733b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,4 +1,7 @@ import math +import sys +import traceback + import torch from torch import einsum @@ -7,13 +10,16 @@ from einops import rearrange from modules import shared -try: - import xformers.ops - import functorch - xformers._is_functorch_available = True - shared.xformers_available = True -except Exception: - print('Cannot find xformers, defaulting to split attention. Try adding --xformers commandline argument to your webui-user file if you wish to install it.') +if shared.cmd_opts.xformers: + try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True + except Exception: + print("Cannot import xformers", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): -- cgit v1.2.3