aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-11-21 11:59:34 +0000
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-11-21 11:59:34 +0000
commit370a77f8e78e65a8a1339289d684cb43df142f70 (patch)
tree001c59f3d41682a36b7cc8816f5a6584713ab7c1 /modules
parentb2e039d07bed76350120ff448964c907a3b5e4a3 (diff)
downloadstable-diffusion-webui-gfx803-370a77f8e78e65a8a1339289d684cb43df142f70.tar.gz
stable-diffusion-webui-gfx803-370a77f8e78e65a8a1339289d684cb43df142f70.tar.bz2
stable-diffusion-webui-gfx803-370a77f8e78e65a8a1339289d684cb43df142f70.zip
Option for using fp16 weight when apply lora
Diffstat (limited to 'modules')
-rw-r--r--modules/initialize_util.py1
-rw-r--r--modules/sd_models.py14
-rw-r--r--modules/shared_options.py1
3 files changed, 13 insertions, 3 deletions
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index 1b11ead6..7fb1d8d5 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -178,6 +178,7 @@ def configure_opts_onchange():
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
startup_timer.record("opts onchange")
diff --git a/modules/sd_models.py b/modules/sd_models.py
index eb491434..0a7777f1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
- if isinstance(module, torch.nn.Conv2d):
- module.to(torch.float8_e4m3fn)
- elif isinstance(module, torch.nn.Linear):
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.clone().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.clone().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
diff --git a/modules/shared_options.py b/modules/shared_options.py
index d27f35e9..eaa9f135 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -201,6 +201,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
+ "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {