diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2024-03-02 03:54:11 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2024-03-02 03:55:03 +0000 |
commit | 141a17e9693065c33a2b1d30f04a0083bb687775 (patch) | |
tree | a0ef7f513a7f8d5fadcb126135b9b565947ab8af | |
parent | da67afe5f68497a04d1fd9173bbd256b73d9d251 (diff) | |
download | stable-diffusion-webui-gfx803-141a17e9693065c33a2b1d30f04a0083bb687775.tar.gz stable-diffusion-webui-gfx803-141a17e9693065c33a2b1d30f04a0083bb687775.tar.bz2 stable-diffusion-webui-gfx803-141a17e9693065c33a2b1d30f04a0083bb687775.zip |
style changes for #14979
-rw-r--r-- | modules/sd_models.py | 70 |
1 files changed, 41 insertions, 29 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index db72e120..747fc39e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -552,36 +552,48 @@ def repair_config(sd_config): karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
+
+def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+
def apply_alpha_schedule_override(sd_model, p=None):
- def rescale_zero_terminal_snr_abar(alphas_cumprod):
- alphas_bar_sqrt = alphas_cumprod.sqrt()
-
- # Store old values.
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
-
- # Shift so the last timestep is zero.
- alphas_bar_sqrt -= (alphas_bar_sqrt_T)
-
- # Scale so the first timestep is back to the old value.
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
-
- # Convert alphas_bar_sqrt to betas
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
- alphas_bar[-1] = 4.8973451890853435e-08
- return alphas_bar
-
- if hasattr(sd_model, 'alphas_cumprod') and hasattr(sd_model, 'alphas_cumprod_original'):
- sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
-
- if opts.use_downcasted_alpha_bar:
- if p is not None:
- p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
- sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
- if opts.sd_noise_schedule == "Zero Terminal SNR":
- if p is not None:
- p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
- sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
+ """
+ Applies an override to the alpha schedule of the model according to settings.
+ - downcasts the alpha schedule to half precision
+ - rescales the alpha schedule to have zero terminal SNR
+ """
+
+ if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
+ return
+
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ if p is not None:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
+
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ if p is not None:
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
+
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|