diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-06 14:01:07 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-06 14:01:07 +0000 |
commit | f1975b0213f5be400889ec04b3891d1cb571fe20 (patch) | |
tree | 874e4bd221209a5197f1f578f907cdc28b33a6b7 /modules/sd_samplers_common.py | |
parent | 57e8a11d17a6646fdf551320f5f714fba752987a (diff) | |
download | stable-diffusion-webui-gfx803-f1975b0213f5be400889ec04b3891d1cb571fe20.tar.gz stable-diffusion-webui-gfx803-f1975b0213f5be400889ec04b3891d1cb571fe20.tar.bz2 stable-diffusion-webui-gfx803-f1975b0213f5be400889ec04b3891d1cb571fe20.zip |
initial refiner support
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 39586b40..3f3e83e3 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np
import torch
from PIL import Image
-from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -127,3 +127,20 @@ def replace_torchsde_browinan(): replace_torchsde_browinan()
+
+
+def apply_refiner(sampler):
+ completed_ratio = sampler.step / sampler.steps
+ if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint:
+ refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
+ if refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
+
+ with sd_models.SkipWritingToConfig():
+ sd_models.reload_model_weights(info=refiner_checkpoint_info)
+
+ devices.torch_gc()
+
+ sampler.update_inner_model()
+
+ sampler.p.setup_conds()
|