diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-10-28 07:24:26 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-10-28 07:24:26 +0000 |
commit | d4d3134f6d2d232c7bcfa80900a362921e644976 (patch) | |
tree | a592066be90aecdf86a76270cca78001cce5d3eb | |
parent | 0beb131c7ffae6f756a6339206da311232a36970 (diff) | |
download | stable-diffusion-webui-gfx803-d4d3134f6d2d232c7bcfa80900a362921e644976.tar.gz stable-diffusion-webui-gfx803-d4d3134f6d2d232c7bcfa80900a362921e644976.tar.bz2 stable-diffusion-webui-gfx803-d4d3134f6d2d232c7bcfa80900a362921e644976.zip |
ManualCast for 10/16 series gpu
-rw-r--r-- | modules/devices.py | 57 | ||||
-rw-r--r-- | modules/processing.py | 2 | ||||
-rw-r--r-- | modules/sd_models.py | 21 |
3 files changed, 64 insertions, 16 deletions
diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d..c05f2b35 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None - - -def autocast(disable=False, unet=False): +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 2df8a7ea..40598f5c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
- with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True):
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if getattr(samples_ddim, 'already_decoded', False):
diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2..31bcb913 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8:
devices.fp8 = True
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+
if devices.device == devices.cpu:
for module in model.model.diffusion_model.modules():
if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
elif isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
- timer.record("apply fp8 unet for cpu")
else:
- if model.is_sdxl:
- cond_stage = model.conditioner
- else:
- cond_stage = model.cond_stage_model
- for module in cond_stage.modules():
- if isinstance(module, torch.nn.Linear):
- module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
- timer.record("apply fp8 unet")
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|