From af081211ee93622473ee575de30fed2fd8263c09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 11 Jul 2023 21:16:43 +0300 Subject: getting SD2.1 to run on SDXL repo --- modules/prompt_parser.py | 64 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 12 deletions(-) (limited to 'modules/prompt_parser.py') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 0069d8b0..d7f9e9a9 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps): cond_schedule = [] for i, (end_at_step, _) in enumerate(prompt_schedule): - cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) + if isinstance(conds, dict): + cond = {k: v[i] for k, v in conds.items()} + else: + cond = conds[i] + + cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond)) cache[prompt] = cond_schedule res.append(cond_schedule) @@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) +class DictWithShape(dict): + def __init__(self, x, shape): + super().__init__() + self.update(x) + + @property + def shape(self): + return self["crossattn"].shape + + def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond - res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + is_dict = isinstance(param, dict) + + if is_dict: + dict_cond = param + res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} + res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) + else: + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, entry in enumerate(cond_schedule): if current_step <= entry.end_at_step: target_index = current break - res[i] = cond_schedule[target_index].cond + + if is_dict: + for k, param in cond_schedule[target_index].cond.items(): + res[k][i] = param + else: + res[i] = cond_schedule[target_index].cond return res +def stack_conds(tensors): + # if prompts have wildly different lengths above the limit we'll get tensors of different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + + return torch.stack(tensors) + + + def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): param = c.batch[0][0].schedules[0].cond @@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) - # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes - # and won't be able to torch.stack them. So this fixes that. - token_count = max([x.shape[0] for x in tensors]) - for i in range(len(tensors)): - if tensors[i].shape[0] != token_count: - last_vector = tensors[i][-1:] - last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) - tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + if isinstance(tensors[0], dict): + keys = list(tensors[0].keys()) + stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} + stacked = DictWithShape(stacked, stacked['crossattn'].shape) + else: + stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) - return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + return conds_list, stacked re_attention = re.compile(r""" -- cgit v1.2.3