diff options
-rw-r--r-- | extensions-builtin/Lora/network.py | 2 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_full.py | 4 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_glora.py | 10 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_hada.py | 12 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_ia3.py | 2 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_lokr.py | 18 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_lora.py | 6 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_norm.py | 4 | ||||
-rw-r--r-- | extensions-builtin/Lora/networks.py | 6 | ||||
-rw-r--r-- | modules/cmd_args.py | 2 | ||||
-rw-r--r-- | modules/devices.py | 57 | ||||
-rw-r--r-- | modules/launch_utils.py | 4 | ||||
-rw-r--r-- | modules/sd_models.py | 32 | ||||
-rw-r--r-- | modules/sd_models_xl.py | 2 |
14 files changed, 124 insertions, 37 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8d..a62e5eff 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ class NetworkModule: def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e9..f221c95f 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): def calc_updown(self, orig_weight):
output_shape = self.weight.shape
- updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
- ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d4870..efe5c681 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695..d95a0fd1 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
- t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
+ t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc4249..96faeaf3 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
- w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
+ w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab..fcdaeafd 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): def calc_updown(self, orig_weight):
if self.w1 is not None:
- w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1 = self.w1.to(orig_weight.device)
else:
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
- w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c..4cc40295 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): return module
def calc_updown(self, orig_weight):
- up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
- down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ up = self.up_model.weight.to(orig_weight.device)
+ down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
- mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce450158..d25afcbb 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4..8ea4ea60 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
+ self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias += ex_bias
+ self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84..20bfb2c4 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,3 +118,5 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
+parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
+parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563..d7c905c2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -71,6 +87,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -91,12 +108,48 @@ def cond_cast_float(input): nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward def autocast(disable=False): if disable: return contextlib.nullcontext() + if fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa5..636da679 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,8 +308,8 @@ def requirements_met(requirements_file): def prepare_environment():
- torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
- torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea1..31bcb913 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -392,6 +392,38 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16
timer.record("apply half()")
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.cmd_opts.opt_unet_fp8_storage:
+ enable_fp8 = True
+ elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+
+ if enable_fp8:
+ devices.fp8 = True
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+
+ if devices.device == devices.cpu:
+ for module in model.model.diffusion_model.modules():
+ if isinstance(module, torch.nn.Conv2d):
+ module.to(torch.float8_e4m3fn)
+ elif isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+ else:
+ model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 01123321..11259a36 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
|