diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-16 09:04:53 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-16 09:04:53 +0000 |
commit | 0198eaec455157a7dc1c950708d1ec95bcf4629c (patch) | |
tree | 33d8e22448356c2f7c9455b3af17353ef497bbac /modules | |
parent | 9d3dd64fe9e95873347710ca1df1f1e88d1908e1 (diff) | |
parent | 14cf434bc36d0ef31f31d4c6cd2bd15d7857d5c8 (diff) | |
download | stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.tar.gz stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.tar.bz2 stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.zip |
Merge pull request #11757 from AUTOMATIC1111/sdxl
SD XL support
Diffstat (limited to 'modules')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 2 | ||||
-rw-r--r-- | modules/launch_utils.py | 4 | ||||
-rw-r--r-- | modules/lowvram.py | 53 | ||||
-rw-r--r-- | modules/paths.py | 25 | ||||
-rw-r--r-- | modules/processing.py | 27 | ||||
-rw-r--r-- | modules/prompt_parser.py | 95 | ||||
-rw-r--r-- | modules/sd_hijack.py | 38 | ||||
-rw-r--r-- | modules/sd_hijack_clip.py | 31 | ||||
-rw-r--r-- | modules/sd_hijack_open_clip.py | 36 | ||||
-rw-r--r-- | modules/sd_hijack_optimizations.py | 51 | ||||
-rw-r--r-- | modules/sd_models.py | 33 | ||||
-rw-r--r-- | modules/sd_models_config.py | 9 | ||||
-rw-r--r-- | modules/sd_models_xl.py | 99 | ||||
-rw-r--r-- | modules/sd_samplers.py | 3 | ||||
-rw-r--r-- | modules/sd_samplers_compvis.py | 6 | ||||
-rw-r--r-- | modules/sd_samplers_kdiffusion.py | 45 | ||||
-rw-r--r-- | modules/sd_vae_approx.py | 59 | ||||
-rw-r--r-- | modules/sd_vae_taesd.py | 26 | ||||
-rw-r--r-- | modules/shared.py | 9 |
19 files changed, 548 insertions, 103 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 79670b87..c4821d21 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None): return context_k, context_v
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
+def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q = self.to_q(x)
diff --git a/modules/launch_utils.py b/modules/launch_utils.py index ff77cbfd..434facbc 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -237,11 +237,13 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
+ stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
+ stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -299,6 +301,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
+ git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
@@ -323,6 +326,7 @@ def prepare_environment(): exit(0)
+
def configure_for_tests():
if "--api" not in sys.argv:
sys.argv.append("--api")
diff --git a/modules/lowvram.py b/modules/lowvram.py index d95bcfbf..6bbc11eb 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram): send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
- # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
-
- # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
- # send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
+ to_remain_in_cpu = [
+ (sd_model, 'first_stage_model'),
+ (sd_model, 'depth_model'),
+ (sd_model, 'embedder'),
+ (sd_model, 'model'),
+ (sd_model, 'embedder'),
+ ]
+
+ is_sdxl = hasattr(sd_model, 'conditioner')
+ is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
+
+ if is_sdxl:
+ to_remain_in_cpu.append((sd_model, 'conditioner'))
+ elif is_sd2:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
+ else:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
+
+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
+ stored = []
+ for obj, field in to_remain_in_cpu:
+ module = getattr(obj, field, None)
+ stored.append(module)
+ setattr(obj, field, None)
+
+ # send the model to GPU.
sd_model.to(devices.device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
+
+ # put modules back. the modules will be in CPU.
+ for (obj, field), module in zip(to_remain_in_cpu, stored):
+ setattr(obj, field, module)
# register hooks for those the first three models
- sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+ if is_sdxl:
+ sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
+ elif is_sd2:
+ sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
+ else:
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
@@ -73,11 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
- parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
- del sd_model.cond_stage_model.transformer
+ if hasattr(sd_model, 'cond_stage_model'):
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
diff --git a/modules/paths.py b/modules/paths.py index bada804e..25052339 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio import modules.safe # noqa: F401
+def mute_sdxl_imports():
+ """create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
+
+ class Dummy:
+ pass
+
+ module = Dummy()
+ module.LPIPS = None
+ sys.modules['taming.modules.losses.lpips'] = module
+
+ module = Dummy()
+ module.StableDataModuleFromConfig = None
+ sys.modules['sgm.data'] = module
+
+
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
@@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths: assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
+mute_sdxl_imports()
+
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
+ (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
@@ -35,6 +53,13 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d)
if "atstart" in options:
sys.path.insert(0, d)
+ elif "sgm" in options:
+ # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
+ # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
+
+ sys.path.insert(0, d)
+ import sgm # noqa: F401
+ sys.path.pop(0)
else:
sys.path.append(d)
paths[what] = d
diff --git a/modules/processing.py b/modules/processing.py index 49441e77..e7b10808 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -330,8 +330,21 @@ class StableDiffusionProcessing: caches is a list with items described above.
"""
+
+ cached_params = (
+ required_prompts,
+ steps,
+ opts.CLIP_stop_at_last_layers,
+ shared.sd_model.sd_checkpoint_info,
+ extra_network_data,
+ opts.sdxl_crop_left,
+ opts.sdxl_crop_top,
+ self.width,
+ self.height,
+ )
+
for cache in caches:
- if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
+ if cache[0] is not None and cached_params == cache[0]:
return cache[1]
cache = caches[0]
@@ -339,14 +352,17 @@ class StableDiffusionProcessing: with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
+ cache[0] = cached_params
return cache[1]
def setup_conds(self):
+ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
+ negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
+
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@@ -523,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x):
- with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x)
+ x = model.decode_first_stage(x.to(devices.dtype_vae))
return x
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 0069d8b0..b29d079d 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations
+
import re
from collections import namedtuple
from typing import List
@@ -109,7 +111,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-def get_learned_conditioning(model, prompts, steps):
+class SdConditioning(list):
+ """
+ A list with prompts for stable diffusion's conditioner model.
+ Can also specify width and height of created image - SDXL needs it.
+ """
+ def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
+ super().__init__()
+ self.extend(prompts)
+
+ if copy_from is None:
+ copy_from = prompts
+
+ self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
+ self.width = width or getattr(copy_from, 'width', None)
+ self.height = height or getattr(copy_from, 'height', None)
+
+
+
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -139,12 +159,17 @@ def get_learned_conditioning(model, prompts, steps): res.append(cached)
continue
- texts = [x[1] for x in prompt_schedule]
+ texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
conds = model.get_learned_conditioning(texts)
cond_schedule = []
for i, (end_at_step, _) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+ if isinstance(conds, dict):
+ cond = {k: v[i] for k, v in conds.items()}
+ else:
+ cond = conds[i]
+
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
cache[prompt] = cond_schedule
res.append(cond_schedule)
@@ -155,11 +180,13 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-def get_multicond_prompt_list(prompts):
+
+def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = []
- prompt_flat_list = []
prompt_indexes = {}
+ prompt_flat_list = SdConditioning(prompts)
+ prompt_flat_list.clear()
for prompt in prompts:
subprompts = re_AND.split(prompt)
@@ -196,6 +223,7 @@ class MulticondLearnedConditioning: self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@@ -214,20 +242,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+class DictWithShape(dict):
+ def __init__(self, x, shape):
+ super().__init__()
+ self.update(x)
+
+ @property
+ def shape(self):
+ return self["crossattn"].shape
+
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
- res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ is_dict = isinstance(param, dict)
+
+ if is_dict:
+ dict_cond = param
+ res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
+ res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
+ else:
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+
for i, cond_schedule in enumerate(c):
target_index = 0
for current, entry in enumerate(cond_schedule):
if current_step <= entry.end_at_step:
target_index = current
break
- res[i] = cond_schedule[target_index].cond
+
+ if is_dict:
+ for k, param in cond_schedule[target_index].cond.items():
+ res[k][i] = param
+ else:
+ res[i] = cond_schedule[target_index].cond
return res
+def stack_conds(tensors):
+ # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
+ return torch.stack(tensors)
+
+
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
@@ -249,16 +314,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch)
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
- # and won't be able to torch.stack them. So this fixes that.
- token_count = max([x.shape[0] for x in tensors])
- for i in range(len(tensors)):
- if tensors[i].shape[0] != token_count:
- last_vector = tensors[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
- tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+ if isinstance(tensors[0], dict):
+ keys = list(tensors[0].keys())
+ stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
+ stacked = DictWithShape(stacked, stacked['crossattn'].shape)
+ else:
+ stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
- return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+ return conds_list, stacked
re_attention = re.compile(r"""
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6b5aae4b..f5615967 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+import sgm.modules.diffusionmodules.openaimodel
+import sgm.modules.encoders.modules
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
@@ -56,6 +61,9 @@ def apply_optimizations(option=None): ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+ sgm.modules.diffusionmodules.model.nonlinearity = silu
+ sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
@@ -89,6 +97,10 @@ def undo_optimizations(): ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
@@ -168,6 +180,32 @@ class StableDiffusionModelHijack: undo_optimizations()
def hijack(self, m):
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ text_cond_models = []
+
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ typename = type(embedder).__name__
+ if typename == 'FrozenOpenCLIPEmbedder':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenCLIPEmbedder':
+ model_embeddings = embedder.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenOpenCLIPEmbedder2':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+
+ if len(text_cond_models) == 1:
+ m.cond_stage_model = text_cond_models[0]
+ else:
+ m.cond_stage_model = conditioner
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index c1d780a3..5443e609 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75
+ self.is_trainable = getattr(wrapped, 'is_trainable', False)
+ self.input_key = getattr(wrapped, 'input_key', 'txt')
+ self.legacy_ucg_val = None
+
def empty_chunk(self):
"""creates an empty PromptChunk and returns it"""
@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): """
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
An example shape returned by this function can be: (2, 77, 768).
+ For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
@@ -242,7 +247,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): if hashes:
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
- return torch.hstack(zs)
+ if getattr(self.wrapped, 'return_pooled', False):
+ return torch.hstack(zs), zs[0].pooled
+ else:
+ return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers):
"""
@@ -265,9 +273,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z = z * (original_mean / new_mean)
+ z *= (original_mean / new_mean)
return z
@@ -324,3 +332,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
+
+
+class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ def encode_with_transformers(self, tokens):
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
+
+ if self.wrapped.layer == "last":
+ z = outputs.last_hidden_state
+ else:
+ z = outputs.hidden_states[self.wrapped.layer_idx]
+
+ return z
diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index f733e852..bb0b96c7 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -32,6 +32,40 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
- embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
+ embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
+
+ return embedded
+
+
+class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
+ self.id_start = tokenizer.encoder["<start_of_text>"]
+ self.id_end = tokenizer.encoder["<end_of_text>"]
+ self.id_pad = 0
+
+ def tokenize(self, texts):
+ assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
+
+ tokenized = [tokenizer.encode(text) for text in texts]
+
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ d = self.wrapped.encode_with_transformer(tokens)
+ z = d[self.wrapped.layer]
+
+ pooled = d.get("pooled")
+ if pooled is not None:
+ z.pooled = pooled
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ ids = tokenizer.encode(init_text)
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
+ embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 53e27ade..b5f85ba5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
class SdOptimization:
@@ -39,6 +43,9 @@ class SdOptimization: ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
+
class SdOptimizationXformers(SdOptimization):
name = "xformers"
@@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization): def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
class SdOptimizationSdpNoMem(SdOptimization):
@@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization): def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
class SdOptimizationSdp(SdOptimizationSdpNoMem):
@@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
class SdOptimizationSubQuad(SdOptimization):
@@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization): def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
class SdOptimizationV1(SdOptimization):
@@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization): cmd_opt = "opt_split_attention_v1"
priority = 10
-
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
class SdOptimizationInvokeAI(SdOptimization):
@@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
class SdOptimizationDoggettx(SdOptimization):
@@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization): def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def list_optimizers(res):
@@ -155,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
+def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
@@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): # taken from https://github.com/Doggettx/stable-diffusion and modified
-def split_cross_attention_forward(self, x, context=None, mask=None):
+def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
@@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None): # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
+
def einsum_op_slice_0(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
@@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size): r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
+
def einsum_op_slice_1(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
@@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size): r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
return r
+
def einsum_op_mps_v1(q, k, v):
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
@@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v): slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
+
def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
+
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
@@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb): return einsum_op_slice_0(q, k, v, q.shape[0] // div)
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
def einsum_op_cuda(q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
@@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v): # Divide factor of safety as there's copying and fragmentation
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
def einsum_op(q, k, v):
if q.device.type == 'cuda':
return einsum_op_cuda(q, k, v)
@@ -328,7 +354,8 @@ def einsum_op(q, k, v): # Tested on i7 with 8MB L3 cache.
return einsum_op_tensor_mem(q, k, v, 32)
-def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
h = self.heads
q = self.to_q(x)
@@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
-def sub_quad_attention_forward(self, x, context=None, mask=None):
+def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
h = self.heads
@@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x
+
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
@@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None
-def xformers_attention_forward(self, x, context=None, mask=None):
+def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
@@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
+
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
-def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
@@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): hidden_states = self.to_out[1](hidden_states)
return hidden_states
-def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
+
+def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x): return h3
+
def xformers_attnblock_forward(self, x):
try:
h_ = x
@@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x): except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
def sdp_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x): out = self.proj_out(out)
return x + out
+
def sdp_no_mem_attnblock_forward(self, x):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return sdp_attnblock_forward(self, x)
+
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007..729f03d7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -289,6 +289,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
+ model.is_sdxl = hasattr(model, 'conditioner')
+ if model.is_sdxl:
+ sd_models_xl.extend_sdxl(model)
+
model.load_state_dict(state_dict, strict=False)
del state_dict
timer.record("apply weights to model")
@@ -334,7 +338,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.sd_checkpoint_info = checkpoint_info
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
- model.logvar = model.logvar.to(devices.device) # fix for training
+ if hasattr(model, 'logvar'):
+ model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
@@ -391,10 +396,11 @@ def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
- if shared.cmd_opts.no_half:
- sd_config.model.params.unet_config.params.use_fp16 = False
- elif shared.cmd_opts.upcast_sampling:
- sd_config.model.params.unet_config.params.use_fp16 = True
+ if hasattr(sd_config.model.params, 'unet_config'):
+ if shared.cmd_opts.no_half:
+ sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
@@ -407,6 +413,8 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
+sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
@@ -441,6 +449,15 @@ class SdModelData: model_data = SdModelData()
+def get_empty_cond(sd_model):
+ if hasattr(sd_model, 'conditioner'):
+ d = sd_model.get_learned_conditioning([""])
+ return d['crossattn']
+ else:
+ return sd_model.cond_stage_model([""])
+
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -461,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
+ clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
timer.record("find config")
@@ -513,7 +530,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
- sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9bfe1237..8266fa39 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
+sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
+config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
+config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
- if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
+ if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
+ return config_sdxl
+ if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
+ return config_sdxl_refiner
+ elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py new file mode 100644 index 00000000..01320c7a --- /dev/null +++ b/modules/sd_models_xl.py @@ -0,0 +1,99 @@ +from __future__ import annotations
+
+import torch
+
+import sgm.models.diffusion
+import sgm.modules.diffusionmodules.denoiser_scaling
+import sgm.modules.diffusionmodules.discretizer
+from modules import devices, shared, prompt_parser
+
+
+def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
+ for embedder in self.conditioner.embedders:
+ embedder.ucg_rate = 0.0
+
+ width = getattr(self, 'target_width', 1024)
+ height = getattr(self, 'target_height', 1024)
+ is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
+ aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
+
+ devices_args = dict(device=devices.device, dtype=devices.dtype)
+
+ sdxl_conds = {
+ "txt": batch,
+ "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
+ "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
+ "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
+ "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
+ }
+
+ force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
+ c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
+
+ return c
+
+
+def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
+ return self.model(x, t, cond)
+
+
+def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
+ return x
+
+
+sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
+sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
+sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
+
+
+def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
+ res = []
+
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
+ encoded = embedder.encode_embedding_init_text(init_text, nvpt)
+ res.append(encoded)
+
+ return torch.cat(res, dim=1)
+
+
+def process_texts(self, texts):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
+ return embedder.process_texts(texts)
+
+
+def get_target_prompt_token_count(self, token_count):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
+ return embedder.get_target_prompt_token_count(token_count)
+
+
+# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
+sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
+sgm.modules.GeneralConditioner.process_texts = process_texts
+sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
+
+
+def extend_sdxl(model):
+ """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
+
+ dtype = next(model.model.diffusion_model.parameters()).dtype
+ model.model.diffusion_model.dtype = dtype
+ model.model.conditioning_key = 'crossattn'
+ model.cond_stage_key = 'txt'
+ # model.cond_stage_model will be set in sd_hijack
+
+ model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
+
+ discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+
+ model.conditioner.wrapped = torch.nn.Module()
+
+
+sgm.modules.attention.print = lambda *args: None
+sgm.modules.diffusionmodules.model.print = lambda *args: None
+sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
+sgm.modules.encoders.modules.print = lambda *args: None
+
+# this gets the code to load the vanilla attention that we override
+sgm.modules.attention.SDP_IS_AVAILABLE = True
+sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f22aad8f..bea2684c 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -28,6 +28,9 @@ def create_sampler(name, model): assert config is not None, f'bad sampler name: {name}'
+ if model.is_sdxl and config.options.get("no_sdxl", False):
+ raise Exception(f"Sampler {config.name} is not supported for SDXL")
+
sampler = config.constructor(model)
sampler.config = config
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index bdae8b40..4a8396f9 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
- sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
- sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
+ sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
+ sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
]
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b76..5552a8dc 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,6 +53,28 @@ k_diffusion_scheduler = { }
+def catenate_conds(conds):
+ if not isinstance(conds[0], dict):
+ return torch.cat(conds)
+
+ return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
+
+
+def subscript_cond(cond, a, b):
+ if not isinstance(cond, dict):
+ return cond[a:b]
+
+ return {key: vec[a:b] for key, vec in cond.items()}
+
+
+def pad_cond(tensor, repeats, empty):
+ if not isinstance(tensor, dict):
+ return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
+
+ tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
+ return tensor
+
+
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
@@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
else:
image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+ if isinstance(uncond, dict):
+ make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
+ else:
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module): num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
- tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ tensor = pad_cond(tensor, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
- uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+ uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
- cond_in = torch.cat([tensor, uncond, uncond])
+ cond_in = catenate_conds([tensor, uncond, uncond])
elif skip_uncond:
cond_in = tensor
else:
- cond_in = torch.cat([tensor, uncond])
+ cond_in = catenate_conds([tensor, uncond])
if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
+ x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
@@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module): b = min(a + batch_size, tensor.shape[0])
if not is_edit_model:
- c_crossattn = [tensor[a:b]]
+ c_crossattn = subscript_cond(tensor, a, b)
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond:
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list]
if skip_uncond:
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index e2f00468..86bd658a 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -2,9 +2,9 @@ import os import torch
from torch import nn
-from modules import devices, paths
+from modules import devices, paths, shared
-sd_vae_approx_model = None
+sd_vae_approx_models = {}
class VAEApprox(nn.Module):
@@ -31,30 +31,55 @@ class VAEApprox(nn.Module): return x
+def download_model(model_path, model_url):
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+ print(f'Downloading VAEApprox model to: {model_path}')
+ torch.hub.download_url_to_file(model_url, model_path)
+
+
def model():
- global sd_vae_approx_model
+ model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
+ loaded_model = sd_vae_approx_models.get(model_name)
- if sd_vae_approx_model is None:
- model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
- sd_vae_approx_model = VAEApprox()
+ if loaded_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
if not os.path.exists(model_path):
- model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
- sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
- sd_vae_approx_model.eval()
- sd_vae_approx_model.to(devices.device, devices.dtype)
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
+
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
+ download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
+
+ loaded_model = VAEApprox()
+ loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+ loaded_model.eval()
+ loaded_model.to(devices.device, devices.dtype)
+ sd_vae_approx_models[model_name] = loaded_model
- return sd_vae_approx_model
+ return loaded_model
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
- coefs = torch.tensor([
- [0.298, 0.207, 0.208],
- [0.187, 0.286, 0.173],
- [-0.158, 0.189, 0.264],
- [-0.184, -0.271, -0.473],
- ]).to(sample.device)
+ if shared.sd_model.is_sdxl:
+ coeffs = [
+ [ 0.3448, 0.4168, 0.4395],
+ [-0.1953, -0.0290, 0.0250],
+ [ 0.1074, 0.0886, -0.0163],
+ [-0.3730, -0.2499, -0.2088],
+ ]
+ else:
+ coeffs = [
+ [ 0.298, 0.207, 0.208],
+ [ 0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]
+
+ coefs = torch.tensor(coeffs).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5e8496e8..5bf7c76e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -8,9 +8,9 @@ import os import torch import torch.nn as nn -from modules import devices, paths_internal +from modules import devices, paths_internal, shared -sd_vae_taesd = None +sd_vae_taesd_models = {} def conv(n_in, n_out, **kwargs): @@ -61,9 +61,7 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def download_model(model_path): - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - +def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) @@ -72,17 +70,19 @@ def download_model(model_path): def model(): - global sd_vae_taesd + model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) - if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") - download_model(model_path) + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - sd_vae_taesd = TAESD(model_path) - sd_vae_taesd.eval() - sd_vae_taesd.to(devices.device, devices.dtype) + loaded_model = TAESD(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return sd_vae_taesd.decoder + return loaded_model.decoder diff --git a/modules/shared.py b/modules/shared.py index f6604ef9..6162938a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -429,9 +429,16 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
}))
+options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
+ "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
+ "sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
+ "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
+ "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
+}))
+
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
- "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|