aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r--modules/sd_samplers_common.py62
1 files changed, 41 insertions, 21 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 35c4d657..07fc4434 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s
from modules.shared import opts, state
import k_diffusion.sampling
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+class SamplerData(SamplerDataTuple):
+ def total_steps(self, steps):
+ if self.options.get("second_order", False):
+ steps = steps * 2
+
+ return steps
def setup_img2img_steps(p, steps=None):
@@ -83,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None):
model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
- x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+ if len(image) > 1:
+ x_latent = torch.stack([
+ model.get_first_stage_encoding(
+ model.encode_first_stage(torch.unsqueeze(img, 0))
+ )[0]
+ for img in image
+ ])
+ else:
+ x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
@@ -131,31 +148,29 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
-def apply_refiner(sampler):
- completed_ratio = sampler.step / sampler.steps
+def apply_refiner(cfg_denoiser):
+ completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
+ refiner_switch_at = cfg_denoiser.p.refiner_switch_at
+ refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
- if completed_ratio <= shared.opts.sd_refiner_switch_at:
+ if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False
- if shared.opts.sd_refiner_checkpoint == "None":
+ if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False
- if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
+ if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass:
return False
- refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
- if refiner_checkpoint_info is None:
- raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
-
- sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
- sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
+ cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
- sampler.p.setup_conds()
- sampler.update_inner_model()
+ cfg_denoiser.p.setup_conds()
+ cfg_denoiser.update_inner_model()
return True
@@ -192,7 +207,7 @@ class Sampler:
self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None # set by the function calling the constructor
+ self.config: SamplerData = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.s_churn = 0.0
@@ -208,6 +223,7 @@ class Sampler:
self.p = None
self.model_wrap_cfg = None
self.sampler_extra_args = None
+ self.options = {}
def callback_state(self, d):
step = d['i']
@@ -220,6 +236,7 @@ class Sampler:
def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps
+ self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
state.sampling_steps = steps
state.sampling_step = 0
@@ -267,19 +284,19 @@ class Sampler:
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise)
- if s_churn != self.s_churn:
+ if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn
- if s_tmin != self.s_tmin:
+ if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin
- if s_tmax != self.s_tmax:
+ if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax
- if s_noise != self.s_noise:
+ if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise
@@ -296,5 +313,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()
-
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()