From eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:49:05 +0800 Subject: Add CPU fp8 support Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet. --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563..0cd2b55d 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,6 +71,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -93,10 +94,13 @@ def cond_cast_float(input): nv_rng = None -def autocast(disable=False): +def autocast(disable=False, unet=False): if disable: return contextlib.nullcontext() + if unet and fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() -- cgit v1.2.3 From d4d3134f6d2d232c7bcfa80900a362921e644976 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Oct 2023 15:24:26 +0800 Subject: ManualCast for 10/16 series gpu --- modules/devices.py | 57 +++++++++++++++++++++++++++++++++++++++++++++------ modules/processing.py | 2 +- modules/sd_models.py | 21 +++++++++++-------- 3 files changed, 64 insertions(+), 16 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d..c05f2b35 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None - - -def autocast(disable=False, unet=False): +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 2df8a7ea..40598f5c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2..31bcb913 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8: devices.fp8 = True + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + if devices.device == devices.cpu: for module in model.model.diffusion_model.modules(): if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) elif isinstance(module, torch.nn.Linear): module.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for cpu") else: - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet") + timer.record("apply fp8") + else: + devices.fp8 = False devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.3 From ddc2a3499b8cd120b4a42358bcd33137ce1d1e75 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Sat, 28 Oct 2023 16:52:35 +0800 Subject: Add MPS manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index c05f2b35..d7c905c2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -121,6 +121,8 @@ def manual_autocast(): def manual_cast_forward(self, *args, **kwargs): org_dtype = next(self.parameters()).dtype self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} result = self.org_forward(*args, **kwargs) self.to(org_dtype) return result @@ -136,7 +138,6 @@ def manual_autocast(): def autocast(disable=False): - print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() @@ -146,6 +147,9 @@ def autocast(disable=False): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): return manual_autocast() + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() -- cgit v1.2.3 From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: Use options instead of cmd_args --- modules/cmd_args.py | 2 -- modules/devices.py | 25 ++++++++++--------- modules/initialize_util.py | 1 + modules/sd_models.py | 61 ++++++++++++++++++++++++---------------------- modules/shared_options.py | 1 + scripts/xyz_grid.py | 1 + 6 files changed, 49 insertions(+), 42 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 088d5dea..a9fb9bfa 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,5 +118,3 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) -parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) -parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/devices.py b/modules/devices.py index d7c905c2..03e7bdb7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ patch_module_list = [ torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d89..1b11ead6 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,7 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa..eb491434 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ class SkipWritingToConfig: SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if not check_fp8(model) and devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if devices.get_optimal_device_name() == "mps": - enable_fp8 = False - elif shared.cmd_opts.opt_unet_fp8_storage: - enable_fp8 = True - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - enable_fp8 = True - else: - enable_fp8 = False - - if enable_fp8: + if check_fp8(model): devices.fp8 = True - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) - - if devices.device == devices.cpu: - for module in model.model.diffusion_model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) - else: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage timer.record("apply fp8") else: devices.fp8 = False @@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: diff --git a/modules/shared_options.py b/modules/shared_options.py index f1003f21..d27f35e9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc..b2250c04 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ axis_options = [ AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), + AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), ] -- cgit v1.2.3 From 043d2edcf6a543f236f1f3cb70ac72e7b3b357b6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:31 +0800 Subject: Better naming --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 03e7bdb7..c19a7f40 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs): @contextlib.contextmanager -def manual_autocast(): +def manual_cast(): for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward @@ -148,10 +148,10 @@ def autocast(disable=False): return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_autocast() + return manual_cast() if has_mps() and shared.cmd_opts.precision != "full": - return manual_autocast() + return manual_cast() if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() -- cgit v1.2.3 From 3cd6e1d0a0877e6f1ac931c8253e6eee09da3805 Mon Sep 17 00:00:00 2001 From: obsol <33932119+read-0nly@users.noreply.github.com> Date: Mon, 27 Nov 2023 19:21:43 -0500 Subject: Update devices.py fixes issue where "--use-cpu" all properly makes SD run on CPU but leaves ControlNet (and other extensions, I presume) pointed at GPU, causing a crash in ControlNet caused by a mismatch between devices between SD and CN https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14097 --- modules/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index c01f0602..65efcf1e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -38,7 +38,7 @@ def get_optimal_device(): def get_device_for(task): - if task in shared.cmd_opts.use_cpu: + if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu: return cpu return get_optimal_device() -- cgit v1.2.3 From 8b40f475a31109cc6ecbdc0d14a0cee9e0303291 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Fri, 10 Nov 2023 11:06:26 +0800 Subject: Initial IPEX support --- modules/devices.py | 11 +++++++++-- modules/xpu_specific.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 modules/xpu_specific.py (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563..be599736 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors, shared +from modules import errors, shared, xpu_specific if sys.platform == "darwin": from modules import mac_specific @@ -30,6 +30,9 @@ def get_optimal_device_name(): if has_mps(): return "mps" + if xpu_specific.has_ipex: + return xpu_specific.get_xpu_device_string() + return "cpu" @@ -100,11 +103,15 @@ def autocast(disable=False): if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if xpu_specific.has_xpu: + return torch.autocast("xpu") + return torch.autocast("cuda") def without_autocast(disable=False): - return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + device_type = "xpu" if xpu_specific.has_xpu else "cuda" + return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py new file mode 100644 index 00000000..6417dd2d --- /dev/null +++ b/modules/xpu_specific.py @@ -0,0 +1,42 @@ +import contextlib +from modules import shared +from modules.sd_hijack_utils import CondFunc + +has_ipex = False +try: + import torch + import intel_extension_for_pytorch as ipex + has_ipex = True +except Exception: + pass + +def check_for_xpu(): + if not has_ipex: + return False + + return hasattr(torch, 'xpu') and torch.xpu.is_available() + +has_xpu = check_for_xpu() + +def get_xpu_device_string(): + if shared.cmd_opts.device_id is not None: + return f"xpu:{shared.cmd_opts.device_id}" + return "xpu" + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + return contextlib.nullcontext() + +if has_xpu: + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) -- cgit v1.2.3 From 7499148ad4dbd3444215c843d02453f68c459707 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 2 Dec 2023 14:00:46 +0800 Subject: Disable ipex autocast due to its bad perf --- modules/cmd_args.py | 1 + modules/devices.py | 20 +++++++++++++------- modules/xpu_specific.py | 28 ++++++++++++++++++---------- webui-ipex-user.bat | 19 +++++++++++++++++++ 4 files changed, 51 insertions(+), 17 deletions(-) create mode 100644 webui-ipex-user.bat (limited to 'modules/devices.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index a9fb9bfa..da93eb26 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) diff --git a/modules/devices.py b/modules/devices.py index be599736..37ecca78 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,11 +3,18 @@ import contextlib from functools import lru_cache import torch -from modules import errors, shared, xpu_specific +from modules import errors, shared if sys.platform == "darwin": from modules import mac_specific +if shared.cmd_opts.use_ipex: + from modules import xpu_specific + + +def has_xpu() -> bool: + return shared.cmd_opts.use_ipex and xpu_specific.has_xpu + def has_mps() -> bool: if sys.platform != "darwin": @@ -30,7 +37,7 @@ def get_optimal_device_name(): if has_mps(): return "mps" - if xpu_specific.has_ipex: + if has_xpu(): return xpu_specific.get_xpu_device_string() return "cpu" @@ -57,6 +64,9 @@ def torch_gc(): if has_mps(): mac_specific.torch_mps_gc() + if has_xpu(): + xpu_specific.torch_xpu_gc() + def enable_tf32(): if torch.cuda.is_available(): @@ -103,15 +113,11 @@ def autocast(disable=False): if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() - if xpu_specific.has_xpu: - return torch.autocast("xpu") - return torch.autocast("cuda") def without_autocast(disable=False): - device_type = "xpu" if xpu_specific.has_xpu else "cuda" - return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 6417dd2d..2df68665 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -1,4 +1,3 @@ -import contextlib from modules import shared from modules.sd_hijack_utils import CondFunc @@ -10,33 +9,42 @@ try: except Exception: pass -def check_for_xpu(): - if not has_ipex: - return False - return hasattr(torch, 'xpu') and torch.xpu.is_available() +def check_for_xpu(): + return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() -has_xpu = check_for_xpu() def get_xpu_device_string(): if shared.cmd_opts.device_id is not None: return f"xpu:{shared.cmd_opts.device_id}" return "xpu" -def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + +def torch_xpu_gc(): + with torch.xpu.device(get_xpu_device_string()): + torch.xpu.empty_cache() + + +has_xpu = check_for_xpu() if has_xpu: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + lambda orig_func, device=None: device is not None and device.type == "xpu") + # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) - CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) diff --git a/webui-ipex-user.bat b/webui-ipex-user.bat new file mode 100644 index 00000000..ab25a040 --- /dev/null +++ b/webui-ipex-user.bat @@ -0,0 +1,19 @@ +@echo off + +set PYTHON= +@REM The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main +@REM This is NOT an Intel official release so please use it at your own risk!! +@REM See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. +@REM +@REM Strengths (over official IPEX 2.0.110 windows release): +@REM - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 +@REM - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. +@REM - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 +@REM Limitation: +@REM - Only works for python 3.10 +set "TORCH_COMMAND=pip install https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl" +set GIT= +set VENV_DIR= +set "COMMANDLINE_ARGS=--use-ipex --skip-torch-cuda-test --skip-version-check --opt-sdp-attention" + +call webui.bat -- cgit v1.2.3 From 5768afc776a66bb94e77a9c1daebeea58fa731d5 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:20:30 +0200 Subject: Add utility to inspect a model's parameters (to get dtype/device) --- modules/devices.py | 3 ++- modules/interrogate.py | 3 ++- modules/sd_models_xl.py | 3 ++- modules/torch_utils.py | 17 +++++++++++++++++ modules/upscaler_utils.py | 5 +++-- modules/xlmr.py | 5 ++++- modules/xlmr_m18.py | 5 ++++- test/test_torch_utils.py | 19 +++++++++++++++++++ 8 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 modules/torch_utils.py create mode 100644 test/test_torch_utils.py (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index c956207f..bd6bd579 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,6 +4,7 @@ from functools import lru_cache import torch from modules import errors, shared +from modules.torch_utils import get_param if sys.platform == "darwin": from modules import mac_specific @@ -131,7 +132,7 @@ patch_module_list = [ def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype + org_dtype = get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 3045560d..5be5a10f 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,6 +11,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.torch_utils import get_param blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -131,7 +132,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = next(self.clip_model.parameters()).dtype + self.dtype = get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1de31b0d..c3602a7e 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,6 +6,7 @@ import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser +from modules.torch_utils import get_param def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = next(model.model.diffusion_model.parameters()).dtype + dtype = get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/torch_utils.py b/modules/torch_utils.py new file mode 100644 index 00000000..e5b52393 --- /dev/null +++ b/modules/torch_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import torch.nn + + +def get_param(model) -> torch.nn.Parameter: + """ + Find the first parameter in a model or module. + """ + if hasattr(model, "model") and hasattr(model.model, "parameters"): + # Unpeel a model descriptor to get at the actual Torch module. + model = model.model + + for param in model.parameters(): + return param + + raise ValueError(f"No parameters found in model {model!r}") diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8e413854..c60e3beb 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -7,6 +7,7 @@ import tqdm from PIL import Image from modules import images, shared +from modules.torch_utils import get_param logger = logging.getLogger(__name__) @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.model.parameters())) - img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + param = get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): output = model(img) diff --git a/modules/xlmr.py b/modules/xlmr.py index a407a3ca..6e000a56 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index a727e865..e3e81961 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py new file mode 100644 index 00000000..f1aec832 --- /dev/null +++ b/test/test_torch_utils.py @@ -0,0 +1,19 @@ +import types + +import pytest +import torch + +from modules.torch_utils import get_param + + +@pytest.mark.parametrize("wrapped", [True, False]) +def test_get_param(wrapped): + mod = torch.nn.Linear(1, 1) + cpu = torch.device("cpu") + mod.to(dtype=torch.float16, device=cpu) + if wrapped: + # more or less how spandrel wraps a thing + mod = types.SimpleNamespace(model=mod) + p = get_param(mod) + assert p.dtype == torch.float16 + assert p.device == cpu -- cgit v1.2.3 From a70dfb64a86b9b6d869deffdb0ffebe980365473 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 22:38:30 +0300 Subject: change import statements for #14478 --- modules/devices.py | 4 ++-- modules/interrogate.py | 5 ++--- modules/sd_models_xl.py | 4 ++-- modules/upscaler_utils.py | 5 ++--- modules/xlmr.py | 4 ++-- modules/xlmr_m18.py | 5 ++--- test/test_torch_utils.py | 4 ++-- 7 files changed, 14 insertions(+), 17 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index bd6bd579..ff279ac5 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,7 +4,7 @@ from functools import lru_cache import torch from modules import errors, shared -from modules.torch_utils import get_param +from modules import torch_utils if sys.platform == "darwin": from modules import mac_specific @@ -132,7 +132,7 @@ patch_module_list = [ def manual_cast_forward(self, *args, **kwargs): - org_dtype = get_param(self).dtype + org_dtype = torch_utils.get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 5be5a10f..35a627ca 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,8 +10,7 @@ import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.torch_utils import get_param +from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -132,7 +131,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = get_param(self.clip_model).dtype + self.dtype = torch_utils.get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index c3602a7e..0de17af3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,7 +6,7 @@ import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser -from modules.torch_utils import get_param +from modules import torch_utils def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = get_param(model.model.diffusion_model).dtype + dtype = torch_utils.get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index c60e3beb..f5cb92d5 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,8 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images, shared -from modules.torch_utils import get_param +from modules import images, shared, torch_utils logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - param = get_param(model) + param = torch_utils.get_param(model) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): diff --git a/modules/xlmr.py b/modules/xlmr.py index 6e000a56..319771b7 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index e3e81961..f6055504 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -4,8 +4,7 @@ import torch from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional - -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py index f1aec832..23ccb93a 100644 --- a/test/test_torch_utils.py +++ b/test/test_torch_utils.py @@ -3,7 +3,7 @@ import types import pytest import torch -from modules.torch_utils import get_param +from modules import torch_utils @pytest.mark.parametrize("wrapped", [True, False]) @@ -14,6 +14,6 @@ def test_get_param(wrapped): if wrapped: # more or less how spandrel wraps a thing mod = types.SimpleNamespace(model=mod) - p = get_param(mod) + p = torch_utils.get_param(mod) assert p.dtype == torch.float16 assert p.device == cpu -- cgit v1.2.3