diff options
Diffstat (limited to 'modules/models/diffusion/uni_pc/uni_pc.py')
-rw-r--r-- | modules/models/diffusion/uni_pc/uni_pc.py | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index ec6b37da..31ee81a6 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -378,7 +378,8 @@ class UniPC: condition=None, unconditional_condition=None, before_sample=None, - after_sample=None + after_sample=None, + after_update=None ): """Construct a UniPC. @@ -394,6 +395,7 @@ class UniPC: self.unconditional_condition = unconditional_condition self.before_sample = before_sample self.after_sample = after_sample + self.after_update = after_update def dynamic_thresholding_fn(self, x0, t=None): """ @@ -434,15 +436,6 @@ class UniPC: noise = self.noise_prediction_fn(x, t) dims = x.dim() alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - from pprint import pp - print("X:") - pp(x) - print("sigma_t:") - pp(sigma_t) - print("noise:") - pp(noise) - print("alpha_t:") - pp(alpha_t) x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) if self.thresholding: p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. @@ -524,7 +517,7 @@ class UniPC: return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') ns = self.noise_schedule assert order <= len(model_prev_list) @@ -568,7 +561,7 @@ class UniPC: A_p = C_inv_p if use_corrector: - print('using corrector') + #print('using corrector') C_inv = torch.linalg.inv(C) A_c = C_inv @@ -627,7 +620,7 @@ class UniPC: return x_t, model_t def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') ns = self.noise_schedule assert order <= len(model_prev_list) dims = x.dim() @@ -695,7 +688,7 @@ class UniPC: D1s = None if use_corrector: - print('using corrector') + #print('using corrector') # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], device=b.device) @@ -755,8 +748,9 @@ class UniPC: t_T = self.noise_schedule.T if t_start is None else t_start device = x.device if method == 'multistep': - assert steps >= order + assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) @@ -768,6 +762,8 @@ class UniPC: x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) if model_x is None: model_x = self.model_fn(x, vec_t) + if self.after_update is not None: + self.after_update(x, model_x) model_prev_list.append(model_x) t_prev_list.append(vec_t) for step in range(order, steps + 1): @@ -776,13 +772,15 @@ class UniPC: step_order = min(order, steps + 1 - step) else: step_order = order - print('this step order:', step_order) + #print('this step order:', step_order) if step == steps: - print('do not run corrector at the last step') + #print('do not run corrector at the last step') use_corrector = False else: use_corrector = True x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + if self.after_update is not None: + self.after_update(x, model_x) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] |