From f1975b0213f5be400889ec04b3891d1cb571fe20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:01:07 +0300 Subject: initial refiner support --- modules/sd_models.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6051604..981aa93d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return res +class SkipWritingToConfig: + """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" + + skip = False + previous = None + + def __enter__(self): + self.previous = SkipWritingToConfig.skip + SkipWritingToConfig.skip = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + SkipWritingToConfig.skip = self.previous + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) -- cgit v1.2.3 From ac8a5d18d3ede6bcb8fa5a3da1c7c28e064cd65d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 10 Aug 2023 17:04:59 +0300 Subject: resolve merge issues --- modules/sd_models.py | 7 +++++-- modules/sd_samplers_kdiffusion.py | 3 +-- modules/shared_options.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6cb2f34..a178adca 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -640,8 +640,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): timer.record("send model to device") model_data.set_sd_model(already_loaded) - shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title - shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title + shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") return model_data.sd_model elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e1854980..95a43cef 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,8 +1,7 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra -from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.shared import opts import modules.shared as shared diff --git a/modules/shared_options.py b/modules/shared_options.py index 9ae51f18..1e5b64ea 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -140,6 +140,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), + "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"), + "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { -- cgit v1.2.3