diff options
Diffstat (limited to 'modules/models')
-rw-r--r-- | modules/models/diffusion/ddpm_edit.py | 30 | ||||
-rw-r--r-- | modules/models/diffusion/uni_pc/__init__.py | 2 | ||||
-rw-r--r-- | modules/models/diffusion/uni_pc/sampler.py | 3 | ||||
-rw-r--r-- | modules/models/diffusion/uni_pc/uni_pc.py | 5 |
4 files changed, 17 insertions, 23 deletions
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index f3d49c44..611c2b69 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -223,7 +223,7 @@ class DDPM(pl.LightningModule): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) @@ -386,7 +386,7 @@ class DDPM(pl.LightningModule): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @@ -479,7 +479,7 @@ class LatentDiffusion(DDPM): self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -891,16 +891,6 @@ class LatentDiffusion(DDPM): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -1171,8 +1161,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1219,8 +1211,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1337,7 +1331,7 @@ class LatentDiffusion(DDPM): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py index e1265e3f..dbb35964 100644 --- a/modules/models/diffusion/uni_pc/__init__.py +++ b/modules/models/diffusion/uni_pc/__init__.py @@ -1 +1 @@ -from .sampler import UniPCSampler +from .sampler import UniPCSampler # noqa: F401 diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index a241c8a7..0a9defa1 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -54,7 +54,8 @@ class UniPCSampler(object): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index eb5f4e76..a4c4ef4e 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F import math from tqdm.auto import trange @@ -94,7 +93,7 @@ class NoiseScheduleVP: """ if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'") self.schedule = schedule if schedule == 'discrete': @@ -469,7 +468,7 @@ class UniPC: t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) return t else: - raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'") def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ |