diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-21 11:59:34 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-21 11:59:34 +0000 |
commit | 370a77f8e78e65a8a1339289d684cb43df142f70 (patch) | |
tree | 001c59f3d41682a36b7cc8816f5a6584713ab7c1 /modules/sd_models.py | |
parent | b2e039d07bed76350120ff448964c907a3b5e4a3 (diff) | |
download | stable-diffusion-webui-gfx803-370a77f8e78e65a8a1339289d684cb43df142f70.tar.gz stable-diffusion-webui-gfx803-370a77f8e78e65a8a1339289d684cb43df142f70.tar.bz2 stable-diffusion-webui-gfx803-370a77f8e78e65a8a1339289d684cb43df142f70.zip |
Option for using fp16 weight when apply lora
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index eb491434..0a7777f1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16
timer.record("apply half()")
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
- if isinstance(module, torch.nn.Conv2d):
- module.to(torch.float8_e4m3fn)
- elif isinstance(module, torch.nn.Linear):
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.clone().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.clone().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
|