diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-12-31 15:06:35 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-12-31 15:06:35 +0000 |
commit | f34c7341720fb2059992926c9f9ae6ff25f7385b (patch) | |
tree | be719a629f8754c206d891b1850f0b5eaf186e2e /modules/sd_hijack_xlmr.py | |
parent | 3f401cdb644066fd43abf6642d2e53be53c73668 (diff) | |
download | stable-diffusion-webui-gfx803-f34c7341720fb2059992926c9f9ae6ff25f7385b.tar.gz stable-diffusion-webui-gfx803-f34c7341720fb2059992926c9f9ae6ff25f7385b.tar.bz2 stable-diffusion-webui-gfx803-f34c7341720fb2059992926c9f9ae6ff25f7385b.zip |
alt-diffusion integration
Diffstat (limited to 'modules/sd_hijack_xlmr.py')
-rw-r--r-- | modules/sd_hijack_xlmr.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py new file mode 100644 index 00000000..4ac51c38 --- /dev/null +++ b/modules/sd_hijack_xlmr.py @@ -0,0 +1,34 @@ +import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded
|