aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_xlmr.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-04 14:40:19 +0000
committerGitHub <noreply@github.com>2023-01-04 14:40:19 +0000
commitda5c1e8a732c173ed8ccda9fa32f9a194ff91ab6 (patch)
treea2eec9c47e820e7ab351337f73c99d874b4b904f /modules/sd_hijack_xlmr.py
parentcffc240a7327ae60671ff533469fc4ed4bf605de (diff)
parent47df0849019abac6722c49512f4dd2285bff5b7d (diff)
downloadstable-diffusion-webui-gfx803-da5c1e8a732c173ed8ccda9fa32f9a194ff91ab6.tar.gz
stable-diffusion-webui-gfx803-da5c1e8a732c173ed8ccda9fa32f9a194ff91ab6.tar.bz2
stable-diffusion-webui-gfx803-da5c1e8a732c173ed8ccda9fa32f9a194ff91ab6.zip
Merge branch 'master' into inpaint_textual_inversion
Diffstat (limited to 'modules/sd_hijack_xlmr.py')
-rw-r--r--modules/sd_hijack_xlmr.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
new file mode 100644
index 00000000..4ac51c38
--- /dev/null
+++ b/modules/sd_hijack_xlmr.py
@@ -0,0 +1,34 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded