From 8b40f475a31109cc6ecbdc0d14a0cee9e0303291 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Fri, 10 Nov 2023 11:06:26 +0800 Subject: Initial IPEX support --- modules/xpu_specific.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 modules/xpu_specific.py (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py new file mode 100644 index 00000000..6417dd2d --- /dev/null +++ b/modules/xpu_specific.py @@ -0,0 +1,42 @@ +import contextlib +from modules import shared +from modules.sd_hijack_utils import CondFunc + +has_ipex = False +try: + import torch + import intel_extension_for_pytorch as ipex + has_ipex = True +except Exception: + pass + +def check_for_xpu(): + if not has_ipex: + return False + + return hasattr(torch, 'xpu') and torch.xpu.is_available() + +has_xpu = check_for_xpu() + +def get_xpu_device_string(): + if shared.cmd_opts.device_id is not None: + return f"xpu:{shared.cmd_opts.device_id}" + return "xpu" + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + return contextlib.nullcontext() + +if has_xpu: + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) -- cgit v1.2.3 From 7499148ad4dbd3444215c843d02453f68c459707 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 2 Dec 2023 14:00:46 +0800 Subject: Disable ipex autocast due to its bad perf --- modules/cmd_args.py | 1 + modules/devices.py | 20 +++++++++++++------- modules/xpu_specific.py | 28 ++++++++++++++++++---------- webui-ipex-user.bat | 19 +++++++++++++++++++ 4 files changed, 51 insertions(+), 17 deletions(-) create mode 100644 webui-ipex-user.bat (limited to 'modules/xpu_specific.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index a9fb9bfa..da93eb26 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) diff --git a/modules/devices.py b/modules/devices.py index be599736..37ecca78 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,11 +3,18 @@ import contextlib from functools import lru_cache import torch -from modules import errors, shared, xpu_specific +from modules import errors, shared if sys.platform == "darwin": from modules import mac_specific +if shared.cmd_opts.use_ipex: + from modules import xpu_specific + + +def has_xpu() -> bool: + return shared.cmd_opts.use_ipex and xpu_specific.has_xpu + def has_mps() -> bool: if sys.platform != "darwin": @@ -30,7 +37,7 @@ def get_optimal_device_name(): if has_mps(): return "mps" - if xpu_specific.has_ipex: + if has_xpu(): return xpu_specific.get_xpu_device_string() return "cpu" @@ -57,6 +64,9 @@ def torch_gc(): if has_mps(): mac_specific.torch_mps_gc() + if has_xpu(): + xpu_specific.torch_xpu_gc() + def enable_tf32(): if torch.cuda.is_available(): @@ -103,15 +113,11 @@ def autocast(disable=False): if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() - if xpu_specific.has_xpu: - return torch.autocast("xpu") - return torch.autocast("cuda") def without_autocast(disable=False): - device_type = "xpu" if xpu_specific.has_xpu else "cuda" - return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 6417dd2d..2df68665 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -1,4 +1,3 @@ -import contextlib from modules import shared from modules.sd_hijack_utils import CondFunc @@ -10,33 +9,42 @@ try: except Exception: pass -def check_for_xpu(): - if not has_ipex: - return False - return hasattr(torch, 'xpu') and torch.xpu.is_available() +def check_for_xpu(): + return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() -has_xpu = check_for_xpu() def get_xpu_device_string(): if shared.cmd_opts.device_id is not None: return f"xpu:{shared.cmd_opts.device_id}" return "xpu" -def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + +def torch_xpu_gc(): + with torch.xpu.device(get_xpu_device_string()): + torch.xpu.empty_cache() + + +has_xpu = check_for_xpu() if has_xpu: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + lambda orig_func, device=None: device is not None and device.type == "xpu") + # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) - CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) diff --git a/webui-ipex-user.bat b/webui-ipex-user.bat new file mode 100644 index 00000000..ab25a040 --- /dev/null +++ b/webui-ipex-user.bat @@ -0,0 +1,19 @@ +@echo off + +set PYTHON= +@REM The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main +@REM This is NOT an Intel official release so please use it at your own risk!! +@REM See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. +@REM +@REM Strengths (over official IPEX 2.0.110 windows release): +@REM - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 +@REM - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. +@REM - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 +@REM Limitation: +@REM - Only works for python 3.10 +set "TORCH_COMMAND=pip install https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl" +set GIT= +set VENV_DIR= +set "COMMANDLINE_ARGS=--use-ipex --skip-torch-cuda-test --skip-version-check --opt-sdp-attention" + +call webui.bat -- cgit v1.2.3 From 87cd07b3af74c447b02570bf3963ba83ade2e203 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 2 Dec 2023 15:54:25 +0800 Subject: Fix fp64 --- modules/sd_samplers_timesteps_impl.py | 4 ++-- modules/xpu_specific.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index a72daafd..930a64af 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) @@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) extra_args = {} if extra_args is None else extra_args diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 2df68665..d933c790 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -4,7 +4,7 @@ from modules.sd_hijack_utils import CondFunc has_ipex = False try: import torch - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # noqa: F401 has_ipex = True except Exception: pass -- cgit v1.2.3 From 746783f7a47f38f728f221cc26fe04035d3ca66b Mon Sep 17 00:00:00 2001 From: Nuullll Date: Wed, 6 Dec 2023 20:55:42 +0800 Subject: [IPEX] Fix embedding Cast `torch.bmm` args into same `dtype`. Fixes the following error when using Text Inversion embedding (#14224): ``` RuntimeError: could not create a primitive descriptor for a matmul primitive ``` --- modules/xpu_specific.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c790..ec1ad100 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,6 @@ if has_xpu: CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) -- cgit v1.2.3 From 59429793440fb3cb1624ddcc702c6f9807373203 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:09:45 +0800 Subject: Fix ControlNet --- modules/xpu_specific.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index ec1ad100..9bb0a561 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -51,3 +51,9 @@ if has_xpu: CondFunc('torch.bmm', lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file -- cgit v1.2.3 From 049d5642e58d572ee8657ac754e72d019eea0e6c Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:11:26 +0800 Subject: Fix format --- modules/xpu_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 9bb0a561..d8da94a0 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -56,4 +56,4 @@ if has_xpu: lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) -- cgit v1.2.3 From e4b4a9c4acf0ca375a8603f7f52fde8467b2d266 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 18:00:01 +0800 Subject: [IPEX] Slice SDPA into smaller chunks --- modules/xpu_specific.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d8da94a0..0ebdd596 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -27,6 +27,68 @@ def torch_xpu_gc(): has_xpu = check_for_xpu() + +# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 +# Here we implement a slicing algorithm to split large batch size into smaller chunks, +# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. +# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, +# which is the best trade-off between VRAM usage and performance. +ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention +def torch_xpu_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs +): + # cast to same dtype first + key = key.to(query.dtype) + value = value.to(query.dtype) + + N = query.shape[:-2] # Batch size + L = query.size(-2) # Target sequence length + E = query.size(-1) # Embedding dimension of the query and key + S = key.size(-2) # Source sequence length + Ev = value.size(-1) # Embedding dimension of the value + + total_batch_size = torch.numel(torch.empty(N)) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + + if total_batch_size <= batch_size_limit: + return orig_sdp_attn_func( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + *args, **kwargs + ) + + query = torch.reshape(query, (-1, L, E)) + key = torch.reshape(key, (-1, S, E)) + value = torch.reshape(value, (-1, S, Ev)) + if attn_mask is not None: + attn_mask = attn_mask.view(-1, L, S) + chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit + outputs = [] + for i in range(chunk_count): + attn_mask_chunk = ( + None + if attn_mask is None + else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] + ) + chunk_output = orig_sdp_attn_func( + query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + attn_mask_chunk, + dropout_p, + is_causal, + *args, **kwargs + ) + outputs.append(chunk_output) + result = torch.cat(outputs, dim=0) + return torch.reshape(result, (*N, L, Ev)) + + if has_xpu: # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', @@ -55,5 +117,5 @@ if has_xpu: lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) + lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), + lambda orig_func, query, *args, **kwargs: query.is_xpu) -- cgit v1.2.3 From f586f4973a0f715e30b42242bb0e6b3f88c37d90 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 19:44:52 +0800 Subject: Fix device id --- modules/xpu_specific.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 0ebdd596..f7687a66 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -33,7 +33,7 @@ has_xpu = check_for_xpu() # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, # which is the best trade-off between VRAM usage and performance. -ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +ARC_SINGLE_ALLOCATION_LIMIT = {} orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention def torch_xpu_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs @@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention( Ev = value.size(-1) # Embedding dimension of the value total_batch_size = torch.numel(torch.empty(N)) - batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + device_id = query.device.index + if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: + ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) if total_batch_size <= batch_size_limit: return orig_sdp_attn_func( -- cgit v1.2.3