diff options
Diffstat (limited to 'modules/sd_samplers_timesteps.py')
-rw-r--r-- | modules/sd_samplers_timesteps.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 16572c7e..6aed2974 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -51,10 +51,9 @@ class CFGDenoiserTimesteps(CFGDenoiser): self.alphas = shared.sd_model.alphas_cumprod
def get_pred_x0(self, x_in, x_out, sigma):
- ts = int(sigma.item())
+ ts = sigma.to(dtype=int)
- s_in = x_in.new_ones([x_in.shape[0]])
- a_t = self.alphas[ts].item() * s_in
+ a_t = self.alphas[ts][:, None, None, None]
sqrt_one_minus_at = (1 - a_t).sqrt()
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|