aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--modules/launch_utils.py20
-rw-r--r--modules/lowvram.py51
-rw-r--r--modules/paths.py8
-rw-r--r--modules/processing.py7
-rw-r--r--modules/prompt_parser.py87
-rw-r--r--modules/sd_hijack.py30
-rw-r--r--modules/sd_hijack_clip.py16
-rw-r--r--modules/sd_hijack_open_clip.py34
-rw-r--r--modules/sd_hijack_optimizations.py51
-rw-r--r--modules/sd_models.py22
-rw-r--r--modules/sd_models_config.py7
-rw-r--r--modules/sd_models_xl.py59
-rw-r--r--modules/sd_samplers_kdiffusion.py45
-rw-r--r--modules/shared.py2
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
16 files changed, 383 insertions, 58 deletions
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 0e0dbca4..aa9d1880 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -224,6 +224,20 @@ def run_extensions_installers(settings_file):
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
+def mute_sdxl_imports():
+ """create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
+
+ import importlib
+
+ module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None))
+ module.LPIPS = None
+ sys.modules['taming.modules.losses.lpips'] = module
+
+ module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None))
+ module.StableDataModuleFromConfig = None
+ sys.modules['sgm.data'] = module
+
+
def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
@@ -235,11 +249,13 @@ def prepare_environment():
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
+ stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
+ stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -297,6 +313,7 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
+ git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
@@ -316,11 +333,14 @@ def prepare_environment():
if args.update_all_extensions:
git_pull_recursive(extensions_dir)
+ mute_sdxl_imports()
+
if "--exit" in sys.argv:
print("Exiting because of --exit argument")
exit(0)
+
def configure_for_tests():
if "--api" not in sys.argv:
sys.argv.append("--api")
diff --git a/modules/lowvram.py b/modules/lowvram.py
index d95bcfbf..da4f33a8 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
- # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
-
- # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
- # send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
+ to_remain_in_cpu = [
+ (sd_model, 'first_stage_model'),
+ (sd_model, 'depth_model'),
+ (sd_model, 'embedder'),
+ (sd_model, 'model'),
+ (sd_model, 'embedder'),
+ ]
+
+ is_sdxl = hasattr(sd_model, 'conditioner')
+ is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
+
+ if is_sdxl:
+ to_remain_in_cpu.append((sd_model, 'conditioner'))
+ elif is_sd2:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
+ else:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
+
+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
+ stored = []
+ for obj, field in to_remain_in_cpu:
+ module = getattr(obj, field, None)
+ stored.append(module)
+ setattr(obj, field, None)
+
+ # send the model to GPU.
sd_model.to(devices.device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
+
+ # put modules back. the modules will be in CPU.
+ for (obj, field), module in zip(to_remain_in_cpu, stored):
+ setattr(obj, field, module)
# register hooks for those the first three models
- sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+ if is_sdxl:
+ sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
+ elif is_sd2:
+ sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
+ else:
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
@@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
- del sd_model.cond_stage_model.transformer
-
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff --git a/modules/paths.py b/modules/paths.py
index bada804e..1100a8dc 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -20,6 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
+ (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
@@ -35,6 +36,13 @@ for d, must_exist, what, options in path_dirs:
d = os.path.abspath(d)
if "atstart" in options:
sys.path.insert(0, d)
+ elif "sgm" in options:
+ # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
+ # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
+
+ sys.path.insert(0, d)
+ import sgm
+ sys.path.pop(0)
else:
sys.path.append(d)
paths[what] = d
diff --git a/modules/processing.py b/modules/processing.py
index cd568a20..85d35423 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -343,10 +343,13 @@ class StableDiffusionProcessing:
return cache[1]
def setup_conds(self):
+ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
+ negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height)
+
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 0069d8b0..33810669 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import re
from collections import namedtuple
from typing import List
@@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-def get_learned_conditioning(model, prompts, steps):
+class SdConditioning(list):
+ """
+ A list with prompts for stable diffusion's conditioner model.
+ Can also specify width and height of created image - SDXL needs it.
+ """
+ def __init__(self, prompts, width=None, height=None):
+ super().__init__()
+ self.extend(prompts)
+ self.width = width or getattr(prompts, 'width', None)
+ self.height = height or getattr(prompts, 'height', None)
+
+
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -144,7 +158,12 @@ def get_learned_conditioning(model, prompts, steps):
cond_schedule = []
for i, (end_at_step, _) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+ if isinstance(conds, dict):
+ cond = {k: v[i] for k, v in conds.items()}
+ else:
+ cond = conds[i]
+
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
cache[prompt] = cond_schedule
res.append(cond_schedule)
@@ -155,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps):
re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-def get_multicond_prompt_list(prompts):
+
+def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = []
- prompt_flat_list = []
prompt_indexes = {}
+ prompt_flat_list = SdConditioning(prompts)
+ prompt_flat_list.clear()
for prompt in prompts:
subprompts = re_AND.split(prompt)
@@ -196,6 +217,7 @@ class MulticondLearnedConditioning:
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@@ -214,20 +236,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+class DictWithShape(dict):
+ def __init__(self, x, shape):
+ super().__init__()
+ self.update(x)
+
+ @property
+ def shape(self):
+ return self["crossattn"].shape
+
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
- res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ is_dict = isinstance(param, dict)
+
+ if is_dict:
+ dict_cond = param
+ res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
+ res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
+ else:
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+
for i, cond_schedule in enumerate(c):
target_index = 0
for current, entry in enumerate(cond_schedule):
if current_step <= entry.end_at_step:
target_index = current
break
- res[i] = cond_schedule[target_index].cond
+
+ if is_dict:
+ for k, param in cond_schedule[target_index].cond.items():
+ res[k][i] = param
+ else:
+ res[i] = cond_schedule[target_index].cond
return res
+def stack_conds(tensors):
+ # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
+ return torch.stack(tensors)
+
+
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
@@ -249,16 +308,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
- # and won't be able to torch.stack them. So this fixes that.
- token_count = max([x.shape[0] for x in tensors])
- for i in range(len(tensors)):
- if tensors[i].shape[0] != token_count:
- last_vector = tensors[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
- tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+ if isinstance(tensors[0], dict):
+ keys = list(tensors[0].keys())
+ stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
+ stacked = DictWithShape(stacked, stacked['crossattn'].shape)
+ else:
+ stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
- return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+ return conds_list, stacked
re_attention = re.compile(r"""
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 3b6f95ce..266811f9 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+import sgm.modules.diffusionmodules.openaimodel
+import sgm.modules.encoders.modules
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
@@ -56,6 +61,9 @@ def apply_optimizations(option=None):
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+ sgm.modules.diffusionmodules.model.nonlinearity = silu
+ sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
@@ -89,6 +97,10 @@ def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
@@ -166,6 +178,24 @@ class StableDiffusionModelHijack:
undo_optimizations()
def hijack(self, m):
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ typename = type(embedder).__name__
+ if typename == 'FrozenOpenCLIPEmbedder':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+ conditioner.embedders[i] = m.cond_stage_model
+ if typename == 'FrozenCLIPEmbedder':
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self)
+ conditioner.embedders[i] = m.cond_stage_model
+ if typename == 'FrozenOpenCLIPEmbedder2':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 3b5a7666..6c17a81d 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75
+ self.is_trainable = getattr(wrapped, 'is_trainable', False)
+ self.input_key = getattr(wrapped, 'input_key', 'txt')
+ self.legacy_ucg_val = None
+
def empty_chunk(self):
"""creates an empty PromptChunk and returns it"""
@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
"""
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
An example shape returned by this function can be: (2, 77, 768).
+ For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
@@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
- return torch.hstack(zs)
+ if getattr(self.wrapped, 'return_pooled', False):
+ return torch.hstack(zs), zs[0].pooled
+ else:
+ return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers):
"""
@@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z = z * (original_mean / new_mean)
+ z *= (original_mean / new_mean)
return z
diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py
index f733e852..fcf5ad07 100644
--- a/modules/sd_hijack_open_clip.py
+++ b/modules/sd_hijack_open_clip.py
@@ -35,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded
+
+
+class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
+ self.id_start = tokenizer.encoder["<start_of_text>"]
+ self.id_end = tokenizer.encoder["<end_of_text>"]
+ self.id_pad = 0
+
+ def tokenize(self, texts):
+ assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
+
+ tokenized = [tokenizer.encode(text) for text in texts]
+
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ d = self.wrapped.encode_with_transformer(tokens)
+ z = d[self.wrapped.layer]
+
+ pooled = d.get("pooled")
+ if pooled is not None:
+ z.pooled = pooled
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ ids = tokenizer.encode(init_text)
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
+ embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
+
+ return embedded
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 53e27ade..e99c9ba5 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
class SdOptimization:
@@ -39,6 +43,9 @@ class SdOptimization:
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
+
class SdOptimizationXformers(SdOptimization):
name = "xformers"
@@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
class SdOptimizationSdpNoMem(SdOptimization):
@@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
class SdOptimizationSdp(SdOptimizationSdpNoMem):
@@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
class SdOptimizationSubQuad(SdOptimization):
@@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
class SdOptimizationV1(SdOptimization):
@@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization):
cmd_opt = "opt_split_attention_v1"
priority = 10
-
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
class SdOptimizationInvokeAI(SdOptimization):
@@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
class SdOptimizationDoggettx(SdOptimization):
@@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def list_optimizers(res):
@@ -155,7 +173,7 @@ def get_available_vram():
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
+def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
h = self.heads
q_in = self.to_q(x)
@@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
# taken from https://github.com/Doggettx/stable-diffusion and modified
-def split_cross_attention_forward(self, x, context=None, mask=None):
+def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
h = self.heads
q_in = self.to_q(x)
@@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
+
def einsum_op_slice_0(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
@@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size):
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
+
def einsum_op_slice_1(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
@@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size):
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
return r
+
def einsum_op_mps_v1(q, k, v):
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
@@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v):
slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
+
def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
+
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
@@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
def einsum_op_cuda(q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
@@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v):
# Divide factor of safety as there's copying and fragmentation
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
def einsum_op(q, k, v):
if q.device.type == 'cuda':
return einsum_op_cuda(q, k, v)
@@ -328,7 +354,8 @@ def einsum_op(q, k, v):
# Tested on i7 with 8MB L3 cache.
return einsum_op_tensor_mem(q, k, v, 32)
-def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
h = self.heads
q = self.to_q(x)
@@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
-def sub_quad_attention_forward(self, x, context=None, mask=None):
+def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
h = self.heads
@@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
return x
+
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
@@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
return None
-def xformers_attention_forward(self, x, context=None, mask=None):
+def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
@@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
+
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
-def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
@@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
hidden_states = self.to_out[1](hidden_states)
return hidden_states
-def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
+
+def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)