From 52cc83d36b7663a77b79fd2258d2ca871af73e55 Mon Sep 17 00:00:00 2001
From: zhaohu xing <920232796@qq.com>
Date: Wed, 30 Nov 2022 14:56:12 +0800
Subject: fix bugs
Signed-off-by: zhaohu xing <920232796@qq.com>
---
modules/sd_hijack_clip.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
(limited to 'modules/sd_hijack_clip.py')
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index b451d1cf..9ea6e1ce 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -4,7 +4,7 @@ import torch
from modules import prompt_parser, devices
from modules.shared import opts
-
+import modules.shared as shared
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
@@ -177,6 +177,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
+ if shared.text_model_name == "XLMR-Large":
+ return self.wrapped.encode(text)
+
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
@@ -254,7 +257,10 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
+ if shared.text_model_name == "XLMR-Large":
+ self.comma_token = None
+ else :
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
--
cgit v1.2.3
From f34c7341720fb2059992926c9f9ae6ff25f7385b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 31 Dec 2022 18:06:35 +0300
Subject: alt-diffusion integration
---
configs/alt-diffusion-inference.yaml | 72 ++++++++++++++++++++++++++++++++++
configs/altdiffusion/ad-inference.yaml | 72 ----------------------------------
configs/v1-inference.yaml | 70 +++++++++++++++++++++++++++++++++
modules/sd_hijack.py | 18 +++++----
modules/sd_hijack_clip.py | 14 +++----
modules/sd_hijack_xlmr.py | 34 ++++++++++++++++
modules/shared.py | 10 +----
v1-inference.yaml | 70 ---------------------------------
8 files changed, 192 insertions(+), 168 deletions(-)
create mode 100644 configs/alt-diffusion-inference.yaml
delete mode 100644 configs/altdiffusion/ad-inference.yaml
create mode 100644 configs/v1-inference.yaml
create mode 100644 modules/sd_hijack_xlmr.py
delete mode 100644 v1-inference.yaml
(limited to 'modules/sd_hijack_clip.py')
diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml
new file mode 100644
index 00000000..cfbee72d
--- /dev/null
+++ b/configs/alt-diffusion-inference.yaml
@@ -0,0 +1,72 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: modules.xlmr.BertSeriesModelWithTransformation
+ params:
+ name: "XLMR-Large"
\ No newline at end of file
diff --git a/configs/altdiffusion/ad-inference.yaml b/configs/altdiffusion/ad-inference.yaml
deleted file mode 100644
index cfbee72d..00000000
--- a/configs/altdiffusion/ad-inference.yaml
+++ /dev/null
@@ -1,72 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: "jpg"
- cond_stage_key: "txt"
- image_size: 64
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 10000 ]
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1. ]
- f_min: [ 1. ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: True
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: modules.xlmr.BertSeriesModelWithTransformation
- params:
- name: "XLMR-Large"
\ No newline at end of file
diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml
new file mode 100644
index 00000000..d4effe56
--- /dev/null
+++ b/configs/v1-inference.yaml
@@ -0,0 +1,70 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index bce23b03..edcbaf52 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
from modules.sd_hijack_optimizations import invokeAI_mps_available
@@ -68,6 +68,7 @@ def fix_checkpoint():
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
+
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -79,21 +80,22 @@ class StableDiffusionModelHijack:
def hijack(self, m):
- if shared.text_model_name == "XLMR-Large":
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
- m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
+ m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- apply_optimizations()
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- apply_optimizations()
-
+
+ apply_optimizations()
+
self.clip = m.cond_stage_model
fix_checkpoint()
@@ -109,7 +111,7 @@ class StableDiffusionModelHijack:
def undo_hijack(self, m):
- if shared.text_model_name == "XLMR-Large":
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 9ea6e1ce..6ec50cca 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -4,7 +4,6 @@ import torch
from modules import prompt_parser, devices
from modules.shared import opts
-import modules.shared as shared
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
@@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
- if shared.text_model_name == "XLMR-Large":
- return self.wrapped.encode(text)
-
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
@@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
- if shared.text_model_name == "XLMR-Large":
- self.comma_token = None
- else :
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
+
+ vocab = self.tokenizer.get_vocab()
+
+ self.comma_token = vocab.get(',', None)
self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
new file mode 100644
index 00000000..4ac51c38
--- /dev/null
+++ b/modules/sd_hijack_xlmr.py
@@ -0,0 +1,34 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded
diff --git a/modules/shared.py b/modules/shared.py
index 2b31e717..715b9169 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -23,7 +23,7 @@ demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@@ -108,14 +108,6 @@ restricted_opts = {
"outdir_txt2img_grids",
"outdir_save",
}
-from omegaconf import OmegaConf
-config = OmegaConf.load(f"{cmd_opts.config}")
-# XLMR-Large
-try:
- text_model_name = config.model.params.cond_stage_config.params.name
-
-except :
- text_model_name = "stable_diffusion"
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
diff --git a/v1-inference.yaml b/v1-inference.yaml
deleted file mode 100644
index d4effe56..00000000
--- a/v1-inference.yaml
+++ /dev/null
@@ -1,70 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: "jpg"
- cond_stage_key: "txt"
- image_size: 64
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 10000 ]
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1. ]
- f_min: [ 1. ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: True
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
--
cgit v1.2.3
From 210449b374d522c94a67fe54289a9eb515933a9f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 02:41:15 +0300
Subject: fix 'RuntimeError: Expected all tensors to be on the same device'
error preventing models from loading on lowvram/medvram.
---
modules/sd_hijack_clip.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/sd_hijack_clip.py')
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 6ec50cca..ca92b142 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -298,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
--
cgit v1.2.3
From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 7 Jan 2023 01:45:28 +0300
Subject: CLIP hijack rework
---
modules/sd_hijack.py | 6 +-
modules/sd_hijack_clip.py | 348 ++++++++++++-------------
modules/sd_hijack_clip_old.py | 81 ++++++
modules/textual_inversion/textual_inversion.py | 1 -
modules/ui.py | 2 +-
5 files changed, 256 insertions(+), 182 deletions(-)
create mode 100644 modules/sd_hijack_clip_old.py
(limited to 'modules/sd_hijack_clip.py')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fa2cd4bb..71cc145a 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -150,10 +150,10 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
- def tokenize(self, text):
- _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
+ def get_prompt_lengths(self, text):
+ _, token_count = self.clip.process_texts([text])
- return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index ca92b142..ac3020d7 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -1,12 +1,28 @@
import math
+from collections import namedtuple
import torch
from modules import prompt_parser, devices
from modules.shared import opts
-def get_target_prompt_token_count(token_count):
- return math.ceil(max(token_count, 1) / 75) * 75
+
+class PromptChunk:
+ """
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+ so just 75 tokens from prompt.
+ """
+
+ def __init__(self):
+ self.tokens = []
+ self.multipliers = []
+ self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk"""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
@@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
+ self.chunk_length = 75
+
+ def empty_chunk(self):
+ """creates an empty PromptChunk and returns it"""
+
+ chunk = PromptChunk()
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
+ return chunk
+
+ def get_target_prompt_token_count(self, token_count):
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
+ """Converts a batch of texts into a batch of token ids"""
+
raise NotImplementedError
def encode_with_transformers(self, tokens):
+ """
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+ All python lists with tokens are assumed to have same length, usually 77.
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+ model - can be 768 and 1024
+ """
+
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+
raise NotImplementedError
- def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ def tokenize_line(self, line):
+ """
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+ represent the prompt.
+ Returns the list and the total number of tokens in the prompt.
+ """
+
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
@@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
tokenized = self.tokenize([text for text, _ in parsed])
- fixes = []
- remade_tokens = []
- multipliers = []
+ chunks = []
+ chunk = PromptChunk()
+ token_count = 0
last_comma = -1
- for tokens, (text, weight) in zip(tokenized, parsed):
- i = 0
- while i < len(tokens):
- token = tokens[i]
+ def next_chunk():
+ """puts current chunk into the list of results and produces the next one - empty"""
+ nonlocal token_count
+ nonlocal last_comma
+ nonlocal chunk
+
+ token_count += len(chunk.tokens)
+ to_add = self.chunk_length - len(chunk.tokens)
+ if to_add > 0:
+ chunk.tokens += [self.id_end] * to_add
+ chunk.multipliers += [1.0] * to_add
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+
+ last_comma = -1
+ chunks.append(chunk)
+ chunk = PromptChunk()
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ position = 0
+ while position < len(tokens):
+ token = tokens[position]
if token == self.comma_token:
- last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
- last_comma += 1
- reloc_tokens = remade_tokens[last_comma:]
- reloc_mults = multipliers[last_comma:]
+ last_comma = len(chunk.tokens)
+
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+ # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next.
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+ break_location = last_comma + 1
+
+ reloc_tokens = chunk.tokens[break_location:]
+ reloc_mults = chunk.multipliers[break_location:]
- remade_tokens = remade_tokens[:last_comma]
- length = len(remade_tokens)
+ chunk.tokens = chunk.tokens[:break_location]
+ chunk.multipliers = chunk.multipliers[:break_location]
- rem = int(math.ceil(length / 75)) * 75 - length
- remade_tokens += [self.id_end] * rem + reloc_tokens
- multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+ next_chunk()
+ chunk.tokens = reloc_tokens
+ chunk.multipliers = reloc_mults
+ if len(chunk.tokens) == self.chunk_length:
+ next_chunk()
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
if embedding is None:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- iteration = len(remade_tokens) // 75
- if (len(remade_tokens) + emb_len) // 75 != iteration:
- rem = (75 * (iteration + 1) - len(remade_tokens))
- remade_tokens += [self.id_end] * rem
- multipliers += [1.0] * rem
- iteration += 1
- fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- token_count = len(remade_tokens)
- prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens)
-
- remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
- multipliers = multipliers + [1.0] * tokens_to_add
-
- return remade_tokens, fixes, multipliers, token_count
-
- def process_text(self, texts):
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
+ chunk.tokens.append(token)
+ chunk.multipliers.append(weight)
+ position += 1
+ continue
+
+ emb_len = int(embedding.vec.shape[0])
+ if len(chunk.tokens) + emb_len > self.chunk_length:
+ next_chunk()
+
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+
+ chunk.tokens += [0] * emb_len
+ chunk.multipliers += [weight] * emb_len
+ position += embedding_length_in_tokens
+
+ if len(chunk.tokens) > 0:
+ next_chunk()
+
+ return chunks, token_count
+
+ def process_texts(self, texts):
+ """
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+ length, in tokens, of all texts.
+ """
+
token_count = 0
cache = {}
- batch_multipliers = []
+ batch_chunks = []
for line in texts:
if line in cache:
- remade_tokens, fixes, multipliers = cache[line]
+ chunks = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
- cache[line] = (remade_tokens, fixes, multipliers)
+ cache[line] = chunks
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
+ batch_chunks.append(chunks)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+ return batch_chunks, token_count
- def process_text_old(self, texts):
- id_start = self.id_start
- id_end = self.id_end
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
+ def forward(self, texts):
+ """
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ An example shape returned by this function can be: (2, 77, 768).
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+ """
- cache = {}
- batch_tokens = self.tokenize(texts)
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
+ if opts.use_old_emphasis_implementation:
+ import modules.sd_hijack_clip_old
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
- def forward(self, text):
- use_old = opts.use_old_emphasis_implementation
- if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
- else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
- if use_old:
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
- z = None
- i = 0
- while max(map(len, remade_batch_tokens)) != 0:
- rem_tokens = [x[75:] for x in remade_batch_tokens]
- rem_multipliers = [x[75:] for x in batch_multipliers]
-
- self.hijack.fixes = []
- for unfiltered in hijack_fixes:
- fixes = []
- for fix in unfiltered:
- if fix[0] == i:
- fixes.append(fix[1])
- self.hijack.fixes.append(fixes)
-
- tokens = []
- multipliers = []
- for j in range(len(remade_batch_tokens)):
- if len(remade_batch_tokens[j]) > 0:
- tokens.append(remade_batch_tokens[j][:75])
- multipliers.append(batch_multipliers[j][:75])
- else:
- tokens.append([self.id_end] * 75)
- multipliers.append([1.0] * 75)
-
- z1 = self.process_tokens(tokens, multipliers)
- z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- remade_batch_tokens = rem_tokens
- batch_multipliers = rem_multipliers
- i += 1
+ batch_chunks, token_count = self.process_texts(texts)
- return z
+ used_embeddings = {}
+ chunk_count = max([len(x) for x in batch_chunks])
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
- batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+ zs = []
+ for i in range(chunk_count):
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+
+ tokens = [x.tokens for x in batch_chunk]
+ multipliers = [x.multipliers for x in batch_chunk]
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
+ for fixes in self.hijack.fixes:
+ for position, embedding in fixes:
+ used_embeddings[embedding.name] = embedding
+
+ z = self.process_tokens(tokens, multipliers)
+ zs.append(z)
+
+ if len(used_embeddings) > 0:
+ embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+
+ return torch.hstack(zs)
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ """
+ sends one single prompt chunk to be encoded by transformers neural network.
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+ corresponds to one token.
+ """
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
@@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
- batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
new file mode 100644
index 00000000..6d9fbbe6
--- /dev/null
+++ b/modules/sd_hijack_clip_old.py
@@ -0,0 +1,81 @@
+from modules import sd_hijack_clip
+from modules import shared
+
+
+def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ id_start = self.id_start
+ id_end = self.id_end
+ maxlen = self.wrapped.max_length # you get to stay at 77
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_tokens = self.tokenize(texts)
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ i += 1
+ elif embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
+
+ self.hijack.comments += hijack_comments
+
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f9f5e8cd..45882ed6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -79,7 +79,6 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
- # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
diff --git a/modules/ui.py b/modules/ui.py
index b79d24ee..5d2f5bad 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -368,7 +368,7 @@ def update_token_counter(text, steps):
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
- tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"{token_count}/{max_length}"
--
cgit v1.2.3
From 08066676a47b560235d4c085dd3cfcb470b80997 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 7 Jan 2023 07:22:07 +0300
Subject: make it not break on empty inputs; thank you tarded, we are
---
modules/sd_hijack_clip.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules/sd_hijack_clip.py')
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index ac3020d7..16aef76a 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -147,7 +147,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens
- if len(chunk.tokens) > 0:
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
next_chunk()
return chunks, token_count
--
cgit v1.2.3
From 1740c33547b62f692834c95914a2b295d51684c7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 7 Jan 2023 07:48:44 +0300
Subject: more comments
---
modules/sd_hijack_clip.py | 21 ++++++++++++++++-----
1 file changed, 16 insertions(+), 5 deletions(-)
(limited to 'modules/sd_hijack_clip.py')
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 16aef76a..5520c9b2 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -3,7 +3,7 @@ from collections import namedtuple
import torch
-from modules import prompt_parser, devices
+from modules import prompt_parser, devices, sd_hijack
from modules.shared import opts
@@ -22,14 +22,24 @@ class PromptChunk:
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
-"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk"""
+"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
+chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
+are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
+ have unlimited prompt length and assign weights to tokens in prompt.
+ """
+
def __init__(self, wrapped, hijack):
super().__init__()
+
self.wrapped = wrapped
- self.hijack = hijack
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
+ depending on model."""
+
+ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75
def empty_chunk(self):
@@ -55,7 +65,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
All python lists with tokens are assumed to have same length, usually 77.
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
- model - can be 768 and 1024
+ model - can be 768 and 1024.
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
"""
raise NotImplementedError
@@ -113,7 +124,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
last_comma = len(chunk.tokens)
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
- # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next.
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
break_location = last_comma + 1
--
cgit v1.2.3
From df3b31eb559ab9fabf7e513bdeddd5282c16f124 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Sat, 7 Jan 2023 07:04:59 -0500
Subject: In-place operations can break gradient calculation
---
modules/sd_hijack_clip.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules/sd_hijack_clip.py')
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 5520c9b2..852afc66 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z *= original_mean / new_mean
+ z = z * (original_mean / new_mean)
return z
--
cgit v1.2.3
From 8e2aeee4a127b295bfc880800e4a312e0f049b85 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 22:29:53 +0300
Subject: add BREAK keyword to end current text chunk and start the next
---
modules/prompt_parser.py | 7 ++++++-
modules/sd_hijack_clip.py | 17 +++++++++++++----
2 files changed, 19 insertions(+), 5 deletions(-)
(limited to 'modules/sd_hijack_clip.py')
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 870218db..69665372 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -274,6 +274,7 @@ re_attention = re.compile(r"""
:
""", re.X)
+re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text):
"""
@@ -339,7 +340,11 @@ def parse_prompt_attention(text):
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
- res.append([text, 1.0])
+ parts = re.split(re_break, text)
+ for i, part in enumerate(parts):
+ if i > 0:
+ res.append(["BREAK", -1])
+ res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 852afc66..9fa5c5c5 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
token_count = 0
last_comma = -1
- def next_chunk():
- """puts current chunk into the list of results and produces the next one - empty"""
+ def next_chunk(is_last=False):
+ """puts current chunk into the list of results and produces the next one - empty;
+ if is_last is true, tokens tokens at the end won't add to token_count"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
- token_count += len(chunk.tokens)
+ if is_last:
+ token_count += len(chunk.tokens)
+ else:
+ token_count += self.chunk_length
+
to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [self.id_end] * to_add
@@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
+ if text == 'BREAK' and weight == -1:
+ next_chunk()
+ continue
+
position = 0
while position < len(tokens):
token = tokens[position]
@@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0:
- next_chunk()
+ next_chunk(is_last=True)
return chunks, token_count
--
cgit v1.2.3