diff options
-rw-r--r-- | configs/instruct-pix2pix.yaml | 99 | ||||
-rw-r--r-- | configs/v1-inpainting-inference.yaml (renamed from v2-inference-v.yaml) | 36 | ||||
-rw-r--r-- | modules/api/api.py | 5 | ||||
-rw-r--r-- | modules/api/models.py | 2 | ||||
-rw-r--r-- | modules/devices.py | 14 | ||||
-rw-r--r-- | modules/processing.py | 25 | ||||
-rw-r--r-- | modules/scripts.py | 32 | ||||
-rw-r--r-- | modules/scripts_auto_postprocessing.py | 42 | ||||
-rw-r--r-- | modules/scripts_postprocessing.py | 11 | ||||
-rw-r--r-- | modules/sd_hijack_inpainting.py | 9 | ||||
-rw-r--r-- | modules/sd_hijack_utils.py | 2 | ||||
-rw-r--r-- | modules/sd_models.py | 236 | ||||
-rw-r--r-- | modules/sd_models_config.py | 68 | ||||
-rw-r--r-- | modules/sd_samplers.py | 2 | ||||
-rw-r--r-- | modules/shared.py | 20 | ||||
-rw-r--r-- | modules/shared_items.py | 23 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 2 | ||||
-rw-r--r-- | modules/timer.py | 35 | ||||
-rw-r--r-- | modules/ui_components.py | 8 | ||||
-rw-r--r-- | scripts/postprocessing_upscale.py | 25 | ||||
-rw-r--r-- | style.css | 6 | ||||
-rw-r--r-- | webui-macos-env.sh | 2 |
22 files changed, 518 insertions, 186 deletions
diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml new file mode 100644 index 00000000..437ddcef --- /dev/null +++ b/configs/instruct-pix2pix.yaml @@ -0,0 +1,99 @@ +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +model: + base_learning_rate: 1.0e-04 + target: modules.models.diffusion.ddpm_edit.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: edited + cond_stage_key: edit + # image_size: 64 + # image_size: 32 + image_size: 16 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: true + load_ema: true + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 0 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 128 + num_workers: 1 + wrap: false + validation: + target: edit_dataset.EditDataset + params: + path: data/clip-filtered-dataset + cache_dir: data/ + cache_name: data_10k + split: val + min_text_sim: 0.2 + min_image_sim: 0.75 + min_direction_sim: 0.2 + max_samples_per_prompt: 1 + min_resize_res: 512 + max_resize_res: 512 + crop_res: 512 + output_as_edit: False + real_input: True diff --git a/v2-inference-v.yaml b/configs/v1-inpainting-inference.yaml index 513cd635..f9eec37d 100644 --- a/v2-inference-v.yaml +++ b/configs/v1-inpainting-inference.yaml @@ -1,8 +1,7 @@ model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion params: - parameterization: "v" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 @@ -12,29 +11,36 @@ model: cond_stage_key: "txt" image_size: 64 channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: - use_checkpoint: True - use_fp16: True image_size: 32 # unused - in_channels: 4 + in_channels: 9 # 4 data + 4 downscaled image + 1 mask out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn + num_heads: 8 use_spatial_transformer: True - use_linear_in_transformer: True transformer_depth: 1 - context_dim: 1024 + context_dim: 768 + use_checkpoint: True legacy: False first_stage_config: @@ -43,7 +49,6 @@ model: embed_dim: 4 monitor: val/rec_loss ddconfig: - #attn_type: "vanilla-xformers" double_z: true z_channels: 4 resolution: 256 @@ -62,7 +67,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate"
\ No newline at end of file + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/modules/api/api.py b/modules/api/api.py index 25c65e57..eb7b1da5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, find_checkpoint_config +from modules.sd_models import checkpoints_list +from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -387,7 +388,7 @@ class Api: ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index 805bd8f7..cba43d3b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -228,7 +228,7 @@ class SDModelItem(BaseModel): hash: Optional[str] = Field(title="Short hash") sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") - config: str = Field(title="Config file") + config: Optional[str] = Field(title="Config file") class HypernetworkItem(BaseModel): name: str = Field(title="Name") diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..4687944e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" - return cpu + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): @@ -139,6 +143,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) diff --git a/modules/processing.py b/modules/processing.py index 9e5a2f38..262806a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -172,7 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
@@ -185,7 +185,12 @@ class StableDiffusionProcessing: conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
+ def edit_image_conditioning(self, source_image):
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+
+ return conditioning_image
+
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -212,7 +217,7 @@ class StableDiffusionProcessing: )
# Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -228,6 +233,9 @@ class StableDiffusionProcessing: if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
+ if self.sd_model.cond_stage_key == "edit":
+ return self.edit_image_conditioning(source_image)
+
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
@@ -409,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x)
+ x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
return x
@@ -650,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample)
+ if p.scripts is not None:
+ pp = scripts.PostprocessImageArgs(image)
+ p.scripts.postprocess_image(p, pp)
+ image = pp.image
+
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
@@ -993,7 +1006,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
+ image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..6e9dc0c0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,12 +6,16 @@ from collections import namedtuple import gradio as gr
-from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object()
+class PostprocessImageArgs:
+ def __init__(self, image):
+ self.image = image
+
+
class Script:
filename = None
args_from = None
@@ -65,7 +69,7 @@ class Script: args contains all values returned by components from ui()
"""
- raise NotImplementedError()
+ pass
def process(self, p, *args):
"""
@@ -100,6 +104,13 @@ class Script: pass
+ def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -247,11 +258,15 @@ class ScriptRunner: self.infotext_fields = []
def initialize_scripts(self, is_img2img):
+ from modules import scripts_auto_postprocessing
+
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
- for script_class, path, basedir, script_module in scripts_data:
+ auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
+
+ for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
script = script_class()
script.filename = path
script.is_txt2img = not is_img2img
@@ -332,7 +347,7 @@ class ScriptRunner: return inputs
- def run(self, p: StableDiffusionProcessing, *args):
+ def run(self, p, *args):
script_index = args[0]
if script_index == 0:
@@ -386,6 +401,15 @@ class ScriptRunner: print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def postprocess_image(self, p, pp: PostprocessImageArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_image(p, pp, *script_args)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py new file mode 100644 index 00000000..30d6d658 --- /dev/null +++ b/modules/scripts_auto_postprocessing.py @@ -0,0 +1,42 @@ +from modules import scripts, scripts_postprocessing, shared
+
+
+class ScriptPostprocessingForMainUI(scripts.Script):
+ def __init__(self, script_postproc):
+ self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
+ self.postprocessing_controls = None
+
+ def title(self):
+ return self.script.name
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img):
+ self.postprocessing_controls = self.script.ui()
+ return self.postprocessing_controls.values()
+
+ def postprocess_image(self, p, script_pp, *args):
+ args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
+
+ pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
+ pp.info = {}
+ self.script.process(pp, **args_dict)
+ p.extra_generation_params.update(pp.info)
+ script_pp.image = pp.image
+
+
+def create_auto_preprocessing_script_data():
+ from modules import scripts
+
+ res = []
+
+ for name in shared.opts.postprocessing_enable_in_main_ui:
+ script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
+ if script is None:
+ continue
+
+ constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
+ res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
+
+ return res
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 25de02d0..ce0ebb61 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -46,6 +46,8 @@ class ScriptPostprocessing: pass
+
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
res = func(*args, **kwargs)
@@ -68,6 +70,9 @@ class ScriptPostprocessingRunner: script: ScriptPostprocessing = script_class()
script.filename = path
+ if script.name == "Simple Upscale":
+ continue
+
self.scripts.append(script)
def create_script_ui(self, script, inputs):
@@ -87,12 +92,11 @@ class ScriptPostprocessingRunner: import modules.scripts
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
- scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
+ scripts_order = shared.opts.postprocessing_operation_order
def script_score(name):
- name = name.lower()
for i, possible_match in enumerate(scripts_order):
- if possible_match in name:
+ if possible_match == name:
return i
return len(self.scripts)
@@ -145,3 +149,4 @@ class ScriptPostprocessingRunner: def image_changed(self):
for script in self.scripts_in_preferred_order():
script.image_changed()
+
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 31d2c898..478cd499 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def should_hijack_inpainting(checkpoint_info): - from modules import sd_models - - ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() - - return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename - - def do_inpainting_hijack(): # p_sample_plms is needed because PLMS can't work with dicts as conditionings diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f81b169a..f8684475 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -5,7 +5,7 @@ class CondFunc: self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
- for i in range(len(func_path)-2, -1, -1):
+ for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..37dad18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path
import sys
import gc
-import time
-from collections import namedtuple
import torch
import re
import safetensors.torch
@@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.paths import models_path
-from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
-from modules.sd_hijack_ip2p import should_hijack_ip2p
+from modules.sd_hijack_inpainting import do_inpainting_hijack
+from modules.timer import Timer
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
-def find_checkpoint_config(info):
- if info is None:
- return shared.cmd_opts.config
-
- config = os.path.splitext(info.filename)[0] + ".yaml"
- if os.path.exists(config):
- return config
-
- return shared.cmd_opts.config
-
-
def list_models():
checkpoints_list.clear()
checkpoint_alisases.clear()
@@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
- device = map_location or shared.weight_load_location
- if device is None:
- device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+ device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd
-def load_model_weights(model, checkpoint_info: CheckpointInfo):
+def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
+ sd_model_hash = checkpoint_info.calculate_shorthash()
+ timer.record("calculate hash")
+
+ if checkpoint_info in checkpoints_loaded:
+ # use checkpoint cache
+ print(f"Loading weights [{sd_model_hash}] from cache")
+ return checkpoints_loaded[checkpoint_info]
+
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
+ res = read_state_dict(checkpoint_info.filename)
+ timer.record("load weights from disk")
+
+ return res
+
+
+def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
+ timer.record("calculate hash")
+
if checkpoint_info.title != title:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
- cache_enabled = shared.opts.sd_checkpoint_cache > 0
+ if state_dict is None:
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
- if cache_enabled and checkpoint_info in checkpoints_loaded:
- # use checkpoint cache
- print(f"Loading weights [{sd_model_hash}] from cache")
- model.load_state_dict(checkpoints_loaded[checkpoint_info])
- else:
- # load from file
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
+ model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ timer.record("apply weights to model")
- sd = read_state_dict(checkpoint_info.filename)
- model.load_state_dict(sd, strict=False)
- del sd
-
- if cache_enabled:
- # cache newly loaded model
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ if shared.opts.sd_checkpoint_cache > 0:
+ # cache newly loaded model
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+
+ if shared.cmd_opts.opt_channelslast:
+ model.to(memory_format=torch.channels_last)
+ timer.record("apply channels_last")
- if shared.cmd_opts.opt_channelslast:
- model.to(memory_format=torch.channels_last)
+ if not shared.cmd_opts.no_half:
+ vae = model.first_stage_model
+ depth_model = getattr(model, 'depth_model', None)
- if not shared.cmd_opts.no_half:
- vae = model.first_stage_model
- depth_model = getattr(model, 'depth_model', None)
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+ if shared.cmd_opts.no_half_vae:
+ model.first_stage_model = None
+ # with --upcast-sampling, don't convert the depth model weights to float16
+ if shared.cmd_opts.upcast_sampling and depth_model:
+ model.depth_model = None
- # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
- if shared.cmd_opts.no_half_vae:
- model.first_stage_model = None
- # with --upcast-sampling, don't convert the depth model weights to float16
- if shared.cmd_opts.upcast_sampling and depth_model:
- model.depth_model = None
+ model.half()
+ model.first_stage_model = vae
+ if depth_model:
+ model.depth_model = depth_model
- model.half()
- model.first_stage_model = vae
- if depth_model:
- model.depth_model = depth_model
+ timer.record("apply half()")
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- devices.dtype_unet = model.model.diffusion_model.dtype
- devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
- model.first_stage_model.to(devices.dtype_vae)
+ model.first_stage_model.to(devices.dtype_vae)
+ timer.record("apply dtype to VAE")
# clean up cache if limit is reached
- if cache_enabled:
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
- checkpoints_loaded.popitem(last=False) # LRU
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False)
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_info.filename
@@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source)
+ timer.record("load VAE")
def enable_midas_autodownload():
@@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper
-class Timer:
- def __init__(self):
- self.start = time.time()
+def repair_config(sd_config):
+
+ if not hasattr(sd_config.model.params, "use_ema"):
+ sd_config.model.params.use_ema = False
- def elapsed(self):
- end = time.time()
- res = end - self.start
- self.start = end
- return res
+ if shared.cmd_opts.no_half:
+ sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
-def load_model(checkpoint_info=None):
+def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
- checkpoint_config = find_checkpoint_config(checkpoint_info)
-
- if checkpoint_config != shared.cmd_opts.config:
- print(f"Loading config from: {checkpoint_config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect()
devices.torch_gc()
- sd_config = OmegaConf.load(checkpoint_config)
-
- if should_hijack_inpainting(checkpoint_info):
- # Hardcoded config for now...
- sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
- sd_config.model.params.conditioning_key = "hybrid"
- sd_config.model.params.unet_config.params.in_channels = 9
- sd_config.model.params.finetune_keys = None
-
- if should_hijack_ip2p(checkpoint_info):
- sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
- sd_config.model.params.conditioning_key = "hybrid"
- sd_config.model.params.first_stage_key = "edited"
- sd_config.model.params.cond_stage_key = "edit"
- sd_config.model.params.image_size = 16
- sd_config.model.params.unet_config.params.in_channels = 8
- sd_config.model.params.unet_config.params.out_channels = 4
+ do_inpainting_hijack()
- if not hasattr(sd_config.model.params, "use_ema"):
- sd_config.model.params.use_ema = False
+ timer = Timer()
- do_inpainting_hijack()
+ if already_loaded_state_dict is not None:
+ state_dict = already_loaded_state_dict
+ else:
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
- if shared.cmd_opts.no_half:
- sd_config.model.params.unet_config.params.use_fp16 = False
- elif shared.cmd_opts.upcast_sampling:
- sd_config.model.params.unet_config.params.use_fp16 = True
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|