aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py43
-rw-r--r--modules/api/models.py31
-rw-r--r--modules/codeformer_model.py2
-rw-r--r--modules/extensions.py6
-rw-r--r--modules/generation_parameters_copypaste.py13
-rw-r--r--modules/images.py2
-rw-r--r--modules/modelloader.py8
-rw-r--r--modules/models/diffusion/uni_pc/__init__.py1
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py100
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py856
-rw-r--r--modules/processing.py7
-rw-r--r--modules/script_callbacks.py8
-rw-r--r--modules/scripts.py32
-rw-r--r--modules/sd_hijack.py12
-rw-r--r--modules/sd_hijack_optimizations.py70
-rw-r--r--modules/sd_samplers.py2
-rw-r--r--modules/sd_samplers_compvis.py59
-rw-r--r--modules/sd_samplers_kdiffusion.py4
-rw-r--r--modules/shared.py20
-rw-r--r--modules/ui.py11
-rw-r--r--modules/ui_common.py9
-rw-r--r--modules/ui_extra_networks.py30
-rw-r--r--modules/ui_extra_networks_checkpoints.py14
-rw-r--r--modules/ui_extra_networks_hypernets.py12
-rw-r--r--modules/ui_extra_networks_textual_inversion.py13
25 files changed, 1292 insertions, 73 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 8a17017b..abdbb6a7 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -150,6 +150,7 @@ class Api:
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
@@ -170,6 +171,12 @@ class Api:
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx]
return script, script_idx
+
+ def get_scripts_list(self):
+ t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
+ i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
+
+ return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
def get_script(self, script_name, script_runner):
if script_name is None or script_name == "":
@@ -215,12 +222,11 @@ class Api:
ui.create_ui()
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
- populate = txt2imgreq.copy(update={ # Override __init__ params
+ populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True
- }
- )
+ "do_not_save_samples": not txt2imgreq.save_images,
+ "do_not_save_grid": not txt2imgreq.save_images,
+ })
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
@@ -231,22 +237,25 @@ class Api:
script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
+
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin()
if selectable_scripts != None:
p.script_args = script_args
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
- b64images = list(map(encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -267,11 +276,10 @@ class Api:
populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True,
- "mask": mask
- }
- )
+ "do_not_save_samples": not img2imgreq.save_images,
+ "do_not_save_grid": not img2imgreq.save_images,
+ "mask": mask,
+ })
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
@@ -283,23 +291,26 @@ class Api:
script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
+
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin()
if selectable_scripts != None:
p.script_args = script_args
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
- b64images = list(map(encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
if not img2imgreq.include_init_images:
img2imgreq.init_images = None
diff --git a/modules/api/models.py b/modules/api/models.py
index e273469d..6c1c5546 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
"outpath_samples",
"outpath_grids",
"sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
+ # "do_not_save_samples",
+ # "do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
@@ -100,13 +100,32 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_scripts", "type": dict, "default": {}}]
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_scripts", "type": dict, "default": {}}]
+
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "init_images", "type": list, "default": None},
+ {"key": "denoising_strength", "type": float, "default": 0.75},
+ {"key": "mask", "type": str, "default": None},
+ {"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
).generate_model()
class TextToImageResponse(BaseModel):
@@ -267,3 +286,7 @@ class EmbeddingsResponse(BaseModel):
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
+
+class ScriptsList(BaseModel):
+ txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
+ img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 01fb7bd8..8d84bbc9 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -55,7 +55,7 @@ def setup_model(dirname):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
diff --git a/modules/extensions.py b/modules/extensions.py
index 3eef9eaf..ed4b58fe 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -66,7 +66,7 @@ class Extension:
def check_updates(self):
repo = git.Repo(self.path)
- for fetch in repo.remote().fetch("--dry-run"):
+ for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
@@ -79,8 +79,8 @@ class Extension:
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
- repo.git.fetch('--all')
- repo.git.reset('--hard', 'origin')
+ repo.git.fetch(all=True)
+ repo.git.reset('origin', hard=True)
def list_extensions():
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 89dc23bf..7c0b5b4e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -23,13 +23,14 @@ registered_param_bindings = []
class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
+ self.paste_field_names = paste_field_names
def reset():
@@ -134,7 +135,7 @@ def connect_paste_params_buttons():
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
if binding.source_tabname is not None and fields is not None:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
binding.paste_button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
@@ -288,6 +289,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
settings_map = {}
+
+
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
@@ -296,7 +299,11 @@ infotext_to_setting_name_mapping = [
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
- ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
+ ('UniPC variant', 'uni_pc_variant'),
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ('UniPC order', 'uni_pc_order'),
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
]
diff --git a/modules/images.py b/modules/images.py
index 5b80c23e..7df2b08c 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif image_to_save.mode == 'I;16':
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
if opts.enable_pnginfo and info is not None:
exif_bytes = piexif.dump({
diff --git a/modules/modelloader.py b/modules/modelloader.py
index fc3f6249..e351d808 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -6,7 +6,7 @@ from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared
-from modules.upscaler import Upscaler
+from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -169,4 +169,8 @@ def load_upscalers():
scaler = cls(commandline_options.get(cmd_name, None))
datas += scaler.scalers
- shared.sd_upscalers = datas
+ shared.sd_upscalers = sorted(
+ datas,
+ # Special case for UpscalerNone keeps it at the beginning of the list.
+ key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
+ )
diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py
new file mode 100644
index 00000000..e1265e3f
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/__init__.py
@@ -0,0 +1 @@
+from .sampler import UniPCSampler
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
new file mode 100644
index 00000000..bf346ff4
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -0,0 +1,100 @@
+"""SAMPLING ONLY."""
+
+import torch
+
+from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
+from modules import shared, devices
+
+
+class UniPCSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.before_sample = None
+ self.after_sample = None
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != devices.device:
+ attr = attr.to(devices.device)
+ setattr(self, name, attr)
+
+ def set_hooks(self, before_sample, after_sample, after_update):
+ self.before_sample = before_sample
+ self.after_sample = after_sample
+ self.after_update = after_update
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for UniPC sampling is {size}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ # SD 1.X is "noise", SD 2.X is "v"
+ model_type = "v" if self.model.parameterization == "v" else "noise"
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=model_type,
+ guidance_type="classifier-free",
+ #condition=conditioning,
+ #unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
+ x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
+
+ return x.to(device), None
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
new file mode 100644
index 00000000..df63d1bc
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -0,0 +1,856 @@
+import torch
+import torch.nn.functional as F
+import math
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0**2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ #condition=None,
+ #unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``