diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/cmd_args.py | 2 | ||||
-rw-r--r-- | modules/devices.py | 57 | ||||
-rw-r--r-- | modules/launch_utils.py | 4 | ||||
-rw-r--r-- | modules/sd_models.py | 32 | ||||
-rw-r--r-- | modules/sd_models_xl.py | 2 |
5 files changed, 92 insertions, 5 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84..20bfb2c4 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,3 +118,5 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
+parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
+parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563..d7c905c2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -71,6 +87,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -91,12 +108,48 @@ def cond_cast_float(input): nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward def autocast(disable=False): if disable: return contextlib.nullcontext() + if fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa5..636da679 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,8 +308,8 @@ def requirements_met(requirements_file): def prepare_environment():
- torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
- torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea1..31bcb913 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -392,6 +392,38 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16
timer.record("apply half()")
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.cmd_opts.opt_unet_fp8_storage:
+ enable_fp8 = True
+ elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+
+ if enable_fp8:
+ devices.fp8 = True
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+
+ if devices.device == devices.cpu:
+ for module in model.model.diffusion_model.modules():
+ if isinstance(module, torch.nn.Conv2d):
+ module.to(torch.float8_e4m3fn)
+ elif isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+ else:
+ model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 01123321..11259a36 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
|