aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-25 16:12:29 +0000
committerGitHub <noreply@github.com>2023-01-25 16:12:29 +0000
commit1574e967297586d013e4cfbb6628eae595c9fba2 (patch)
tree99374009f63cf73cadc713b02db5fbba0701e516 /modules/sd_models.py
parent1982ef68900fe3c5eee704dfbda5416c1bb5470b (diff)
parente3b53fd295aca784253dfc8668ec87b537a72f43 (diff)
downloadstable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.tar.gz
stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.tar.bz2
stable-diffusion-webui-gfx803-1574e967297586d013e4cfbb6628eae595c9fba2.zip
Merge pull request #6510 from brkirch/unet16-upcast-precision
Add upcast options, full precision sampling from float16 UNet and upcasting attention for inference using SD 2.1 models without --no-half
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index cddc2343..7072eb2e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -258,16 +258,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
+ depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
+ # with --upcast-sampling, don't convert the depth model weights to float16
+ if shared.cmd_opts.upcast_sampling and depth_model:
+ model.depth_model = None
model.half()
model.first_stage_model = vae
+ if depth_model:
+ model.depth_model = depth_model
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -382,6 +390,8 @@ def load_model(checkpoint_info=None):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
timer = Timer()