diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-19 07:50:06 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-19 07:50:06 +0000 |
commit | 598da5cd4928618b166886d3485ce30ce3a43490 (patch) | |
tree | 48bdd2fcf47bd88b9283d13fd53ecc39e5f5ff27 | |
parent | b60e1088db2497e945d36c7500dcbf03afceedf2 (diff) | |
download | stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.tar.gz stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.tar.bz2 stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.zip |
Use options instead of cmd_args
-rw-r--r-- | modules/cmd_args.py | 2 | ||||
-rw-r--r-- | modules/devices.py | 25 | ||||
-rw-r--r-- | modules/initialize_util.py | 1 | ||||
-rw-r--r-- | modules/sd_models.py | 61 | ||||
-rw-r--r-- | modules/shared_options.py | 1 | ||||
-rw-r--r-- | scripts/xyz_grid.py | 1 |
6 files changed, 49 insertions, 42 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 088d5dea..a9fb9bfa 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,5 +118,3 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
-parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
-parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False)
diff --git a/modules/devices.py b/modules/devices.py index d7c905c2..03e7bdb7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ patch_module_list = [ torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d89..1b11ead6 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,7 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
startup_timer.record("opts onchange")
diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa..eb491434 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ class SkipWritingToConfig: SkipWritingToConfig.skip = self.previous
+def check_fp8(model):
+ if model is None:
+ return None
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.opts.fp8_storage == "Enable":
+ enable_fp8 = True
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+ return enable_fp8
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
+ if not check_fp8(model) and devices.fp8:
+ # prevent model to load state dict in fp8
+ model.half()
+
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16
timer.record("apply half()")
- if devices.get_optimal_device_name() == "mps":
- enable_fp8 = False
- elif shared.cmd_opts.opt_unet_fp8_storage:
- enable_fp8 = True
- elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
- enable_fp8 = True
- else:
- enable_fp8 = False
-
- if enable_fp8:
+ if check_fp8(model):
devices.fp8 = True
- if model.is_sdxl:
- cond_stage = model.conditioner
- else:
- cond_stage = model.cond_stage_model
-
- for module in cond_stage.modules():
- if isinstance(module, torch.nn.Linear):
+ first_stage = model.first_stage_model
+ model.first_stage_model = None
+ for module in model.modules():
+ if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
-
- if devices.device == devices.cpu:
- for module in model.model.diffusion_model.modules():
- if isinstance(module, torch.nn.Conv2d):
- module.to(torch.float8_e4m3fn)
- elif isinstance(module, torch.nn.Linear):
- module.to(torch.float8_e4m3fn)
- else:
- model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
+ elif isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
+ model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False
@@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if check_fp8(sd_model) != devices.fp8:
+ # load from state dict again to prevent extra numerical errors
+ forced_reload = True
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
diff --git a/modules/shared_options.py b/modules/shared_options.py index f1003f21..d27f35e9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
+ "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc..b2250c04 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ axis_options = [ AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
+ AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
]
|