diff options
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 97bc0804..35c4d657 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -3,7 +3,7 @@ from collections import namedtuple import numpy as np
import torch
from PIL import Image
-from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling
@@ -131,6 +131,35 @@ def replace_torchsde_browinan(): replace_torchsde_browinan()
+def apply_refiner(sampler):
+ completed_ratio = sampler.step / sampler.steps
+
+ if completed_ratio <= shared.opts.sd_refiner_switch_at:
+ return False
+
+ if shared.opts.sd_refiner_checkpoint == "None":
+ return False
+
+ if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
+ return False
+
+ refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
+ if refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
+
+ sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
+
+ with sd_models.SkipWritingToConfig():
+ sd_models.reload_model_weights(info=refiner_checkpoint_info)
+
+ devices.torch_gc()
+ sampler.p.setup_conds()
+ sampler.update_inner_model()
+
+ return True
+
+
class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.
@@ -176,8 +205,9 @@ class Sampler: self.conditioning_key = shared.sd_model.model.conditioning_key
- self.model_wrap = None
+ self.p = None
self.model_wrap_cfg = None
+ self.sampler_extra_args = None
def callback_state(self, d):
step = d['i']
@@ -189,6 +219,7 @@ class Sampler: shared.total_tqdm.update()
def launch_sampling(self, steps, func):
+ self.model_wrap_cfg.steps = steps
state.sampling_steps = steps
state.sampling_step = 0
@@ -208,6 +239,8 @@ class Sampler: return p.steps
def initialize(self, p) -> dict:
+ self.p = p
+ self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
|