aboutsummaryrefslogtreecommitdiffstats
path: root/modules/models/diffusion/uni_pc/uni_pc.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/models/diffusion/uni_pc/uni_pc.py')
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
index a4c4ef4e..6f8ad631 100644
--- a/modules/models/diffusion/uni_pc/uni_pc.py
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -178,13 +178,13 @@ def model_wrapper(
model,
noise_schedule,
model_type="noise",
- model_kwargs={},
+ model_kwargs=None,
guidance_type="uncond",
#condition=None,
#unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
- classifier_kwargs={},
+ classifier_kwargs=None,
):
"""Create a wrapper function for the noise prediction model.
@@ -275,6 +275,9 @@ def model_wrapper(
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
+ model_kwargs = model_kwargs or []
+ classifier_kwargs = classifier_kwargs or []
+
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.