diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-14 05:06:45 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-14 05:06:45 +0000 |
commit | cb9a3a7809fc7247360705bca4175ccb59b9288c (patch) | |
tree | 097747e69ab86535ae16ac5f67e5e47686ac4110 /modules/sd_samplers_kdiffusion.py | |
parent | 4051d51caf70b8c48ddba2df980e7523c2bf31cd (diff) | |
parent | 8abfc95013d247c8a863d048574bc1f9d1eb0443 (diff) | |
download | stable-diffusion-webui-gfx803-cb9a3a7809fc7247360705bca4175ccb59b9288c.tar.gz stable-diffusion-webui-gfx803-cb9a3a7809fc7247360705bca4175ccb59b9288c.tar.bz2 stable-diffusion-webui-gfx803-cb9a3a7809fc7247360705bca4175ccb59b9288c.zip |
Merge pull request #10357 from catboxanon/sag
Add/modify CFG callbacks for Self-Attention Guidance extension
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r-- | modules/sd_samplers_kdiffusion.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e9e41818..55f0d3a3 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -8,6 +8,7 @@ from modules.shared import opts, state import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -160,7 +161,7 @@ class CFGDenoiser(torch.nn.Module): fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -180,6 +181,11 @@ class CFGDenoiser(torch.nn.Module): if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ if after_cfg_callback_params.output_altered:
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised
|