From 125319988984987801dc4b4ab1e5ed36e9b211c5 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 03:30:20 -0800 Subject: Working UniPC (for batch size 1) --- modules/models/diffusion/uni_pc/__init__.py | 1 + modules/models/diffusion/uni_pc/sampler.py | 85 +++ modules/models/diffusion/uni_pc/uni_pc.py | 858 ++++++++++++++++++++++++++++ modules/processing.py | 2 +- modules/sd_samplers_compvis.py | 35 +- 5 files changed, 975 insertions(+), 6 deletions(-) create mode 100644 modules/models/diffusion/uni_pc/__init__.py create mode 100644 modules/models/diffusion/uni_pc/sampler.py create mode 100644 modules/models/diffusion/uni_pc/uni_pc.py (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py new file mode 100644 index 00000000..e1265e3f --- /dev/null +++ b/modules/models/diffusion/uni_pc/__init__.py @@ -0,0 +1 @@ +from .sampler import UniPCSampler diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py new file mode 100644 index 00000000..7cccd8a2 --- /dev/null +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -0,0 +1,85 @@ +"""SAMPLING ONLY.""" + +import torch + +from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC + +class UniPCSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.before_sample = None + self.after_sample = None + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def set_hooks(self, before, after): + self.before_sample = before + self.after_sample = after + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + #condition=conditioning, + #unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + + return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py new file mode 100644 index 00000000..ec6b37da --- /dev/null +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -0,0 +1,858 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + #condition=None, + #unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, None, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input, condition): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous, condition, unconditional_condition): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input, condition) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + if isinstance(condition, dict): + assert isinstance(unconditional_condition, dict) + c_in = dict() + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([ + unconditional_condition[k][i], + condition[k][i]]) for i in range(len(condition[k]))] + else: + c_in[k] = torch.cat([ + unconditional_condition[k], + condition[k]]) + elif isinstance(condition, list): + c_in = list() + assert isinstance(unconditional_condition, list) + for i in range(len(condition)): + c_in.append(torch.cat([unconditional_condition[i], condition[i]])) + else: + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class UniPC: + def __init__( + self, + model_fn, + noise_schedule, + predict_x0=True, + thresholding=False, + max_val=1., + variant='bh1', + condition=None, + unconditional_condition=None, + before_sample=None, + after_sample=None + ): + """Construct a UniPC. + + We support both data_prediction and noise_prediction. + """ + self.model_fn_ = model_fn + self.noise_schedule = noise_schedule + self.variant = variant + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + self.condition = condition + self.unconditional_condition = unconditional_condition + self.before_sample = before_sample + self.after_sample = after_sample + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model(self, x, t): + cond = self.condition + uncond = self.unconditional_condition + if self.before_sample is not None: + x, t, cond, uncond = self.before_sample(x, t, cond, uncond) + res = self.model_fn_(x, t, cond, uncond) + if self.after_sample is not None: + x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res) + + if isinstance(res, tuple): + # (None, pred_x0) + res = res[1] + + return res + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + from pprint import pp + print("X:") + pp(x) + print("sigma_t:") + pp(sigma_t) + print("noise:") + pp(noise) + print("alpha_t:") + pp(alpha_t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = steps + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): + if len(t.shape) == 0: + t = t.view(-1) + if 'bh' in self.variant: + return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + else: + assert self.variant == 'vary_coeff' + return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + + def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): + print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_t = ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + K = len(rks) + # build C matrix + C = [] + + col = torch.ones_like(rks) + for k in range(1, K + 1): + C.append(col) + col = col * rks / (k + 1) + C = torch.stack(C, dim=1) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + C_inv_p = torch.linalg.inv(C[:-1, :-1]) + A_p = C_inv_p + + if use_corrector: + print('using corrector') + C_inv = torch.linalg.inv(C) + A_c = C_inv + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_ks = [] + factorial_k = 1 + h_phi_k = h_phi_1 + for k in range(1, K + 2): + h_phi_ks.append(h_phi_k) + h_phi_k = h_phi_k / hh - 1 / factorial_k + factorial_k *= (k + 1) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + else: + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + x_t_ = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * h_phi_1) * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + return x_t, model_t + + def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): + print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + ns = self.noise_schedule + assert order <= len(model_prev_list) + dims = x.dim() + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = ((lambda_prev_i - lambda_prev_0) / h)[0] + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h[0] if self.predict_x0 else h[0] + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=x.device) + + # now predictor + use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if x_t is None: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if use_corrector: + print('using corrector') + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + model_t = None + if self.predict_x0: + x_t_ = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0 + ) + + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x + - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 + ) + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + return x_t, model_t + + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, corrector=False, + ): + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + print('this step order:', step_order) + if step == steps: + print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list[-1] = model_x + else: + raise NotImplementedError() + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] diff --git a/modules/processing.py b/modules/processing.py index e1b53ac0..11e726df 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index d03131cd..86fa1c5b 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -7,19 +7,27 @@ import torch from modules.shared import state from modules import sd_samplers_common, prompt_parser, shared +import modules.models.diffusion.uni_pc samplers_data_compvis = [ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), ] class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) + self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim + self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) + self.orig_p_sample_ddim = None + if self.is_plms: + self.orig_p_sample_ddim = self.sampler.p_sample_plms + elif self.is_ddim: + self.orig_p_sample_ddim = self.sampler.p_sample_ddim self.mask = None self.nmask = None self.init_latent = None @@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler: return self.last_latent def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) + + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + + x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) + + return res + + def before_sample(self, x, ts, cond, unconditional_conditioning): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler: if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x_dec = img_orig * self.mask + self.nmask * x_dec + x = img_orig * self.mask + self.nmask * x # Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Note that they need to be lists because it just concatenates them later. @@ -84,7 +101,13 @@ class VanillaStableDiffusionSampler: cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + return x, ts, cond, unconditional_conditioning + + def after_sample(self, x, ts, cond, uncond, res): + if self.is_unipc: + # unipc model_fn returns (pred_x0) + # p_sample_ddim returns (x_prev, pred_x0) + res = (None, res[0]) if self.mask is not None: self.last_latent = self.init_latent * self.mask + self.nmask * res[1] @@ -97,7 +120,7 @@ class VanillaStableDiffusionSampler: state.sampling_step = self.step shared.total_tqdm.update() - return res + return x, ts, cond, uncond, res def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim @@ -107,12 +130,14 @@ class VanillaStableDiffusionSampler: for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) + if self.is_unipc: + self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r)) self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if ((self.config.name == 'DDIM' or self.config.name == "UniPC") and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 -- cgit v1.2.3 From 21880eb9e57b884635a07d2360831b4186afddf4 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 04:47:08 -0800 Subject: Fix logspam and live previews --- modules/models/diffusion/uni_pc/sampler.py | 20 ++++++++++++++----- modules/models/diffusion/uni_pc/uni_pc.py | 32 ++++++++++++++---------------- modules/sd_samplers_compvis.py | 20 ++++++++++--------- 3 files changed, 41 insertions(+), 31 deletions(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 7cccd8a2..219e9862 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -19,9 +19,10 @@ class UniPCSampler(object): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def set_hooks(self, before, after): - self.before_sample = before - self.after_sample = after + def set_hooks(self, before_sample, after_sample, after_update): + self.before_sample = before_sample + self.after_sample = after_sample + self.after_update = after_update @torch.no_grad() def sample(self, @@ -50,9 +51,17 @@ class UniPCSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") @@ -60,6 +69,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) + print(f'Data shape for UniPC sampling is {size}, eta {eta}') device = self.model.betas.device if x_T is None: @@ -79,7 +89,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index ec6b37da..31ee81a6 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -378,7 +378,8 @@ class UniPC: condition=None, unconditional_condition=None, before_sample=None, - after_sample=None + after_sample=None, + after_update=None ): """Construct a UniPC. @@ -394,6 +395,7 @@ class UniPC: self.unconditional_condition = unconditional_condition self.before_sample = before_sample self.after_sample = after_sample + self.after_update = after_update def dynamic_thresholding_fn(self, x0, t=None): """ @@ -434,15 +436,6 @@ class UniPC: noise = self.noise_prediction_fn(x, t) dims = x.dim() alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - from pprint import pp - print("X:") - pp(x) - print("sigma_t:") - pp(sigma_t) - print("noise:") - pp(noise) - print("alpha_t:") - pp(alpha_t) x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) if self.thresholding: p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. @@ -524,7 +517,7 @@ class UniPC: return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') ns = self.noise_schedule assert order <= len(model_prev_list) @@ -568,7 +561,7 @@ class UniPC: A_p = C_inv_p if use_corrector: - print('using corrector') + #print('using corrector') C_inv = torch.linalg.inv(C) A_c = C_inv @@ -627,7 +620,7 @@ class UniPC: return x_t, model_t def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') ns = self.noise_schedule assert order <= len(model_prev_list) dims = x.dim() @@ -695,7 +688,7 @@ class UniPC: D1s = None if use_corrector: - print('using corrector') + #print('using corrector') # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], device=b.device) @@ -755,8 +748,9 @@ class UniPC: t_T = self.noise_schedule.T if t_start is None else t_start device = x.device if method == 'multistep': - assert steps >= order + assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) @@ -768,6 +762,8 @@ class UniPC: x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) if model_x is None: model_x = self.model_fn(x, vec_t) + if self.after_update is not None: + self.after_update(x, model_x) model_prev_list.append(model_x) t_prev_list.append(vec_t) for step in range(order, steps + 1): @@ -776,13 +772,15 @@ class UniPC: step_order = min(order, steps + 1 - step) else: step_order = order - print('this step order:', step_order) + #print('this step order:', step_order) if step == steps: - print('do not run corrector at the last step') + #print('do not run corrector at the last step') use_corrector = False else: use_corrector = True x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + if self.after_update is not None: + self.after_update(x, model_x) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 86fa1c5b..946079ae 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler: return x, ts, cond, unconditional_conditioning - def after_sample(self, x, ts, cond, uncond, res): - if self.is_unipc: - # unipc model_fn returns (pred_x0) - # p_sample_ddim returns (x_prev, pred_x0) - res = (None, res[0]) - + def update_step(self, last_latent): if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * res[1] + self.last_latent = self.init_latent * self.mask + self.nmask * last_latent else: - self.last_latent = res[1] + self.last_latent = last_latent sd_samplers_common.store_latent(self.last_latent) @@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler: state.sampling_step = self.step shared.total_tqdm.update() + def after_sample(self, x, ts, cond, uncond, res): + if not self.is_unipc: + self.update_step(res[1]) + return x, ts, cond, uncond, res + def unipc_after_update(self, x, model_x): + self.update_step(x) + def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim if self.eta != 0.0: @@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) if self.is_unipc: - self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r)) + self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None -- cgit v1.2.3 From c88dcc20d495dab4be2692bdff30277112dbe416 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:00:09 -0800 Subject: UniPC does not support img2img (for now) --- modules/processing.py | 2 +- modules/sd_samplers.py | 2 +- modules/sd_samplers_compvis.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 11e726df..b7cf5357 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = 'DDIM' # PLMS/UniPC does not support img2img so we just silently switch ot DDIM self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 28c2136f..ff361f22 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -32,7 +32,7 @@ def set_samplers(): global samplers, samplers_for_img2img hidden = set(shared.opts.hide_samplers) - hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) + hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 946079ae..ad39ab2b 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -139,7 +139,7 @@ class VanillaStableDiffusionSampler: self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if ((self.config.name == 'DDIM' or self.config.name == "UniPC") and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 -- cgit v1.2.3 From 79ffb9453f8eddbdd4e316b9d9c75812b0eea4e1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:27:05 -0800 Subject: Add UniPC sampler settings --- modules/models/diffusion/uni_pc/sampler.py | 5 +++-- modules/models/diffusion/uni_pc/uni_pc.py | 2 +- modules/shared.py | 5 +++++ 3 files changed, 9 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 219e9862..e66a21e3 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -3,6 +3,7 @@ import torch from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC +from modules import shared class UniPCSampler(object): def __init__(self, model, **kwargs): @@ -89,7 +90,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) - x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=shared.opts.uni_pc_thresholding, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) + x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 31ee81a6..df63d1bc 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -750,7 +750,7 @@ class UniPC: if method == 'multistep': assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps") + print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) diff --git a/modules/shared.py b/modules/shared.py index 79fbf724..34242073 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -480,6 +480,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), + 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), + 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), + 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 150 - 1, "step": 1}), + 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), + 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { -- cgit v1.2.3 From 06cb0dc92095647e4856be10b4d7dc12f5e11fa1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:36:41 -0800 Subject: Fix UniPC order --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 34242073..670d4954 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -482,7 +482,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), - 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 150 - 1, "step": 1}), + 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}), 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) -- cgit v1.2.3 From fb274229b2c5c1a89dac0b3da28c08c92d71fd95 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 14:30:35 -0800 Subject: bug fix --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/processing.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index e66a21e3..0bef6eed 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -70,7 +70,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for UniPC sampling is {size}, eta {eta}') + print(f'Data shape for UniPC sampling is {size}') device = self.model.betas.device if x_T is None: diff --git a/modules/processing.py b/modules/processing.py index b7cf5357..0ca15491 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = 'DDIM' # PLMS/UniPC does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = self.sampler_name + if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM + img2img_sampler_name = 'DDIM' self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] -- cgit v1.2.3 From 716a69237cefb385f71105dbbf50e92d664e0f42 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 11 Feb 2023 06:18:34 -0800 Subject: support SD2.X models --- modules/models/diffusion/uni_pc/sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 0bef6eed..708a9b2b 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -80,10 +80,13 @@ class UniPCSampler(object): ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + # SD 1.X is "noise", SD 2.X is "v" + model_type = "v" if self.model.parameterization == "v" else "noise" + model_fn = model_wrapper( lambda x, t, c: self.model.apply_model(x, t, c), ns, - model_type="noise", + model_type=model_type, guidance_type="classifier-free", #condition=conditioning, #unconditional_condition=unconditional_conditioning, -- cgit v1.2.3 From a320d157ec0221fa4e9c756327e31d881b9921ae Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 13 Feb 2023 20:26:47 -0500 Subject: all hiding of ui tabs --- modules/shared.py | 1 + modules/ui.py | 3 +++ 2 files changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 79fbf724..ded28925 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -455,6 +455,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), + "hidden_tabs": OptionInfo("", "Hidden UI tabs"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), diff --git a/modules/ui.py b/modules/ui.py index f5df1ffe..c99e55ab 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1568,7 +1568,10 @@ def create_ui(): parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: + hidden_tabs = [x.lower().strip() for x in shared.opts.hidden_tabs.split(",")] for interface, label, ifid in interfaces: + if label.lower() in hidden_tabs: + continue with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): interface.render() -- cgit v1.2.3 From 7df7e4d22796fda11629463f2fcbe859b98b1d19 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 14 Feb 2023 03:55:42 -0800 Subject: Allow extensions to declare paste fields for "Send to X" buttons --- modules/generation_parameters_copypaste.py | 5 +++-- modules/scripts.py | 9 +++++++++ modules/ui_common.py | 9 ++++++++- 3 files changed, 20 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index fc9e17aa..93d955db 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -23,13 +23,14 @@ registered_param_bindings = [] class ParamBinding: - def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None): + def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]): self.paste_button = paste_button self.tabname = tabname self.source_text_component = source_text_component self.source_image_component = source_image_component self.source_tabname = source_tabname self.override_settings_component = override_settings_component + self.paste_field_names = paste_field_names def reset(): @@ -133,7 +134,7 @@ def connect_paste_params_buttons(): connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname) if binding.source_tabname is not None and fields is not None: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names binding.paste_button.click( fn=lambda *x: x, inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], diff --git a/modules/scripts.py b/modules/scripts.py index 24056a12..ac0785ce 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -33,6 +33,11 @@ class Script: parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example """ + paste_field_names = None + """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the + various "Send to " buttons when clicked + """ + def title(self): """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" @@ -256,6 +261,7 @@ class ScriptRunner: self.alwayson_scripts = [] self.titles = [] self.infotext_fields = [] + self.paste_field_names = [] def initialize_scripts(self, is_img2img): from modules import scripts_auto_postprocessing @@ -304,6 +310,9 @@ class ScriptRunner: if script.infotext_fields is not None: self.infotext_fields += script.infotext_fields + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names + inputs += controls inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(inputs) diff --git a/modules/ui_common.py b/modules/ui_common.py index fd047f31..a12433d2 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -198,9 +198,16 @@ Requested path was: {f} html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}') + paste_field_names = [] + if tabname == "txt2img": + paste_field_names = modules.scripts.scripts_txt2img.paste_field_names + elif tabname == "img2img": + paste_field_names = modules.scripts.scripts_img2img.paste_field_names + for paste_tabname, paste_button in buttons.items(): parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery, + paste_field_names=paste_field_names )) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log -- cgit v1.2.3 From 83829471decbde64d335eb510d4a5670baf68773 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 19 Feb 2023 09:21:44 -0500 Subject: make ui as multiselect instead of string list --- modules/shared.py | 3 ++- modules/ui.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 1a1abeb2..a7c5f58e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -305,6 +305,7 @@ def list_samplers(): hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} +tab_names = [] options_templates = {} @@ -460,7 +461,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), - "hidden_tabs": OptionInfo("", "Hidden UI tabs"), + "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), diff --git a/modules/ui.py b/modules/ui.py index a4ecd41b..5ac249b2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1563,6 +1563,10 @@ def create_ui(): extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] + shared.tab_names = [] + for _interface, label, _ifid in interfaces: + shared.tab_names.append(label) + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings", variant="compact"): for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): @@ -1572,9 +1576,8 @@ def create_ui(): parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: - hidden_tabs = [x.lower().strip() for x in shared.opts.hidden_tabs.split(",")] for interface, label, ifid in interfaces: - if label.lower() in hidden_tabs: + if label in shared.opts.hidden_tabs: continue with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): interface.render() -- cgit v1.2.3 From b0f2653541219486c8b8cbdbf8ce80caf3f62ef8 Mon Sep 17 00:00:00 2001 From: xSinStarx <98231899+xSinStarx@users.noreply.github.com> Date: Mon, 20 Feb 2023 12:39:38 -0800 Subject: Fixes img2img Negative Token Counter The img2img negative token counter is counting the txt2img negative prompt. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 0516c643..a37b1739 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -939,7 +939,7 @@ def create_ui(): ) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) -- cgit v1.2.3 From 32a4c8d961df3da4534c98fd0573d854cff1bb91 Mon Sep 17 00:00:00 2001 From: Kilvoctu Date: Mon, 20 Feb 2023 15:14:06 -0600 Subject: use emojis for extra network buttons MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🔄 for refresh ❌ for close --- modules/ui_extra_networks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 71f1d81f..8786fde6 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -13,6 +13,10 @@ from modules.generation_parameters_copypaste import image_from_url_text extra_pages = [] allowed_dirs = set() +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +refresh_symbol = '\U0001f504' # 🔄 +close_symbol = '\U0000274C' # ❌ def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -182,8 +186,8 @@ def create_ui(container, button, tabname): ui.pages.append(page_elem) filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") - button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + button_refresh = gr.Button(refresh_symbol, elem_id=tabname+"_extra_refresh") + button_close = gr.Button(close_symbol, elem_id=tabname+"_extra_close") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) -- cgit v1.2.3 From a2d635ad135241a0a40f67f7e1638c9c8a4ded04 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 22 Feb 2023 01:52:53 -0800 Subject: Add before_process_batch script callback --- modules/processing.py | 3 +++ modules/scripts.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 2009d3bf..187e98fd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + if p.scripts is not None: + p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) + if len(prompts) == 0: break diff --git a/modules/scripts.py b/modules/scripts.py index 24056a12..e6a505b3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -80,6 +80,20 @@ class Script: pass + def before_process_batch(self, p, *args, **kwargs): + """ + Called before extra networks are parsed from the prompt, so you can add + new extra network keywords to the prompt with this callback. + + **kwargs will have those items: + - batch_number - index of current batch, from 0 to number of batches-1 + - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things + - seeds - list of seeds for current batch + - subseeds - list of subseeds for current batch + """ + + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -388,6 +402,15 @@ class ScriptRunner: print(f"Error running process: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def before_process_batch(self, p, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_process_batch(p, *script_args, **kwargs) + except Exception: + print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: try: -- cgit v1.2.3 From b90cad7f3136bbe04efeee2a00e95d0cc6ce1a4a Mon Sep 17 00:00:00 2001 From: "fkunn1326@users.noreply.github.com" Date: Thu, 23 Feb 2023 03:29:22 +0000 Subject: Add .mjs support for extensions --- modules/ui.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 0516c643..2509ce2d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1754,6 +1754,9 @@ def reload_javascript(): for script in modules.scripts.list_scripts("javascript", ".js"): head += f'\n' + for script in modules.scripts.list_scripts("javascript", ".mjs"): + head += f'\n' + head += f'\n' def template_response(*args, **kwargs): -- cgit v1.2.3 From ac4c7f05cd38dfa99cf64f7ddb9b1656e70a13c5 Mon Sep 17 00:00:00 2001 From: Tpinion Date: Fri, 24 Feb 2023 00:42:29 +0800 Subject: Filter out temporary files that will be generated if the download fails. --- modules/codeformer_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 01fb7bd8..8d84bbc9 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -55,7 +55,7 @@ def setup_model(dirname): if self.net is not None and self.face_helper is not None: self.net.to(devices.device_codeformer) return self.net, self.face_helper - model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth') + model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) if len(model_paths) != 0: ckpt_path = model_paths[0] else: -- cgit v1.2.3 From 327186b484c344598522a989a4b4859d6b90fb04 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Fri, 24 Feb 2023 14:03:46 +0900 Subject: Update script_callbacks.py --- modules/script_callbacks.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index edd0e2a7..c5bb3e71 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -29,7 +29,7 @@ class ImageSaveParams: class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, tensor, uncond): self.x = x """Latent image representation in the process of being denoised""" @@ -44,6 +44,12 @@ class CFGDenoiserParams: self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" + + self.tensor = tensor + """ Encoder hidden states of condtioning""" + + self.uncond = uncond + """ Encoder hidden states of unconditioning""" class CFGDenoisedParams: -- cgit v1.2.3 From 9a1435946ce7cc7f2cdaa0a312e1c0d296b8c276 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Fri, 24 Feb 2023 14:04:23 +0900 Subject: Update sd_samplers_kdiffusion.py --- modules/sd_samplers_kdiffusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 528f513f..ea974be0 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond sigma_in = denoiser_params.sigma + tensor = denoiser_params.tensor + uncond = denoiser_params.uncond if tensor.shape[1] == uncond.shape[1]: if not is_edit_model: -- cgit v1.2.3 From 534cf60afb547de72891e1e87b59a1433aadeee3 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Fri, 24 Feb 2023 14:26:55 +0900 Subject: Update script_callbacks.py --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index c5bb3e71..d1703135 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -46,7 +46,7 @@ class CFGDenoiserParams: """Total number of sampling steps planned""" self.tensor = tensor - """ Encoder hidden states of condtioning""" + """ Encoder hidden states of conditioning""" self.uncond = uncond """ Encoder hidden states of unconditioning""" -- cgit v1.2.3 From b15bc73c99e6fbbeffdbdbeab39ba30276021d4b Mon Sep 17 00:00:00 2001 From: Brad Smith Date: Fri, 24 Feb 2023 14:22:58 -0500 Subject: sort upscalers by name --- modules/modelloader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/modelloader.py b/modules/modelloader.py index fc3f6249..a7ac338c 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url from modules import shared -from modules.upscaler import Upscaler +from modules.upscaler import Upscaler, UpscalerNone from modules.paths import script_path, models_path @@ -169,4 +169,8 @@ def load_upscalers(): scaler = cls(commandline_options.get(cmd_name, None)) datas += scaler.scalers - shared.sd_upscalers = datas + shared.sd_upscalers = sorted( + datas, + # Special case for UpscalerNone keeps it at the beginning of the list. + key=lambda x: x.name if not isinstance(x.scaler, UpscalerNone) else "" + ) -- cgit v1.2.3 From aa108bd02a8282e8213fa6c5967e3c47e49bb43f Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 24 Feb 2023 20:57:18 -0700 Subject: Add lossless webp option --- modules/images.py | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 5b80c23e..7df2b08c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i elif image_to_save.mode == 'I;16': image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L") - image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) + image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless) if opts.enable_pnginfo and info is not None: exif_bytes = piexif.dump({ diff --git a/modules/shared.py b/modules/shared.py index 805f9cc1..51101988 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,6 +327,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"), "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"), "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), -- cgit v1.2.3 From 6d92d95a33e46aea7a7b8c136a76a621d7fc4f52 Mon Sep 17 00:00:00 2001 From: Adam Huganir Date: Sat, 25 Feb 2023 19:15:06 +0000 Subject: git 3.1.30 api change --- modules/extensions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py index 3eef9eaf..ed4b58fe 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -66,7 +66,7 @@ class Extension: def check_updates(self): repo = git.Repo(self.path) - for fetch in repo.remote().fetch("--dry-run"): + for fetch in repo.remote().fetch(dry_run=True): if fetch.flags != fetch.HEAD_UPTODATE: self.can_update = True self.status = "behind" @@ -79,8 +79,8 @@ class Extension: repo = git.Repo(self.path) # Fix: `error: Your local changes to the following files would be overwritten by merge`, # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. - repo.git.fetch('--all') - repo.git.reset('--hard', 'origin') + repo.git.fetch(all=True) + repo.git.reset('origin', hard=True) def list_extensions(): -- cgit v1.2.3 From 3c6459154fb115ea7cf1a0c5f3f0761a192dfea3 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 27 Feb 2023 17:28:04 -0500 Subject: add check for resulting image size --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 805f9cc1..ec08b7be 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -330,6 +330,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"), "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), + "img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number), "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), -- cgit v1.2.3 From 3b6de96467b0ee6d59f3aaf4bafd1467633e14c6 Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sun, 26 Feb 2023 19:17:58 -0500 Subject: Added alwayson_script_name and alwayson_script_args to api Added 2 additional possible entries in the api request: alwayson_script_name, a string list, and, alwayson_script_args, a list of list containing the args of each script. This allows us to send args to always on script and keep backwards compatibility with old script_name and script_arg api params --- modules/api/api.py | 111 ++++++++++++++++++++++++++++++++++++++++++++------ modules/api/models.py | 4 +- 2 files changed, 100 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5a9ac5f1..a1cdebb8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -163,20 +163,26 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) - def get_script(self, script_name, script_runner): - if script_name is None: + def get_selectable_script(self, script_name, script_runner): + if script_name is None or script_name == "": return None, None - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx + def get_script(self, script_name, script_runner): + for script in script_runner.scripts: + if script_name.lower() == script.title().lower(): + return script + return None + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) + script_runner = scripts.scripts_txt2img + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), @@ -184,22 +190,59 @@ class Api: "do_not_save_grid": True } ) + if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('script_name', None) + args.pop('script_args', None) # will refeed them later with script_args + args.pop('alwayson_script_name', None) + args.pop('alwayson_script_args', None) + + #find max idx from the scripts in runner and generate a none array to init script_args + last_arg_index = 1 + for script in script_runner.scripts: + if last_arg_index < script.args_to: + last_arg_index = script.args_to + # None everywhere exepct position 0 to initialize script args + script_args = [None]*last_arg_index + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run + if api_selectable_scripts: + script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = txt2imgreq.script_args + script_args[0] = api_selectable_script_idx + 1 + else: + # if 0 then none + script_args[0] = 0 + + # Now check for always on scripts + if len(txt2imgreq.alwayson_script_name) > 0: + # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args + if len(txt2imgreq.alwayson_script_name) != len(txt2imgreq.alwayson_script_args): + raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") + + for alwayson_script_name, alwayson_script_args in zip(txt2imgreq.alwayson_script_name, txt2imgreq.alwayson_script_args): + alwayson_script = self.get_script(alwayson_script_name, script_runner) + if alwayson_script == None: + raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") + # Selectable script in always on script param check + if alwayson_script.alwayson == False: + raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + if alwayson_script_args != []: + script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) + p.scripts = script_runner shared.state.begin() - if script is not None: + if api_selectable_scripts != None: + p.script_args = script_args p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) else: + p.script_args = tuple(script_args) processed = process_images(p) shared.state.end() @@ -212,12 +255,16 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") - script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) - mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) + script_runner = scripts.scripts_img2img + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": True, @@ -225,24 +272,62 @@ class Api: "mask": mask } ) + if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + args.pop('script_args', None) # will refeed them later with script_args + args.pop('alwayson_script_name', None) + args.pop('alwayson_script_args', None) + + #find max idx from the scripts in runner and generate a none array to init script_args + last_arg_index = 1 + for script in script_runner.scripts: + if last_arg_index < script.args_to: + last_arg_index = script.args_to + # None everywhere exepct position 0 to initialize script args + script_args = [None]*last_arg_index + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run + if api_selectable_scripts: + script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = img2imgreq.script_args + script_args[0] = api_selectable_script_idx + 1 + else: + # if 0 then none + script_args[0] = 0 + + # Now check for always on scripts + if len(img2imgreq.alwayson_script_name) > 0: + # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args + if len(img2imgreq.alwayson_script_name) != len(img2imgreq.alwayson_script_args): + raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") + + for alwayson_script_name, alwayson_script_args in zip(img2imgreq.alwayson_script_name, img2imgreq.alwayson_script_args): + alwayson_script = self.get_script(alwayson_script_name, script_runner) + if alwayson_script == None: + raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") + # Selectable script in always on script param check + if alwayson_script.alwayson == False: + raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + if alwayson_script_args != []: + script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args + with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner shared.state.begin() - if script is not None: + if api_selectable_scripts != None: + p.script_args = script_args p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_img2img.run(p, *p.script_args) else: + p.script_args = tuple(script_args) processed = process_images(p) shared.state.end() diff --git a/modules/api/models.py b/modules/api/models.py index cba43d3b..86c70178 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,13 +100,13 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] ).generate_model() class TextToImageResponse(BaseModel): -- cgit v1.2.3 From a39c4cf766a9b0f18972fc52bcf5189173f434c6 Mon Sep 17 00:00:00 2001 From: Vespinian Date: Mon, 27 Feb 2023 23:27:33 -0500 Subject: small refactor of api.py --- modules/api/api.py | 125 ++++++++++++++++++++++------------------------------- 1 file changed, 51 insertions(+), 74 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a1cdebb8..d4c0c152 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -172,56 +172,37 @@ class Api: return script, script_idx def get_script(self, script_name, script_runner): - for script in script_runner.scripts: - if script_name.lower() == script.title().lower(): - return script - return None - - def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - script_runner = scripts.scripts_txt2img - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) - - populate = txt2imgreq.copy(update={ # Override __init__ params - "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True - } - ) - - if populate.sampler_name: - populate.sampler_index = None # prevent a warning later on - - args = vars(populate) - args.pop('script_name', None) - args.pop('script_args', None) # will refeed them later with script_args - args.pop('alwayson_script_name', None) - args.pop('alwayson_script_args', None) + if script_name is None or script_name == "": + return None, None + + script_idx = script_name_to_index(script_name, script_runner.scripts) + return script_runner.scripts[script_idx] + def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner): #find max idx from the scripts in runner and generate a none array to init script_args last_arg_index = 1 for script in script_runner.scripts: if last_arg_index < script.args_to: last_arg_index = script.args_to - # None everywhere exepct position 0 to initialize script args + # None everywhere except position 0 to initialize script args script_args = [None]*last_arg_index - # position 0 in script_arg is the idx+1 of the selectable script that is going to be run - if api_selectable_scripts: - script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = txt2imgreq.script_args - script_args[0] = api_selectable_script_idx + 1 + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() + if selectable_scripts: + script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args + script_args[0] = selectable_idx + 1 else: # if 0 then none script_args[0] = 0 # Now check for always on scripts - if len(txt2imgreq.alwayson_script_name) > 0: + if request.alwayson_script_name and (len(request.alwayson_script_name) > 0): # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args - if len(txt2imgreq.alwayson_script_name) != len(txt2imgreq.alwayson_script_args): + if not request.alwayson_script_args: + raise HTTPException(status_code=422, detail=f"Script {request.alwayson_script_name} has no arg list") + if len(request.alwayson_script_name) != len(request.alwayson_script_args): raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") - for alwayson_script_name, alwayson_script_args in zip(txt2imgreq.alwayson_script_name, txt2imgreq.alwayson_script_args): + for alwayson_script_name, alwayson_script_args in zip(request.alwayson_script_name, request.alwayson_script_args): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script == None: raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") @@ -230,19 +211,45 @@ class Api: raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") if alwayson_script_args != []: script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args + return script_args + + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + script_runner = scripts.scripts_txt2img + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) + + populate = txt2imgreq.copy(update={ # Override __init__ params + "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), + "do_not_save_samples": True, + "do_not_save_grid": True + } + ) + + if populate.sampler_name: + populate.sampler_index = None # prevent a warning later on + + args = vars(populate) + args.pop('script_name', None) + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them + args.pop('alwayson_script_name', None) + args.pop('alwayson_script_args', None) + + script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) p.scripts = script_runner shared.state.begin() - if api_selectable_scripts != None: + if selectable_scripts != None: p.script_args = script_args p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples - processed = scripts.scripts_txt2img.run(p, *p.script_args) + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: - p.script_args = tuple(script_args) + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() @@ -263,7 +270,7 @@ class Api: if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() - api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) + selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), @@ -279,41 +286,11 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) - args.pop('script_args', None) # will refeed them later with script_args + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_script_name', None) args.pop('alwayson_script_args', None) - #find max idx from the scripts in runner and generate a none array to init script_args - last_arg_index = 1 - for script in script_runner.scripts: - if last_arg_index < script.args_to: - last_arg_index = script.args_to - # None everywhere exepct position 0 to initialize script args - script_args = [None]*last_arg_index - # position 0 in script_arg is the idx+1 of the selectable script that is going to be run - if api_selectable_scripts: - script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = img2imgreq.script_args - script_args[0] = api_selectable_script_idx + 1 - else: - # if 0 then none - script_args[0] = 0 - - # Now check for always on scripts - if len(img2imgreq.alwayson_script_name) > 0: - # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args - if len(img2imgreq.alwayson_script_name) != len(img2imgreq.alwayson_script_args): - raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") - - for alwayson_script_name, alwayson_script_args in zip(img2imgreq.alwayson_script_name, img2imgreq.alwayson_script_args): - alwayson_script = self.get_script(alwayson_script_name, script_runner) - if alwayson_script == None: - raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") - # Selectable script in always on script param check - if alwayson_script.alwayson == False: - raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") - if alwayson_script_args != []: - script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args - + script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) @@ -321,13 +298,13 @@ class Api: p.scripts = script_runner shared.state.begin() - if api_selectable_scripts != None: + if selectable_scripts != None: p.script_args = script_args p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples - processed = scripts.scripts_img2img.run(p, *p.script_args) + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: - p.script_args = tuple(script_args) + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() -- cgit v1.2.3 From c6c2a59333c77dffff49a748bfed8c54af6e2abd Mon Sep 17 00:00:00 2001 From: Vespinian Date: Mon, 27 Feb 2023 23:45:59 -0500 Subject: comment clarification --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index d4c0c152..248922d2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -191,7 +191,7 @@ class Api: script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args script_args[0] = selectable_idx + 1 else: - # if 0 then none + # when [0] = 0 no selectable script to run script_args[0] = 0 # Now check for always on scripts -- cgit v1.2.3 From 23d4fb5bf2400622d00ca5fe489fadb160ee7c47 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 3 Mar 2023 08:29:10 -0500 Subject: allow saving of images via api --- modules/api/api.py | 8 ++++---- modules/api/models.py | 4 ++-- modules/images.py | 3 +++ 3 files changed, 9 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5a9ac5f1..6b939daa 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -180,8 +180,8 @@ class Api: populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True + "do_not_save_samples": True if not 'do_not_save_samples' in vars(txt2imgreq) else txt2imgreq.do_not_save_samples, + "do_not_save_grid": True if not 'do_not_save_grid' in vars(txt2imgreq) else txt2imgreq.do_not_save_grid, } ) if populate.sampler_name: @@ -220,8 +220,8 @@ class Api: populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True, + "do_not_save_samples": True if not 'do_not_save_samples' in img2imgreq else img2imgreq.do_not_save_samples, + "do_not_save_grid": True if not 'do_not_save_grid' in img2imgreq else img2imgreq.do_not_save_grid, "mask": mask } ) diff --git a/modules/api/models.py b/modules/api/models.py index cba43d3b..a947e6ac 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -14,8 +14,8 @@ API_NOT_ALLOWED = [ "outpath_samples", "outpath_grids", "sampler_index", - "do_not_save_samples", - "do_not_save_grid", + # "do_not_save_samples", + # "do_not_save_grid", "extra_generation_params", "overlay_images", "do_not_reload_embeddings", diff --git a/modules/images.py b/modules/images.py index 5b80c23e..f8e62b71 100644 --- a/modules/images.py +++ b/modules/images.py @@ -489,6 +489,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ namegen = FilenameGenerator(p, seed, prompt, image) + if path is None: # set default path to avoid errors when functions are triggered manually or via api and param is not set + path = opts.outdir_save + if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) -- cgit v1.2.3 From f8e219bad9f33cde94cd31fff3edd70946612541 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 3 Mar 2023 09:00:52 -0500 Subject: allow api requests to specify do not send images in response --- modules/api/api.py | 10 ++++++++-- modules/api/models.py | 18 ++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 6b939daa..7da9081b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -190,6 +190,9 @@ class Api: args = vars(populate) args.pop('script_name', None) + send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] + args.pop('do_not_send_images', None) + with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) @@ -203,7 +206,7 @@ class Api: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -232,6 +235,9 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] + args.pop('do_not_send_images', None) + with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] @@ -246,7 +252,7 @@ class Api: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] if not img2imgreq.include_init_images: img2imgreq.init_images = None diff --git a/modules/api/models.py b/modules/api/models.py index a947e6ac..aa4ea5d5 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,13 +100,27 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [ + {"key": "sampler_index", "type": str, "default": "Euler"}, + {"key": "script_name", "type": str, "default": None}, + {"key": "script_args", "type": list, "default": []}, + {"key": "do_not_send_images", "type": bool, "default": False} + ] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [ + {"key": "sampler_index", "type": str, "default": "Euler"}, + {"key": "init_images", "type": list, "default": None}, + {"key": "denoising_strength", "type": float, "default": 0.75}, + {"key": "mask", "type": str, "default": None}, + {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, + {"key": "script_name", "type": str, "default": None}, + {"key": "script_args", "type": list, "default": []}, + {"key": "do_not_send_images", "type": bool, "default": False} + ] ).generate_model() class TextToImageResponse(BaseModel): -- cgit v1.2.3 From c48bbccf12f13cf309f532d70494ff04c27bcc2a Mon Sep 17 00:00:00 2001 From: Yea chen Date: Sat, 4 Mar 2023 11:46:07 +0800 Subject: add: /sdapi/v1/scripts in API API for get scripts list --- modules/api/api.py | 13 +++++++++++++ modules/api/models.py | 4 ++++ 2 files changed, 17 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5a9ac5f1..46cb7c81 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -150,6 +150,7 @@ class Api: self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) + self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -174,6 +175,18 @@ class Api: script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx + + def get_scripts_list(self): + t2ilist = [] + i2ilist = [] + + for a in scripts.scripts_txt2img.titles: + t2ilist.append(str(a.lower())) + + for b in scripts.scripts_img2img.titles: + i2ilist.append(str(b.lower())) + + return ScriptsList(txt2img = t2ilist, img2img = i2ilist) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) diff --git a/modules/api/models.py b/modules/api/models.py index cba43d3b..db739f2b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -267,3 +267,7 @@ class EmbeddingsResponse(BaseModel): class MemoryResponse(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") + +class ScriptsList(BaseModel): + txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") + img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file -- cgit v1.2.3 From b012d70f15641d6b85c9257b83cec892e941609c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 4 Mar 2023 17:51:37 -0500 Subject: update using original defaults --- modules/api/api.py | 17 +++++++++++------ modules/api/models.py | 6 ++++-- 2 files changed, 15 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 7da9081b..a6bb439c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -180,8 +180,8 @@ class Api: populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": True if not 'do_not_save_samples' in vars(txt2imgreq) else txt2imgreq.do_not_save_samples, - "do_not_save_grid": True if not 'do_not_save_grid' in vars(txt2imgreq) else txt2imgreq.do_not_save_grid, + "do_not_save_samples": txt2imgreq.do_not_save, + "do_not_save_grid": txt2imgreq.do_not_save, } ) if populate.sampler_name: @@ -190,8 +190,9 @@ class Api: args = vars(populate) args.pop('script_name', None) - send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] - args.pop('do_not_send_images', None) + send_images = True if not 'do_not_send' in args else not args['do_not_send'] + args.pop('do_not_send', None) + args.pop('do_not_save', None) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) @@ -223,8 +224,8 @@ class Api: populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), - "do_not_save_samples": True if not 'do_not_save_samples' in img2imgreq else img2imgreq.do_not_save_samples, - "do_not_save_grid": True if not 'do_not_save_grid' in img2imgreq else img2imgreq.do_not_save_grid, + "do_not_save_samples": img2imgreq.do_not_save, + "do_not_save_grid": img2imgreq.do_not_save, "mask": mask } ) @@ -235,6 +236,10 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + send_images = True if not 'do_not_send' in args else not args['do_not_send'] + args.pop('do_not_send', None) + args.pop('do_not_save', None) + send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] args.pop('do_not_send_images', None) diff --git a/modules/api/models.py b/modules/api/models.py index aa4ea5d5..2b66e1f0 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -104,7 +104,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send_images", "type": bool, "default": False} + {"key": "do_not_send", "type": bool, "default": False}, + {"key": "do_not_save", "type": bool, "default": True} ] ).generate_model() @@ -119,7 +120,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send_images", "type": bool, "default": False} + {"key": "do_not_send", "type": bool, "default": False}, + {"key": "do_not_save", "type": bool, "default": True} ] ).generate_model() -- cgit v1.2.3 From d118cb6ea3f1a410b5e030519dc021eafc1d6b52 Mon Sep 17 00:00:00 2001 From: Brad Smith Date: Mon, 6 Mar 2023 13:18:35 -0500 Subject: use lowercase name for sorting; keep `UpscalerLanczos` and `UpscalerNearest` at the start of the list with `UpscalerNone` Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com> --- modules/modelloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/modelloader.py b/modules/modelloader.py index a7ac338c..e351d808 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url from modules import shared -from modules.upscaler import Upscaler, UpscalerNone +from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -172,5 +172,5 @@ def load_upscalers(): shared.sd_upscalers = sorted( datas, # Special case for UpscalerNone keeps it at the beginning of the list. - key=lambda x: x.name if not isinstance(x.scaler, UpscalerNone) else "" + key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) -- cgit v1.2.3 From 49b1dc5e07825e76c85ac4ac078fd63aa835e8bd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 6 Mar 2023 21:00:34 +0200 Subject: Deduplicate extra network preview-search code --- modules/ui_extra_networks.py | 10 ++++++++++ modules/ui_extra_networks_checkpoints.py | 11 +---------- modules/ui_extra_networks_hypernets.py | 9 +-------- modules/ui_extra_networks_textual_inversion.py | 8 +------- 4 files changed, 13 insertions(+), 25 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 71f1d81f..1a10a5df 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,6 +2,7 @@ import glob import os.path import urllib.parse from pathlib import Path +from typing import Optional from modules import shared import gradio as gr @@ -137,6 +138,15 @@ class ExtraNetworksPage: return self.card_page.format(**args) + def _find_preview(self, path: str) -> Optional[str]: + """ + Find a preview PNG for a given path (without extension) and call link_preview on it. + """ + for file in [path + ".png", path + ".preview.png"]: + if os.path.isfile(file): + return self.link_preview(file) + return None + def intialize(): extra_pages.clear() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 04097a79..b712d12b 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -1,7 +1,6 @@ import html import json import os -import urllib.parse from modules import shared, ui_extra_networks, sd_models @@ -17,18 +16,10 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): checkpoint: sd_models.CheckpointInfo for name, checkpoint in sd_models.checkpoints_list.items(): path, ext = os.path.splitext(checkpoint.filename) - previews = [path + ".png", path + ".preview.png"] - - preview = None - for file in previews: - if os.path.isfile(file): - preview = self.link_preview(file) - break - yield { "name": checkpoint.name_for_extra, "filename": path, - "preview": preview, + "preview": self._find_preview(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 57851088..89f33242 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -14,18 +14,11 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def list_items(self): for name, path in shared.hypernetworks.items(): path, ext = os.path.splitext(path) - previews = [path + ".png", path + ".preview.png"] - - preview = None - for file in previews: - if os.path.isfile(file): - preview = self.link_preview(file) - break yield { "name": name, "filename": path, - "preview": preview, + "preview": self._find_preview(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index bb64eb81..f7057390 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -15,16 +15,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def list_items(self): for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): path, ext = os.path.splitext(embedding.filename) - preview_file = path + ".preview.png" - - preview = None - if os.path.isfile(preview_file): - preview = self.link_preview(preview_file) - yield { "name": embedding.name, "filename": embedding.filename, - "preview": preview, + "preview": self._find_preview(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", -- cgit v1.2.3 From 06f167da37cd00ea8241bd2a6a3c12d8c5fb9eaf Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 6 Mar 2023 21:14:06 +0200 Subject: Extra networks: support .txt description sidecar file --- modules/ui_extra_networks.py | 15 +++++++++++++++ modules/ui_extra_networks_checkpoints.py | 1 + modules/ui_extra_networks_hypernets.py | 1 + modules/ui_extra_networks_textual_inversion.py | 1 + 4 files changed, 18 insertions(+) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 1a10a5df..cd61a569 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,6 +1,7 @@ import glob import os.path import urllib.parse +from functools import lru_cache from pathlib import Path from typing import Optional @@ -131,6 +132,7 @@ class ExtraNetworksPage: "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], + "description": (item.get("description") or ""), "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), @@ -147,6 +149,19 @@ class ExtraNetworksPage: return self.link_preview(file) return None + @lru_cache(maxsize=512) + def _find_description(self, path: str) -> Optional[str]: + """ + Find and read a description file for a given path (without extension). + """ + for file in [f"{path}.txt", f"{path}.description.txt"]: + try: + with open(file, "r", encoding="utf-8", errors="replace") as f: + return f.read() + except OSError: + pass + return None + def intialize(): extra_pages.clear() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index b712d12b..1deb785a 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -20,6 +20,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "name": checkpoint.name_for_extra, "filename": path, "preview": self._find_preview(path), + "description": self._find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 89f33242..80cc2a24 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -19,6 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": self._find_preview(path), + "description": self._find_description(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index f7057390..f3bae666 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -19,6 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": self._find_preview(path), + "description": self._find_description(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", -- cgit v1.2.3 From fec0a895119a124a295e3dad5205de5766031dc7 Mon Sep 17 00:00:00 2001 From: Pam Date: Tue, 7 Mar 2023 00:33:13 +0500 Subject: scaled dot product attention --- modules/sd_hijack.py | 4 ++++ modules/sd_hijack_optimizations.py | 42 ++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 47 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 79476783..76cb9120 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -42,6 +42,10 @@ def apply_optimizations(): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' + elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c02d954c..a324a592 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -346,6 +346,48 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) +# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py +# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface +def scaled_dot_product_attention_forward(self, x, context=None, mask=None): + batch_size, sequence_length, inner_dim = x.shape + + if mask is not None: + mask = self.prepare_attention_mask(mask, sequence_length, batch_size) + mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) + + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + head_dim = inner_dim // h + q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/shared.py b/modules/shared.py index 805f9cc1..12d0756b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -69,6 +69,7 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) -- cgit v1.2.3 From f85a192f998dcc17d0a59d8755c76100e8e31bde Mon Sep 17 00:00:00 2001 From: Yea Chen Date: Tue, 7 Mar 2023 04:04:35 +0800 Subject: Update modules/api/api.py Suggested change by @akx Co-authored-by: Aarni Koskela --- modules/api/api.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 46cb7c81..12b38386 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -177,14 +177,8 @@ class Api: return script, script_idx def get_scripts_list(self): - t2ilist = [] - i2ilist = [] - - for a in scripts.scripts_txt2img.titles: - t2ilist.append(str(a.lower())) - - for b in scripts.scripts_img2img.titles: - i2ilist.append(str(b.lower())) + t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] + i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] return ScriptsList(txt2img = t2ilist, img2img = i2ilist) -- cgit v1.2.3 From 37acba263389e22bc46cfffc80b2ca8b76a85287 Mon Sep 17 00:00:00 2001 From: Pam Date: Fri, 10 Mar 2023 12:19:36 +0500 Subject: argument to disable memory efficient for sdp --- modules/sd_hijack.py | 11 ++++++++--- modules/sd_hijack_optimizations.py | 4 ++++ modules/shared.py | 1 + 3 files changed, 13 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 76cb9120..f62e9adb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -43,9 +43,14 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - optimization_method = 'sdp' + if cmd_opts.opt_sdp_no_mem_attention: + print("Applying scaled dot product cross attention optimization (without memory efficient attention).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward + optimization_method = 'sdp-no-mem' + else: + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index a324a592..68b1dd84 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -388,6 +388,10 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): hidden_states = self.to_out[1](hidden_states) return hidden_states +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return scaled_dot_product_attention_forward(self, x, context, mask) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/shared.py b/modules/shared.py index 12d0756b..4b81c591 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -70,6 +70,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") +parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="disables memory efficient sdp, makes image generation deterministic; requires --opt-sdp-attention") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) -- cgit v1.2.3 From 0981dea94832f34d638b1aa8964cfaeffd223b47 Mon Sep 17 00:00:00 2001 From: Pam Date: Fri, 10 Mar 2023 12:58:10 +0500 Subject: sdp refactoring --- modules/sd_hijack.py | 19 ++++++++++--------- modules/shared.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f62e9adb..e98ae51a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,20 +37,21 @@ def apply_optimizations(): optimization_method = None + can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' - elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): - if cmd_opts.opt_sdp_no_mem_attention: - print("Applying scaled dot product cross attention optimization (without memory efficient attention).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward - optimization_method = 'sdp-no-mem' - else: - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - optimization_method = 'sdp' + elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: + print("Applying scaled dot product cross attention optimization (without memory efficient attention).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward + optimization_method = 'sdp-no-mem' + elif cmd_opts.opt_sdp_attention and can_use_sdp: + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward diff --git a/modules/shared.py b/modules/shared.py index 4b81c591..66a6bfa5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -70,7 +70,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") -parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="disables memory efficient sdp, makes image generation deterministic; requires --opt-sdp-attention") +parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) -- cgit v1.2.3 From 8d7fa2f67cb0554d8902d5d407166876020e067e Mon Sep 17 00:00:00 2001 From: Pam Date: Fri, 10 Mar 2023 22:48:41 +0500 Subject: sdp_attnblock_forward hijack --- modules/sd_hijack.py | 2 ++ modules/sd_hijack_optimizations.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index e98ae51a..f4bb0266 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -47,10 +47,12 @@ def apply_optimizations(): elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: print("Applying scaled dot product cross attention optimization (without memory efficient attention).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward optimization_method = 'sdp-no-mem' elif cmd_opts.opt_sdp_attention and can_use_sdp: print("Applying scaled dot product cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 68b1dd84..2e307b5d 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -473,6 +473,30 @@ def xformers_attnblock_forward(self, x): except NotImplementedError: return cross_attention_attnblock_forward(self, x) +def sdp_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) + out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + +def sdp_no_mem_attnblock_forward(self, x): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return sdp_attnblock_forward(self, x) + def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) -- cgit v1.2.3 From 5fef67f6ee949a61826a3a043ea8610fd89fc371 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Mar 2023 19:56:14 -0500 Subject: Requested changes --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/sd_samplers_compvis.py | 4 +++- modules/shared.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 708a9b2b..6bb3bb21 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -93,7 +93,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=shared.opts.uni_pc_thresholding, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) return x.to(device), None diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index ad39ab2b..7d07c4a5 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -140,10 +140,12 @@ class VanillaStableDiffusionSampler: def adjust_steps_if_invalid(self, p, num_steps): if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): + if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: + num_steps = shared.opts.uni_pc_order valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 - + return num_steps def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): diff --git a/modules/shared.py b/modules/shared.py index 7c559fa4..29f8dccb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -485,10 +485,9 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), - 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), + 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}), - 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) -- cgit v1.2.3 From f261a4a53c153c630a506bc5282e9955c36b3ef2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 11:56:05 +0300 Subject: use selected device instead of always cuda for UniPC sampler --- modules/models/diffusion/uni_pc/sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 6bb3bb21..bf346ff4 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -3,7 +3,8 @@ import torch from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC -from modules import shared +from modules import shared, devices + class UniPCSampler(object): def __init__(self, model, **kwargs): @@ -16,8 +17,8 @@ class UniPCSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != devices.device: + attr = attr.to(devices.device) setattr(self, name, attr) def set_hooks(self, before_sample, after_sample, after_update): -- cgit v1.2.3 From 58b5b7c2f1d3b843803c1fc7a0aae8b1d6be5763 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 12:09:36 +0300 Subject: add UniPC options to infotext --- modules/generation_parameters_copypaste.py | 8 +++++++- modules/sd_samplers_compvis.py | 14 ++++++++++++++ modules/shared.py | 9 +++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 89dc23bf..cb367655 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -288,6 +288,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model settings_map = {} + + infotext_to_setting_name_mapping = [ ('Clip skip', 'CLIP_stop_at_last_layers', ), ('Conditional mask weight', 'inpainting_mask_weight'), @@ -296,7 +298,11 @@ infotext_to_setting_name_mapping = [ ('Noise multiplier', 'initial_noise_multiplier'), ('Eta', 'eta_ancestral'), ('Eta DDIM', 'eta_ddim'), - ('Discard penultimate sigma', 'always_discard_next_to_last_sigma') + ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'), + ('UniPC variant', 'uni_pc_variant'), + ('UniPC skip type', 'uni_pc_skip_type'), + ('UniPC order', 'uni_pc_order'), + ('UniPC lower order final', 'uni_pc_lower_order_final'), ] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 7d07c4a5..083da18c 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -129,6 +129,19 @@ class VanillaStableDiffusionSampler: if self.eta != 0.0: p.extra_generation_params["Eta DDIM"] = self.eta + if self.is_unipc: + keys = [ + ('UniPC variant', 'uni_pc_variant'), + ('UniPC skip type', 'uni_pc_skip_type'), + ('UniPC order', 'uni_pc_order'), + ('UniPC lower order final', 'uni_pc_lower_order_final'), + ] + + for name, key in keys: + v = getattr(shared.opts, key) + if v != shared.opts.get_default(key): + p.extra_generation_params[name] = v + for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) @@ -138,6 +151,7 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None + def adjust_steps_if_invalid(self, p, num_steps): if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: diff --git a/modules/shared.py b/modules/shared.py index 29f8dccb..d481c25b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -563,6 +563,15 @@ class Options: return True + def get_default(self, key): + """returns the default value for the key""" + + data_label = self.data_labels.get(key) + if data_label is None: + return None + + return data_label.default + def save(self, filename): assert not cmd_opts.freeze_settings, "saving settings is disabled" -- cgit v1.2.3 From 3531a50080e63197752dd4d9b49f0ac34a758e12 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 13:22:59 +0300 Subject: rename fields for API for saving/sending images save images to correct directories --- modules/api/api.py | 41 +++++++++++++++++------------------------ modules/api/models.py | 8 ++++---- modules/images.py | 3 --- 3 files changed, 21 insertions(+), 31 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a6bb439c..fbd50552 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -178,29 +178,27 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) - populate = txt2imgreq.copy(update={ # Override __init__ params + populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": txt2imgreq.do_not_save, - "do_not_save_grid": txt2imgreq.do_not_save, - } - ) + "do_not_save_samples": not txt2imgreq.save_images, + "do_not_save_grid": not txt2imgreq.save_images, + }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('script_name', None) - send_images = True if not 'do_not_send' in args else not args['do_not_send'] - args.pop('do_not_send', None) - args.pop('do_not_save', None) + send_images = args.pop('send_images', True) + args.pop('save_images', None) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples shared.state.begin() if script is not None: - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) else: @@ -222,13 +220,12 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), - "do_not_save_samples": img2imgreq.do_not_save, - "do_not_save_grid": img2imgreq.do_not_save, - "mask": mask - } - ) + "do_not_save_samples": not img2imgreq.save_images, + "do_not_save_grid": not img2imgreq.save_images, + "mask": mask, + }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on @@ -236,21 +233,17 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) - send_images = True if not 'do_not_send' in args else not args['do_not_send'] - args.pop('do_not_send', None) - args.pop('do_not_save', None) - - send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] - args.pop('do_not_send_images', None) + send_images = args.pop('send_images', True) + args.pop('save_images', None) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples shared.state.begin() if script is not None: - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_img2img.run(p, *p.script_args) else: diff --git a/modules/api/models.py b/modules/api/models.py index 2b66e1f0..ff3fb344 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -104,8 +104,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send", "type": bool, "default": False}, - {"key": "do_not_save", "type": bool, "default": True} + {"key": "send_images", "type": bool, "default": True}, + {"key": "save_images", "type": bool, "default": False}, ] ).generate_model() @@ -120,8 +120,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send", "type": bool, "default": False}, - {"key": "do_not_save", "type": bool, "default": True} + {"key": "send_images", "type": bool, "default": True}, + {"key": "save_images", "type": bool, "default": False}, ] ).generate_model() diff --git a/modules/images.py b/modules/images.py index f8e62b71..5b80c23e 100644 --- a/modules/images.py +++ b/modules/images.py @@ -489,9 +489,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ namegen = FilenameGenerator(p, seed, prompt, image) - if path is None: # set default path to avoid errors when functions are triggered manually or via api and param is not set - path = opts.outdir_save - if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) -- cgit v1.2.3 From aaa367e35ce4e823219c2954ca141ca1ed14800e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 14:18:18 +0300 Subject: new setting: Extra text to add before <...> when adding extra network to prompt --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index dbab0018..28d952dd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -442,6 +442,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), })) -- cgit v1.2.3 From 7f2005127ff20170e2d92f353416c7f0705c593b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 14:52:29 +0300 Subject: rename CFGDenoiserParams fields for #8064 --- modules/script_callbacks.py | 10 +++++----- modules/sd_samplers_kdiffusion.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index d1703135..07911876 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -29,7 +29,7 @@ class ImageSaveParams: class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, tensor, uncond): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): self.x = x """Latent image representation in the process of being denoised""" @@ -45,11 +45,11 @@ class CFGDenoiserParams: self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" - self.tensor = tensor - """ Encoder hidden states of conditioning""" + self.text_cond = text_cond + """ Encoder hidden states of text conditioning from prompt""" - self.uncond = uncond - """ Encoder hidden states of unconditioning""" + self.text_uncond = text_uncond + """ Encoder hidden states of text conditioning from negative prompt""" class CFGDenoisedParams: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index ea974be0..93f0e55a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -106,8 +106,8 @@ class CFGDenoiser(torch.nn.Module): x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond sigma_in = denoiser_params.sigma - tensor = denoiser_params.tensor - uncond = denoiser_params.uncond + tensor = denoiser_params.text_cond + uncond = denoiser_params.text_uncond if tensor.shape[1] == uncond.shape[1]: if not is_edit_model: -- cgit v1.2.3 From 8106117a474a5e62dbd73687dadc6e7b7637267d Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Sat, 11 Mar 2023 09:18:08 +0800 Subject: models/ui.py: make the path of script.js absolute --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 0516c643..56b55f29 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1745,7 +1745,8 @@ def create_ui(): def reload_javascript(): - head = f'\n' + script_js = os.path.join(script_path, "script.js") + head = f'\n' inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: -- cgit v1.2.3 From 8e0d16e746759b1f9b4bf1b5abfc30f3d985415e Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Sat, 11 Mar 2023 12:22:59 +0800 Subject: modules/sd_vae_approx.py: fix VAE-approx path --- modules/sd_vae_approx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 0027343a..e2f00468 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -35,8 +35,11 @@ def model(): global sd_vae_approx_model if sd_vae_approx_model is None: + model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") sd_vae_approx_model = VAEApprox() - sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) + if not os.path.exists(model_path): + model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") + sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) sd_vae_approx_model.eval() sd_vae_approx_model.to(devices.device, devices.dtype) -- cgit v1.2.3 From ce68ab8d0dfdd25e820c58dbc9d3b0148a2022a4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 15:27:42 +0300 Subject: remove underscores from function names in #8366 remove LRU from #8366 because I don't know why it's there --- modules/ui_extra_networks.py | 8 +++----- modules/ui_extra_networks_checkpoints.py | 4 ++-- modules/ui_extra_networks_hypernets.py | 4 ++-- modules/ui_extra_networks_textual_inversion.py | 4 ++-- 4 files changed, 9 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index cd61a569..21c52287 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,9 +1,7 @@ import glob import os.path import urllib.parse -from functools import lru_cache from pathlib import Path -from typing import Optional from modules import shared import gradio as gr @@ -140,17 +138,17 @@ class ExtraNetworksPage: return self.card_page.format(**args) - def _find_preview(self, path: str) -> Optional[str]: + def find_preview(self, path): """ Find a preview PNG for a given path (without extension) and call link_preview on it. """ for file in [path + ".png", path + ".preview.png"]: if os.path.isfile(file): return self.link_preview(file) + return None - @lru_cache(maxsize=512) - def _find_description(self, path: str) -> Optional[str]: + def find_description(self, path): """ Find and read a description file for a given path (without extension). """ diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 1deb785a..7d1aa203 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -19,8 +19,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): yield { "name": checkpoint.name_for_extra, "filename": path, - "preview": self._find_preview(path), - "description": self._find_description(path), + "preview": self.find_preview(path), + "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 80cc2a24..8e49e93b 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -18,8 +18,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): yield { "name": name, "filename": path, - "preview": self._find_preview(path), - "description": self._find_description(path), + "preview": self.find_preview(path), + "description": self.find_description(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index f3bae666..1ad806eb 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -18,8 +18,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): yield { "name": embedding.name, "filename": embedding.filename, - "preview": self._find_preview(path), - "description": self._find_description(path), + "preview": self.find_preview(path), + "description": self.find_description(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", -- cgit v1.2.3 From 9320139bd8340e8b12c178fe80411c8e25f78d9e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 15:33:24 +0300 Subject: support three extensions for preview instead of one: png, jpg, webp --- modules/ui_extra_networks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 21c52287..3b476f3a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -142,7 +142,11 @@ class ExtraNetworksPage: """ Find a preview PNG for a given path (without extension) and call link_preview on it. """ - for file in [path + ".png", path + ".preview.png"]: + + preview_extensions = ["png", "jpg", "webp"] + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) + + for file in potential_files: if os.path.isfile(file): return self.link_preview(file) -- cgit v1.2.3 From 6da2027213c3bf132c54489d34c48ec084f8dc11 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 15:46:20 +0300 Subject: save previews for extra networks in the selected format --- modules/ui_extra_networks.py | 3 +++ modules/ui_extra_networks_checkpoints.py | 2 +- modules/ui_extra_networks_hypernets.py | 2 +- modules/ui_extra_networks_textual_inversion.py | 4 ++-- 4 files changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 3b476f3a..85f0af4c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -144,6 +144,9 @@ class ExtraNetworksPage: """ preview_extensions = ["png", "jpg", "webp"] + if shared.opts.samples_format not in preview_extensions: + preview_extensions.append(shared.opts.samples_format) + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) for file in potential_files: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 7d1aa203..a17aa9c9 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -23,7 +23,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', - "local_preview": path + ".png", + "local_preview": f"{path}.{shared.opts.samples_format}", } def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8e49e93b..6187e000 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -22,7 +22,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), - "local_preview": path + ".png", + "local_preview": f"{path}.preview.{shared.opts.samples_format}", } def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 1ad806eb..6944d559 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,7 +1,7 @@ import json import os -from modules import ui_extra_networks, sd_hijack +from modules import ui_extra_networks, sd_hijack, shared class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): @@ -22,7 +22,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), - "local_preview": path + ".preview.png", + "local_preview": f"{path}.preview.{shared.opts.samples_format}", } def allowed_directories_for_previews(self): -- cgit v1.2.3 From 7fd19fa4e7039746dc98990883acfa500b90b6c7 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sat, 11 Mar 2023 07:22:22 -0800 Subject: initial fix for filename length limits on *nix systems --- modules/images.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 7df2b08c..4c204fca 100644 --- a/modules/images.py +++ b/modules/images.py @@ -512,6 +512,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file_decoration = "-" + file_decoration file_decoration = namegen.apply(file_decoration) + suffix + if hasattr(os, 'statvfs'): + max_name_len = os.statvfs(path).f_namemax + file_decoration = file_decoration[:max_name_len - 5] if add_number: basecount = get_next_sequence_number(path, basename) -- cgit v1.2.3 From 2174f58daee1e077eec1125e196d34cc93dbaf23 Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sat, 11 Mar 2023 12:21:33 -0500 Subject: Changed alwayson_script_name and alwayson_script_args api params to 1 alwayson_scripts param dict --- modules/api/api.py | 23 +++++++---------------- modules/api/models.py | 4 ++-- 2 files changed, 9 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 248922d2..8a17017b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -195,22 +195,17 @@ class Api: script_args[0] = 0 # Now check for always on scripts - if request.alwayson_script_name and (len(request.alwayson_script_name) > 0): - # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args - if not request.alwayson_script_args: - raise HTTPException(status_code=422, detail=f"Script {request.alwayson_script_name} has no arg list") - if len(request.alwayson_script_name) != len(request.alwayson_script_args): - raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") - - for alwayson_script_name, alwayson_script_args in zip(request.alwayson_script_name, request.alwayson_script_args): + if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script == None: raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") # Selectable script in always on script param check if alwayson_script.alwayson == False: raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") - if alwayson_script_args != []: - script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args + # always on script with no arg should always run so you don't really need to add them to the requests + if "args" in request.alwayson_scripts[alwayson_script_name]: + script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"] return script_args def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): @@ -226,15 +221,13 @@ class Api: "do_not_save_grid": True } ) - if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them - args.pop('alwayson_script_name', None) - args.pop('alwayson_script_args', None) + args.pop('alwayson_scripts', None) script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner) @@ -279,7 +272,6 @@ class Api: "mask": mask } ) - if populate.sampler_name: populate.sampler_index = None # prevent a warning later on @@ -287,8 +279,7 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them - args.pop('alwayson_script_name', None) - args.pop('alwayson_script_args', None) + args.pop('alwayson_scripts', None) script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner) diff --git a/modules/api/models.py b/modules/api/models.py index 86c70178..e273469d 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,13 +100,13 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_scripts", "type": dict, "default": {}}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_scripts", "type": dict, "default": {}}] ).generate_model() class TextToImageResponse(BaseModel): -- cgit v1.2.3 From 5546e71a105033989adacc2df9dfb53f81a0534c Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sat, 11 Mar 2023 12:35:20 -0500 Subject: Fixed whitespace --- modules/api/models.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 6c1c5546..4a70f440 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -113,7 +113,6 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [ {"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, -- cgit v1.2.3 From 27e319dc4f09a2f040043948e5c52965976f8491 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 21:22:52 +0300 Subject: alternative solution for #8089 --- modules/shared.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 4e229353..2fb9e3b5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -116,7 +116,10 @@ parser.add_argument("--no-download-sd-model", action='store_true', help="don't d script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) -cmd_opts = parser.parse_args() +if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: + cmd_opts = parser.parse_args() +else: + cmd_opts, _ = parser.parse_known_args() restricted_opts = { "samples_filename_pattern", -- cgit v1.2.3 From 247a34498b337798a371d69483bbcab49b5c320c Mon Sep 17 00:00:00 2001 From: Kilvoctu Date: Sat, 11 Mar 2023 13:11:26 -0600 Subject: restore text, remove 'close' don't use emojis for extra network buttons; remove 'close' --- modules/ui_extra_networks.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8786fde6..3f56bf63 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -13,10 +13,6 @@ from modules.generation_parameters_copypaste import image_from_url_text extra_pages = [] allowed_dirs = set() -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -refresh_symbol = '\U0001f504' # 🔄 -close_symbol = '\U0000274C' # ❌ def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -186,8 +182,7 @@ def create_ui(container, button, tabname): ui.pages.append(page_elem) filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) - button_refresh = gr.Button(refresh_symbol, elem_id=tabname+"_extra_refresh") - button_close = gr.Button(close_symbol, elem_id=tabname+"_extra_close") + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -198,7 +193,6 @@ def create_ui(container, button, tabname): state_visible = gr.State(value=False) button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) - button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) def refresh(): res = [] -- cgit v1.2.3 From 49bbdbe4476a7325e184f3022166c3dc02d086ba Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sat, 11 Mar 2023 14:34:56 -0500 Subject: small diff whitespace cleanup --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index abdbb6a7..35e17afc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -274,7 +274,7 @@ class Api: ui.create_ui() selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) - populate = img2imgreq.copy(update={ # Override __init__ params + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": not img2imgreq.save_images, "do_not_save_grid": not img2imgreq.save_images, -- cgit v1.2.3 From 48f4abd2e61e545104f72eb50c9ab9b100726948 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 11 Mar 2023 15:52:14 -0500 Subject: fix dims typo in unipc --- modules/models/diffusion/uni_pc/uni_pc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index df63d1bc..e9a093a2 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -719,7 +719,7 @@ class UniPC: x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 ) if x_t is None: -- cgit v1.2.3 From a4cb96d4ae82741be9f0d072a37af3ae39521379 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 11 Mar 2023 17:35:17 -0500 Subject: Remove test, use bool tensor fix by default The test isn't working correctly on macOS 13.3 and the bool tensor fix for cumsum is currently always needed anyway, so enable the fix by default. --- modules/mac_specific.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/mac_specific.py b/modules/mac_specific.py index ddcea53b..18e6ff72 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): output_dtype = kwargs.get('dtype', input.dtype) if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) - elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) @@ -45,7 +45,6 @@ if has_mps: CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) - cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) -- cgit v1.2.3 From bd67c41f5415c4c42d2383f128fe9ee656153bd0 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 12 Mar 2023 09:19:23 -0400 Subject: force refresh tqdm before close --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 2fb9e3b5..f28a12cc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -714,6 +714,7 @@ class TotalTQDM: def clear(self): if self._tqdm is not None: + self._tqdm.refresh() self._tqdm.close() self._tqdm = None -- cgit v1.2.3 From 27eedb696661d031b9a7b8641b50eaec8dabf64f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 17:20:04 +0300 Subject: change extension index link to the new dedicated repo instead of wiki --- modules/ui_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index bd4308ef..df75a925 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -304,7 +304,7 @@ def create_ui(): with gr.TabItem("Available"): with gr.Row(): refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") - available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False) + available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) -- cgit v1.2.3 From a00cd8b9c1e9866a58d135f3b64cc7e0f29c6d47 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 21:04:17 +0300 Subject: attempt to fix memory monitor with multiple CUDA devices --- modules/memmon.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/memmon.py b/modules/memmon.py index a7060f58..4018edcc 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread): self.data = defaultdict(int) try: - torch.cuda.mem_get_info() + self.cuda_mem_get_info() torch.cuda.memory_stats(self.device) except Exception as e: # AMD or whatever print(f"Warning: caught exception '{e}', memory monitor disabled") self.disabled = True + def cuda_mem_get_info(self): + index = self.device.index if self.device.index is not None else torch.cuda.current_device() + return torch.cuda.mem_get_info(index) + def run(self): if self.disabled: return @@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread): self.run_flag.clear() continue - self.data["min_free"] = torch.cuda.mem_get_info()[0] + self.data["min_free"] = self.cuda_mem_get_info()[0] while self.run_flag.is_set(): - free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + free, total = self.cuda_mem_get_info() self.data["min_free"] = min(self.data["min_free"], free) time.sleep(1 / self.opts.memmon_poll_rate) @@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread): def read(self): if not self.disabled: - free, total = torch.cuda.mem_get_info() + free, total = self.cuda_mem_get_info() self.data["free"] = free self.data["total"] = total -- cgit v1.2.3 From dfeee786f903e392dbef1519c7c246b9856ebab3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 21:25:22 +0300 Subject: display correct timings after restarting UI --- modules/timer.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/timer.py b/modules/timer.py index 57a4f17a..ba92be33 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -33,3 +33,6 @@ class Timer: res += ")" return res + + def reset(self): + self.__init__() -- cgit v1.2.3 From a71b7b5ec09a24e8a9bb3385e32862da905af6f1 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 12 Mar 2023 12:30:31 -0700 Subject: relocate filename length limit to better spot --- modules/images.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 4c204fca..2ce5c67c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -512,9 +512,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file_decoration = "-" + file_decoration file_decoration = namegen.apply(file_decoration) + suffix - if hasattr(os, 'statvfs'): - max_name_len = os.statvfs(path).f_namemax - file_decoration = file_decoration[:max_name_len - 5] if add_number: basecount = get_next_sequence_number(path, basename) @@ -576,6 +573,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i os.replace(temp_file_path, filename_without_extension + extension) fullfn_without_extension, extension = os.path.splitext(params.filename) + if hasattr(os, 'statvfs'): + max_name_len = os.statvfs(path).f_namemax + fullfn_without_extension = fullfn_without_extension[:max_name_len - len(extension)] + params.filename = fullfn_without_extension + extension _atomically_save_image(image, fullfn_without_extension, extension) image.already_saved_as = fullfn -- cgit v1.2.3 From 48df6d66ea627a1c538aba4d37d9141798fff4d7 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 12 Mar 2023 12:33:29 -0700 Subject: add safety check in case of short extensions so eg if a two-letter or empty extension is used, `.txt` would break, this `max` call protects that. --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 2ce5c67c..18d1de2f 100644 --- a/modules/images.py +++ b/modules/images.py @@ -575,7 +575,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i fullfn_without_extension, extension = os.path.splitext(params.filename) if hasattr(os, 'statvfs'): max_name_len = os.statvfs(path).f_namemax - fullfn_without_extension = fullfn_without_extension[:max_name_len - len(extension)] + fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))] params.filename = fullfn_without_extension + extension _atomically_save_image(image, fullfn_without_extension, extension) -- cgit v1.2.3 From af9158a8c708c2f710823b298708c51a4d8ba08f Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 12 Mar 2023 12:36:04 -0700 Subject: update `fullfn` properly --- modules/images.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 18d1de2f..2da988ee 100644 --- a/modules/images.py +++ b/modules/images.py @@ -577,6 +577,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i max_name_len = os.statvfs(path).f_namemax fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))] params.filename = fullfn_without_extension + extension + fullfn = params.filename _atomically_save_image(image, fullfn_without_extension, extension) image.already_saved_as = fullfn -- cgit v1.2.3 From 4d26c7da57a621815f25929b35977bd6a3958711 Mon Sep 17 00:00:00 2001 From: high_byte Date: Mon, 13 Mar 2023 17:37:29 +0200 Subject: initialize extra_network_data before use --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 06e7a440..59717b4c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -583,6 +583,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter + extra_network_data = None for n in range(p.n_iter): p.iteration = n @@ -712,7 +713,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - if not p.disable_extra_networks: + if not p.disable_extra_networks and extra_network_data: extra_networks.deactivate(p, extra_network_data) devices.torch_gc() -- cgit v1.2.3 From 03a80f198e8cda66cb6206cb3ae505d66d9eed9d Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 13 Mar 2023 12:35:30 -0400 Subject: add pbar to unipc --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/models/diffusion/uni_pc/uni_pc.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index bf346ff4..a241c8a7 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -71,7 +71,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for UniPC sampling is {size}') + # print(f'Data shape for UniPC sampling is {size}') device = self.model.betas.device if x_T is None: diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index e9a093a2..eb5f4e76 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F import math +from tqdm.auto import trange class NoiseScheduleVP: @@ -750,7 +751,7 @@ class UniPC: if method == 'multistep': assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") + #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) @@ -766,7 +767,7 @@ class UniPC: self.after_update(x, model_x) model_prev_list.append(model_x) t_prev_list.append(vec_t) - for step in range(order, steps + 1): + for step in trange(order, steps + 1): vec_t = timesteps[step].expand(x.shape[0]) if lower_order_final: step_order = min(order, steps + 1 - step) -- cgit v1.2.3 From c19530f1a590d758463f84523dd4c48c34d723e6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 14 Mar 2023 09:10:26 +0300 Subject: Add view metadata button for Lora cards. --- modules/sd_models.py | 24 ++++++++++++++++++++++++ modules/ui_extra_networks.py | 7 +++++++ 2 files changed, 31 insertions(+) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 93959f55..5f57ec0c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def read_metadata_from_safetensors(filename): + import json + + with open(filename, mode="rb") as file: + metadata_len = file.read(8) + metadata_len = int.from_bytes(metadata_len, "little") + json_start = file.read(2) + + assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" + json_data = json_start + file.read(metadata_len-2) + json_obj = json.loads(json_data) + + res = {} + for k, v in json_obj.get("__metadata__", {}).items(): + res[k] = v + if isinstance(v, str) and v[0] == '{': + try: + res[k] = json.loads(v) + except Exception as e: + pass + + return res + + def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 01df5e90..b5847299 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -124,6 +124,12 @@ class ExtraNetworksPage: if onclick is None: onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + metadata_button = "" + metadata = item.get("metadata") + if metadata: + metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"' + metadata_button = f"" + args = { "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "prompt": item.get("prompt", None), @@ -134,6 +140,7 @@ class ExtraNetworksPage: "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), + "metadata_button": metadata_button, } return self.card_page.format(**args) -- cgit v1.2.3 From 4281432594084bc846cc834c807ac57f59457eae Mon Sep 17 00:00:00 2001 From: willtakasan <126040162+willtakasan@users.noreply.github.com> Date: Tue, 14 Mar 2023 15:36:08 +0900 Subject: Update ui_extra_networks.py I updated it so that no error message is displayed when setting a webp for the preview image. --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index b5847299..cdfd6f2a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -30,8 +30,8 @@ def add_pages_to_demo(app): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") ext = os.path.splitext(filename)[1].lower() - if ext not in (".png", ".jpg"): - raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.") + if ext not in (".png", ".jpg", ".webp"): + raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) -- cgit v1.2.3 From 6a04a7f20fcc4a992ae017b06723e9ceffe17b37 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 14 Mar 2023 11:22:29 +0300 Subject: fix an error loading Lora with empty values in metadata --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5f57ec0c..f0cb1240 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -225,7 +225,7 @@ def read_metadata_from_safetensors(filename): res = {} for k, v in json_obj.get("__metadata__", {}).items(): res[k] = v - if isinstance(v, str) and v[0] == '{': + if isinstance(v, str) and v[0:1] == '{': try: res[k] = json.loads(v) except Exception as e: -- cgit v1.2.3