aboutsummaryrefslogtreecommitdiffstats
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py64
1 files changed, 52 insertions, 12 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 0069d8b0..d7f9e9a9 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps):
cond_schedule = []
for i, (end_at_step, _) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+ if isinstance(conds, dict):
+ cond = {k: v[i] for k, v in conds.items()}
+ else:
+ cond = conds[i]
+
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
cache[prompt] = cond_schedule
res.append(cond_schedule)
@@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+class DictWithShape(dict):
+ def __init__(self, x, shape):
+ super().__init__()
+ self.update(x)
+
+ @property
+ def shape(self):
+ return self["crossattn"].shape
+
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
- res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ is_dict = isinstance(param, dict)
+
+ if is_dict:
+ dict_cond = param
+ res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
+ res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
+ else:
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+
for i, cond_schedule in enumerate(c):
target_index = 0
for current, entry in enumerate(cond_schedule):
if current_step <= entry.end_at_step:
target_index = current
break
- res[i] = cond_schedule[target_index].cond
+
+ if is_dict:
+ for k, param in cond_schedule[target_index].cond.items():
+ res[k][i] = param
+ else:
+ res[i] = cond_schedule[target_index].cond
return res
+def stack_conds(tensors):
+ # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
+ return torch.stack(tensors)
+
+
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
@@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
- # and won't be able to torch.stack them. So this fixes that.
- token_count = max([x.shape[0] for x in tensors])
- for i in range(len(tensors)):
- if tensors[i].shape[0] != token_count:
- last_vector = tensors[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
- tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+ if isinstance(tensors[0], dict):
+ keys = list(tensors[0].keys())
+ stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
+ stacked = DictWithShape(stacked, stacked['crossattn'].shape)
+ else:
+ stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
- return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+ return conds_list, stacked
re_attention = re.compile(r"""