diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-05-21 14:37:09 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-05-21 14:37:09 +0000 |
commit | 1f3182924ba8e70d0e0fc3ed270782f324376ba3 (patch) | |
tree | 27a9e5167e5b981dfe56f5084ea8e1e8743f3fc0 /modules/sd_models.py | |
parent | 89f9faa63388756314e8a1d96cf86bf5e0663045 (diff) | |
parent | fdaf0147b6d2a5f599464bb7c65817ef5832eff1 (diff) | |
download | stable-diffusion-webui-gfx803-1f3182924ba8e70d0e0fc3ed270782f324376ba3.tar.gz stable-diffusion-webui-gfx803-1f3182924ba8e70d0e0fc3ed270782f324376ba3.tar.bz2 stable-diffusion-webui-gfx803-1f3182924ba8e70d0e0fc3ed270782f324376ba3.zip |
Merge branch 'dev' into release_candidate
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 53 |
1 files changed, 42 insertions, 11 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 36f643e1..b1afbaa7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,9 +15,9 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -87,8 +87,7 @@ class CheckpointInfo: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
@@ -99,7 +98,6 @@ def setup_model(): if not os.path.exists(model_path):
os.makedirs(model_path)
- list_models()
enable_midas_autodownload()
@@ -167,7 +165,7 @@ def model_hash(filename): def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
-
+
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -239,7 +237,7 @@ def read_metadata_from_safetensors(filename): if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
- except Exception as e:
+ except Exception:
pass
return res
@@ -374,7 +372,7 @@ def enable_midas_autodownload(): if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
-
+
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
@@ -410,11 +408,18 @@ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_w class SdModelData:
def __init__(self):
self.sd_model = None
+ self.was_loaded_at_least_once = False
self.lock = threading.Lock()
def get_sd_model(self):
+ if self.was_loaded_at_least_once:
+ return self.sd_model
+
if self.sd_model is None:
with self.lock:
+ if self.sd_model is not None or self.was_loaded_at_least_once:
+ return self.sd_model
+
try:
load_model()
except Exception as e:
@@ -467,7 +472,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
+ except Exception:
pass
if sd_model is None:
@@ -493,6 +498,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model.eval()
model_data.sd_model = sd_model
+ model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -538,13 +544,12 @@ def reload_model_weights(sd_model=None, info=None): if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
- checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
+ except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
@@ -565,7 +570,7 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
@@ -580,3 +585,29 @@ def unload_model_weights(sd_model=None, info=None): print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, token_merging_ratio):
+ """
+ Applies speed and memory optimizations from tomesd.
+ """
+
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+
+ if current_token_merging_ratio == token_merging_ratio:
+ return
+
+ if current_token_merging_ratio > 0:
+ tomesd.remove_patch(sd_model)
+
+ if token_merging_ratio > 0:
+ tomesd.apply_patch(
+ sd_model,
+ ratio=token_merging_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+
+ sd_model.applied_token_merged_ratio = token_merging_ratio
|