diff options
author | Fampai <unknown> | 2022-10-08 20:32:05 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2022-10-09 19:31:23 +0000 |
commit | e59c66c0088422b27f64b401ef42c242f836725a (patch) | |
tree | 4de6bb6c9ca31b007fb55526d218bddf40561287 /modules/sd_hijack.py | |
parent | 6c383d2e82045fc4475d665f83bdeeac8fd844d9 (diff) | |
download | stable-diffusion-webui-gfx803-e59c66c0088422b27f64b401ef42c242f836725a.tar.gz stable-diffusion-webui-gfx803-e59c66c0088422b27f64b401ef42c242f836725a.tar.bz2 stable-diffusion-webui-gfx803-e59c66c0088422b27f64b401ef42c242f836725a.zip |
Optimized code for Ignoring last CLIP layers
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 12 |
1 files changed, 4 insertions, 8 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f12a9696..4a2d2153 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,14 +282,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
- tmp = -opts.CLIP_ignore_last_layers
- if (opts.CLIP_ignore_last_layers == 0):
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
- z = outputs.last_hidden_state
- else:
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
- z = outputs.hidden_states[tmp]
- z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ tmp = -opts.CLIP_stop_at_last_layers
+ outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
+ z = outputs.hidden_states[tmp]
+ z = self.wrapped.transformer.text_model.final_layer_norm(z)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
|