diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2024-01-01 11:45:12 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-01 11:45:12 +0000 |
commit | 267fd5d76b00b0c22edffa83c1a078680ba8b42f (patch) | |
tree | c4092b8ec7430f15aaac7d9f8a0fa2199de28140 /modules/processing.py | |
parent | d613cd17c72c753bd1e314dff74dc22d9a949374 (diff) | |
parent | 5381405eaa1e809e5cfb97522bd4c19d3c946079 (diff) | |
download | stable-diffusion-webui-gfx803-267fd5d76b00b0c22edffa83c1a078680ba8b42f.tar.gz stable-diffusion-webui-gfx803-267fd5d76b00b0c22edffa83c1a078680ba8b42f.tar.bz2 stable-diffusion-webui-gfx803-267fd5d76b00b0c22edffa83c1a078680ba8b42f.zip |
Merge pull request #14145 from drhead/zero-terminal-snr
Implement zero terminal SNR noise schedule option
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/modules/processing.py b/modules/processing.py index b30df60d..846e4796 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -898,6 +898,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+ if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|