diff options
author | brkirch <brkirch@users.noreply.github.com> | 2023-01-04 05:40:16 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2023-01-06 05:14:20 +0000 |
commit | f6ab5a39d762a7791573d1c52ae5a3024b10e8ed (patch) | |
tree | c3958d77a6dae42457b571dbe0f1efec7ce45dd2 /modules/sd_hijack_xlmr.py | |
parent | d782a95967c9eea753df3333cd1954b6ec73eba0 (diff) | |
parent | 3e22e294135ed0327ce9d9738655ff03c53df3c0 (diff) | |
download | stable-diffusion-webui-gfx803-f6ab5a39d762a7791573d1c52ae5a3024b10e8ed.tar.gz stable-diffusion-webui-gfx803-f6ab5a39d762a7791573d1c52ae5a3024b10e8ed.tar.bz2 stable-diffusion-webui-gfx803-f6ab5a39d762a7791573d1c52ae5a3024b10e8ed.zip |
Merge branch 'AUTOMATIC1111:master' into sub-quad_attn_opt
Diffstat (limited to 'modules/sd_hijack_xlmr.py')
-rw-r--r-- | modules/sd_hijack_xlmr.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py new file mode 100644 index 00000000..4ac51c38 --- /dev/null +++ b/modules/sd_hijack_xlmr.py @@ -0,0 +1,34 @@ +import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded
|