diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-14 08:46:27 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-14 08:46:27 +0000 |
commit | 80adb6979d46bbb832254004cac4f4f9bec9efb3 (patch) | |
tree | e206ee60f9be21e9e20d483213b7d0a610d2bdbd /modules/sd_models.py | |
parent | 1dcd6723242c3d691610f9ed937951baea49c2d1 (diff) | |
parent | 3ddc76342298ad0b2d14cb571ceb48c0b0c4176d (diff) | |
download | stable-diffusion-webui-gfx803-80adb6979d46bbb832254004cac4f4f9bec9efb3.tar.gz stable-diffusion-webui-gfx803-80adb6979d46bbb832254004cac4f4f9bec9efb3.tar.bz2 stable-diffusion-webui-gfx803-80adb6979d46bbb832254004cac4f4f9bec9efb3.zip |
Merge branch 'dev' into find_vae
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 97 |
1 files changed, 72 insertions, 25 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 4f7613a1..4c9a0a1f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,8 @@ import collections import os.path
import sys
import gc
+import threading
+
import torch
import re
import safetensors.torch
@@ -13,9 +15,9 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -45,7 +47,7 @@ class CheckpointInfo: self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
@@ -67,7 +69,7 @@ class CheckpointInfo: checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
if self.sha256 is None:
return
@@ -85,8 +87,7 @@ class CheckpointInfo: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
@@ -165,7 +166,7 @@ def model_hash(filename): def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
-
+
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -237,7 +238,7 @@ def read_metadata_from_safetensors(filename): if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
- except Exception as e:
+ except Exception:
pass
return res
@@ -372,7 +373,7 @@ def enable_midas_autodownload(): if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
-
+
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
@@ -404,13 +405,39 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
+
+class SdModelData:
+ def __init__(self):
+ self.sd_model = None
+ self.lock = threading.Lock()
+
+ def get_sd_model(self):
+ if self.sd_model is None:
+ with self.lock:
+ try:
+ load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load", file=sys.stderr)
+ self.sd_model = None
+
+ return self.sd_model
+
+ def set_sd_model(self, v):
+ self.sd_model = v
+
+
+model_data = SdModelData()
+
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
- if shared.sd_model:
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
gc.collect()
devices.torch_gc()
@@ -439,7 +466,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
+ except Exception:
pass
if sd_model is None:
@@ -464,7 +491,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ timer.record("hijack")
sd_model.eval()
- shared.sd_model = sd_model
+ model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -484,7 +511,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_info = info or select_checkpoint()
if not sd_model:
- sd_model = shared.sd_model
+ sd_model = model_data.sd_model
if sd_model is None: # previous model load failed
current_checkpoint_info = None
@@ -512,11 +539,11 @@ def reload_model_weights(sd_model=None, info=None): del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
- return shared.sd_model
+ return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
+ except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
@@ -535,17 +562,15 @@ def reload_model_weights(sd_model=None, info=None): return sd_model
+
def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
- if shared.sd_model:
-
- # shared.sd_model.cond_stage_model.to(devices.cpu)
- # shared.sd_model.first_stage_model.to(devices.cpu)
- shared.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ model_data.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
@@ -554,3 +579,25 @@ def unload_model_weights(sd_model=None, info=None): print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, hr: bool):
+ """
+ Applies speed and memory optimizations from tomesd.
+
+ Args:
+ hr (bool): True if called in the context of a high-res pass
+ """
+
+ ratio = shared.opts.token_merging_ratio
+ if hr:
+ ratio = shared.opts.token_merging_ratio_hr
+
+ tomesd.apply_patch(
+ sd_model,
+ ratio=ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
|