diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-12 20:52:43 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-12 20:52:43 +0000 |
commit | da464a3fb39ecc6ea7b22fe87271194480d8501c (patch) | |
tree | fd67d92762d0490d9d4784aaae3f2a3c2f31c6ca /modules/prompt_parser.py | |
parent | af081211ee93622473ee575de30fed2fd8263c09 (diff) | |
download | stable-diffusion-webui-gfx803-da464a3fb39ecc6ea7b22fe87271194480d8501c.tar.gz stable-diffusion-webui-gfx803-da464a3fb39ecc6ea7b22fe87271194480d8501c.tar.bz2 stable-diffusion-webui-gfx803-da464a3fb39ecc6ea7b22fe87271194480d8501c.zip |
SDXL support
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r-- | modules/prompt_parser.py | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index d7f9e9a9..33810669 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations
+
import re
from collections import namedtuple
from typing import List
@@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-def get_learned_conditioning(model, prompts, steps):
+class SdConditioning(list):
+ """
+ A list with prompts for stable diffusion's conditioner model.
+ Can also specify width and height of created image - SDXL needs it.
+ """
+ def __init__(self, prompts, width=None, height=None):
+ super().__init__()
+ self.extend(prompts)
+ self.width = width or getattr(prompts, 'width', None)
+ self.height = height or getattr(prompts, 'height', None)
+
+
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-def get_multicond_prompt_list(prompts):
+
+def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = []
- prompt_flat_list = []
prompt_indexes = {}
+ prompt_flat_list = SdConditioning(prompts)
+ prompt_flat_list.clear()
for prompt in prompts:
subprompts = re_AND.split(prompt)
@@ -201,6 +217,7 @@ class MulticondLearnedConditioning: self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|