diff options
author | brkirch <brkirch@users.noreply.github.com> | 2023-01-25 04:51:45 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2023-01-25 06:13:02 +0000 |
commit | 84d9ce30cb427759547bc7876ed80ab91787d175 (patch) | |
tree | a87ca1a7094ca9b7af4e573a211b1dcf8146af67 /modules/sd_models.py | |
parent | 48a15821de768fea76e66f26df83df3fddf18f4b (diff) | |
download | stable-diffusion-webui-gfx803-84d9ce30cb427759547bc7876ed80ab91787d175.tar.gz stable-diffusion-webui-gfx803-84d9ce30cb427759547bc7876ed80ab91787d175.tar.bz2 stable-diffusion-webui-gfx803-84d9ce30cb427759547bc7876ed80ab91787d175.zip |
Add option for float32 sampling with float16 UNet
This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half().
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..7c98991a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half:
vae = model.first_stage_model
+ depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
+ # with --upcast-sampling, don't convert the depth model weights to float16
+ if shared.cmd_opts.upcast_sampling and depth_model:
+ model.depth_model = None
model.half()
model.first_stage_model = vae
+ if depth_model:
+ model.depth_model = depth_model
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -372,6 +380,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
timer = Timer()
|