diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-12-24 15:38:16 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-12-24 15:38:16 +0000 |
commit | 0b8acce6a9a1418fa88a506450cd1b92e2d48986 (patch) | |
tree | 4c6c06b8a67cc100246054a0f9a794031c0cde17 /modules/sd_samplers.py | |
parent | 03d7b394539558f6f560155d87a4fc66eb675e30 (diff) | |
download | stable-diffusion-webui-gfx803-0b8acce6a9a1418fa88a506450cd1b92e2d48986.tar.gz stable-diffusion-webui-gfx803-0b8acce6a9a1418fa88a506450cd1b92e2d48986.tar.bz2 stable-diffusion-webui-gfx803-0b8acce6a9a1418fa88a506450cd1b92e2d48986.zip |
separate part of denoiser code into a function to make it easier for extensions to override it
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d26e48dc..8efe74df 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -288,6 +288,16 @@ class CFGDenoiser(torch.nn.Module): self.init_latent = None
self.step = 0
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@@ -329,12 +339,7 @@ class CFGDenoiser(torch.nn.Module): x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
|