diff options
author | Greendayle <Greendayle> | 2022-10-08 14:27:48 +0000 |
---|---|---|
committer | Greendayle <Greendayle> | 2022-10-08 14:27:48 +0000 |
commit | 2e8ba0fa478eb076760dc0fdfc526f6f5f1f98c5 (patch) | |
tree | 18bd69791e0eff3affd59876f8b39e9150aa8e2b /modules/prompt_parser.py | |
parent | 5f12e7efd92ad802742f96788b4be3249ad02829 (diff) | |
parent | 4f33289d0fc5aa3a197f4a4c926d03d44f0d597e (diff) | |
download | stable-diffusion-webui-gfx803-2e8ba0fa478eb076760dc0fdfc526f6f5f1f98c5.tar.gz stable-diffusion-webui-gfx803-2e8ba0fa478eb076760dc0fdfc526f6f5f1f98c5.tar.bz2 stable-diffusion-webui-gfx803-2e8ba0fa478eb076760dc0fdfc526f6f5f1f98c5.zip |
fix conflicts
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r-- | modules/prompt_parser.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f00256f2..15666073 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch)
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|