aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_clip.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-12-31 15:06:35 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-12-31 15:06:35 +0000
commitf34c7341720fb2059992926c9f9ae6ff25f7385b (patch)
treebe719a629f8754c206d891b1850f0b5eaf186e2e /modules/sd_hijack_clip.py
parent3f401cdb644066fd43abf6642d2e53be53c73668 (diff)
downloadstable-diffusion-webui-gfx803-f34c7341720fb2059992926c9f9ae6ff25f7385b.tar.gz
stable-diffusion-webui-gfx803-f34c7341720fb2059992926c9f9ae6ff25f7385b.tar.bz2
stable-diffusion-webui-gfx803-f34c7341720fb2059992926c9f9ae6ff25f7385b.zip
alt-diffusion integration
Diffstat (limited to 'modules/sd_hijack_clip.py')
-rw-r--r--modules/sd_hijack_clip.py14
1 files changed, 5 insertions, 9 deletions
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 9ea6e1ce..6ec50cca 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -4,7 +4,6 @@ import torch
from modules import prompt_parser, devices
from modules.shared import opts
-import modules.shared as shared
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
@@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
- if shared.text_model_name == "XLMR-Large":
- return self.wrapped.encode(text)
-
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
@@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
- if shared.text_model_name == "XLMR-Large":
- self.comma_token = None
- else :
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
+ vocab = self.tokenizer.get_vocab()
+
+ self.comma_token = vocab.get(',</w>', None)
self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text: