diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-25 16:12:29 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-25 16:12:29 +0000 |
commit | 1574e967297586d013e4cfbb6628eae595c9fba2 (patch) | |
tree | 99374009f63cf73cadc713b02db5fbba0701e516 /modules/sd_models.py | |
parent | 1982ef68900fe3c5eee704dfbda5416c1bb5470b (diff) | |
parent | e3b53fd295aca784253dfc8668ec87b537a72f43 (diff) | |
download | stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.tar.gz stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.tar.bz2 stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.zip |
Merge pull request #6510 from brkirch/unet16-upcast-precision
Add upcast options, full precision sampling from float16 UNet and upcasting attention for inference using SD 2.1 models without --no-half
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index cddc2343..7072eb2e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -258,16 +258,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half:
vae = model.first_stage_model
+ depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
+ # with --upcast-sampling, don't convert the depth model weights to float16
+ if shared.cmd_opts.upcast_sampling and depth_model:
+ model.depth_model = None
model.half()
model.first_stage_model = vae
+ if depth_model:
+ model.depth_model = depth_model
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -382,6 +390,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
timer = Timer()
|