diff options
author | zhaohu xing <920232796@qq.com> | 2022-11-29 02:28:41 +0000 |
---|---|---|
committer | zhaohu xing <920232796@qq.com> | 2022-11-29 02:28:41 +0000 |
commit | 75c4511e6b81ae8fb0dbd932043e8eb35cd09f72 (patch) | |
tree | 6f4662507be1d532a4e992f54f82d905fc450f3a /modules/sd_hijack.py | |
parent | 828438b4a190759807f9054932cae3a8b880ddf1 (diff) | |
download | stable-diffusion-webui-gfx803-75c4511e6b81ae8fb0dbd932043e8eb35cd09f72.tar.gz stable-diffusion-webui-gfx803-75c4511e6b81ae8fb0dbd932043e8eb35cd09f72.tar.bz2 stable-diffusion-webui-gfx803-75c4511e6b81ae8fb0dbd932043e8eb35cd09f72.zip |
add AltDiffusion to webui
Signed-off-by: zhaohu xing <920232796@qq.com>
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..26280fe4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -70,14 +70,19 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+
+ if shared.text_model_name == "XLMR-Large":
+ model_embeddings = m.cond_stage_model.roberta.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
+ else :
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embeddings, self)
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
- apply_optimizations()
+ # apply_optimizations()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -125,8 +130,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer
self.token_mults = {}
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
-
+ try:
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+ except:
+ self.comma_token = None
+
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
@@ -298,6 +306,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
+ if shared.text_model_name == "XLMR-Large":
+ return self.wrapped.encode(text)
+
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
@@ -359,7 +370,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
-
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|