diff options
author | papuSpartan <30642826+papuSpartan@users.noreply.github.com> | 2023-05-12 03:40:17 +0000 |
---|---|---|
committer | papuSpartan <30642826+papuSpartan@users.noreply.github.com> | 2023-05-12 03:40:17 +0000 |
commit | 75b3692920e8dceb9031dd405b9226b55d286ce1 (patch) | |
tree | b7bb9db2aca00e54525b82ed1d902eac273766b9 /modules/models/diffusion/uni_pc/uni_pc.py | |
parent | f0efc8c211fc2d2c2f8caf6e2f92501922d18c99 (diff) | |
parent | abe32cefa39dee36d7f661d4e63c28ea8dd60c4f (diff) | |
download | stable-diffusion-webui-gfx803-75b3692920e8dceb9031dd405b9226b55d286ce1.tar.gz stable-diffusion-webui-gfx803-75b3692920e8dceb9031dd405b9226b55d286ce1.tar.bz2 stable-diffusion-webui-gfx803-75b3692920e8dceb9031dd405b9226b55d286ce1.zip |
Merge branch 'dev' of https://github.com/AUTOMATIC1111/stable-diffusion-webui into tomesd
Diffstat (limited to 'modules/models/diffusion/uni_pc/uni_pc.py')
-rw-r--r-- | modules/models/diffusion/uni_pc/uni_pc.py | 86 |
1 files changed, 46 insertions, 40 deletions
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index eb5f4e76..d257a728 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -1,7 +1,6 @@ import torch -import torch.nn.functional as F import math -from tqdm.auto import trange +import tqdm class NoiseScheduleVP: @@ -94,7 +93,7 @@ class NoiseScheduleVP: """ if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'") self.schedule = schedule if schedule == 'discrete': @@ -179,13 +178,13 @@ def model_wrapper( model, noise_schedule, model_type="noise", - model_kwargs={}, + model_kwargs=None, guidance_type="uncond", #condition=None, #unconditional_condition=None, guidance_scale=1., classifier_fn=None, - classifier_kwargs={}, + classifier_kwargs=None, ): """Create a wrapper function for the noise prediction model. @@ -276,6 +275,9 @@ def model_wrapper( A noise prediction model that accepts the noised data and the continuous time as the inputs. """ + model_kwargs = model_kwargs or {} + classifier_kwargs = classifier_kwargs or {} + def get_model_input_time(t_continuous): """ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. @@ -342,7 +344,7 @@ def model_wrapper( t_in = torch.cat([t_continuous] * 2) if isinstance(condition, dict): assert isinstance(unconditional_condition, dict) - c_in = dict() + c_in = {} for k in condition: if isinstance(condition[k], list): c_in[k] = [torch.cat([ @@ -353,7 +355,7 @@ def model_wrapper( unconditional_condition[k], condition[k]]) elif isinstance(condition, list): - c_in = list() + c_in = [] assert isinstance(unconditional_condition, list) for i in range(len(condition)): c_in.append(torch.cat([unconditional_condition[i], condition[i]])) @@ -469,7 +471,7 @@ class UniPC: t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) return t else: - raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'") def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ @@ -757,40 +759,44 @@ class UniPC: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in range(1, order): - vec_t = timesteps[init_order].expand(x.shape[0]) - x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) - if model_x is None: - model_x = self.model_fn(x, vec_t) - if self.after_update is not None: - self.after_update(x, model_x) - model_prev_list.append(model_x) - t_prev_list.append(vec_t) - for step in trange(order, steps + 1): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final: - step_order = min(order, steps + 1 - step) - else: - step_order = order - #print('this step order:', step_order) - if step == steps: - #print('do not run corrector at the last step') - use_corrector = False - else: - use_corrector = True - x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) - if self.after_update is not None: - self.after_update(x, model_x) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: + with tqdm.tqdm(total=steps) as pbar: + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) if model_x is None: model_x = self.model_fn(x, vec_t) - model_prev_list[-1] = model_x + if self.after_update is not None: + self.after_update(x, model_x) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + pbar.update() + + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + #print('this step order:', step_order) + if step == steps: + #print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + if self.after_update is not None: + self.after_update(x, model_x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list[-1] = model_x + pbar.update() else: raise NotImplementedError() if denoise_to_zero: |