diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-08 12:43:25 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-08 12:43:25 +0000 |
commit | 7001bffe0247804793dfabb69ac96d832572ccd0 (patch) | |
tree | 74db8920aff51b8dbb01cb801267c84bca2af162 /modules/prompt_parser.py | |
parent | 77f4237d1c3af1756e7dab2699e3dcebad5619d6 (diff) | |
download | stable-diffusion-webui-gfx803-7001bffe0247804793dfabb69ac96d832572ccd0.tar.gz stable-diffusion-webui-gfx803-7001bffe0247804793dfabb69ac96d832572ccd0.tar.bz2 stable-diffusion-webui-gfx803-7001bffe0247804793dfabb69ac96d832572ccd0.zip |
fix AND broken for long prompts
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r-- | modules/prompt_parser.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f00256f2..15666073 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch)
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|