From 746783f7a47f38f728f221cc26fe04035d3ca66b Mon Sep 17 00:00:00 2001 From: Nuullll Date: Wed, 6 Dec 2023 20:55:42 +0800 Subject: [IPEX] Fix embedding Cast `torch.bmm` args into same `dtype`. Fixes the following error when using Text Inversion embedding (#14224): ``` RuntimeError: could not create a primitive descriptor for a matmul primitive ``` --- modules/xpu_specific.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c790..ec1ad100 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,6 @@ if has_xpu: CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) -- cgit v1.2.3 From 59429793440fb3cb1624ddcc702c6f9807373203 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:09:45 +0800 Subject: Fix ControlNet --- modules/xpu_specific.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index ec1ad100..9bb0a561 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -51,3 +51,9 @@ if has_xpu: CondFunc('torch.bmm', lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file -- cgit v1.2.3 From 049d5642e58d572ee8657ac754e72d019eea0e6c Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:11:26 +0800 Subject: Fix format --- modules/xpu_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 9bb0a561..d8da94a0 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -56,4 +56,4 @@ if has_xpu: lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) -- cgit v1.2.3 From e4b4a9c4acf0ca375a8603f7f52fde8467b2d266 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 18:00:01 +0800 Subject: [IPEX] Slice SDPA into smaller chunks --- modules/xpu_specific.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d8da94a0..0ebdd596 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -27,6 +27,68 @@ def torch_xpu_gc(): has_xpu = check_for_xpu() + +# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 +# Here we implement a slicing algorithm to split large batch size into smaller chunks, +# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. +# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, +# which is the best trade-off between VRAM usage and performance. +ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention +def torch_xpu_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs +): + # cast to same dtype first + key = key.to(query.dtype) + value = value.to(query.dtype) + + N = query.shape[:-2] # Batch size + L = query.size(-2) # Target sequence length + E = query.size(-1) # Embedding dimension of the query and key + S = key.size(-2) # Source sequence length + Ev = value.size(-1) # Embedding dimension of the value + + total_batch_size = torch.numel(torch.empty(N)) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + + if total_batch_size <= batch_size_limit: + return orig_sdp_attn_func( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + *args, **kwargs + ) + + query = torch.reshape(query, (-1, L, E)) + key = torch.reshape(key, (-1, S, E)) + value = torch.reshape(value, (-1, S, Ev)) + if attn_mask is not None: + attn_mask = attn_mask.view(-1, L, S) + chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit + outputs = [] + for i in range(chunk_count): + attn_mask_chunk = ( + None + if attn_mask is None + else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] + ) + chunk_output = orig_sdp_attn_func( + query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + attn_mask_chunk, + dropout_p, + is_causal, + *args, **kwargs + ) + outputs.append(chunk_output) + result = torch.cat(outputs, dim=0) + return torch.reshape(result, (*N, L, Ev)) + + if has_xpu: # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', @@ -55,5 +117,5 @@ if has_xpu: lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) + lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), + lambda orig_func, query, *args, **kwargs: query.is_xpu) -- cgit v1.2.3 From f586f4973a0f715e30b42242bb0e6b3f88c37d90 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 19:44:52 +0800 Subject: Fix device id --- modules/xpu_specific.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 0ebdd596..f7687a66 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -33,7 +33,7 @@ has_xpu = check_for_xpu() # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, # which is the best trade-off between VRAM usage and performance. -ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +ARC_SINGLE_ALLOCATION_LIMIT = {} orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention def torch_xpu_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs @@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention( Ev = value.size(-1) # Embedding dimension of the value total_batch_size = torch.numel(torch.empty(N)) - batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + device_id = query.device.index + if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: + ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) if total_batch_size <= batch_size_limit: return orig_sdp_attn_func( -- cgit v1.2.3 From 16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 16:32:18 +0800 Subject: [IPEX] Fix SDPA attn_mask dtype --- modules/xpu_specific.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66..4e11125b 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length -- cgit v1.2.3 From 73786c047f14d6ae658b2c12f493f05486ba1789 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:09:56 +0800 Subject: [IPEX] Fix torch.Generator hijack --- modules/xpu_specific.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 4e11125b..1137891a 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention( return torch.reshape(result, (*N, L, Ev)) +def is_xpu_device(device: str | torch.device = None): + if device is None: + return False + if isinstance(device, str): + return device.startswith("xpu") + return device.type == "xpu" + + if has_xpu: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device.type == "xpu") + try: + # torch.Generator supports "xpu" device since 2.1 + torch.Generator("xpu") + except: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: is_xpu_device(device)) # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', -- cgit v1.2.3 From 818d6a11e709bf07d48606bdccab944c46a5f4b0 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:14:06 +0800 Subject: Fix format --- modules/xpu_specific.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 1137891a..2971dbc3 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -106,8 +106,8 @@ if has_xpu: try: # torch.Generator supports "xpu" device since 2.1 torch.Generator("xpu") - except: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + except RuntimeError: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) CondFunc('torch.Generator', lambda orig_func, device=None: torch.xpu.Generator(device), lambda orig_func, device=None: is_xpu_device(device)) -- cgit v1.2.3