aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorDepFA <35278260+dfaker@users.noreply.github.com>2022-10-09 23:38:54 +0000
committerGitHub <noreply@github.com>2022-10-09 23:38:54 +0000
commit4117afff11c7b0a2162c73ea02be8cfa30d02640 (patch)
treeaf26f1b0c9eac8c024d2e51ec8fb5ca4a4d45ed3 /modules/sd_hijack.py
parente2c2925eb4d634b186de2c76798162ec56e2f869 (diff)
parent45fbd1c5fec887988ab555aac75a999d4f3aff40 (diff)
downloadstable-diffusion-webui-gfx803-4117afff11c7b0a2162c73ea02be8cfa30d02640.tar.gz
stable-diffusion-webui-gfx803-4117afff11c7b0a2162c73ea02be8cfa30d02640.tar.bz2
stable-diffusion-webui-gfx803-4117afff11c7b0a2162c73ea02be8cfa30d02640.zip
Merge branch 'master' into embed-embeddings-in-images
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f12a9696..437acce4 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -282,14 +282,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
- tmp = -opts.CLIP_ignore_last_layers
- if (opts.CLIP_ignore_last_layers == 0):
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
- z = outputs.last_hidden_state
- else:
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
- z = outputs.hidden_states[tmp]
+ outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+ if opts.CLIP_stop_at_last_layers > 1:
+ z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]