diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-09-15 10:10:16 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-09-15 10:10:16 +0000 |
commit | f2693bec08d2c2e513cb35fa24402396505a01a9 (patch) | |
tree | e17b80c308467f9c1433bfc632ec7ecd42dae372 /modules | |
parent | b28cf84c3632df4a6d4c110f7c25d68445b64427 (diff) | |
download | stable-diffusion-webui-gfx803-f2693bec08d2c2e513cb35fa24402396505a01a9.tar.gz stable-diffusion-webui-gfx803-f2693bec08d2c2e513cb35fa24402396505a01a9.tar.bz2 stable-diffusion-webui-gfx803-f2693bec08d2c2e513cb35fa24402396505a01a9.zip |
prompt editing
Diffstat (limited to 'modules')
-rw-r--r-- | modules/processing.py | 8 | ||||
-rw-r--r-- | modules/prompt_parser.py | 128 | ||||
-rw-r--r-- | modules/sd_samplers.py | 44 |
3 files changed, 161 insertions, 19 deletions
diff --git a/modules/processing.py b/modules/processing.py index 93138e7c..9b53d210 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure
import modules.sd_hijack
-from modules import devices
+from modules import devices, prompt_parser
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -247,8 +247,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- c = p.sd_model.get_learned_conditioning(prompts)
+ #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
+ #c = p.sd_model.get_learned_conditioning(prompts)
+ uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
+ c = prompt_parser.get_learned_conditioning(prompts, p.steps)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py new file mode 100644 index 00000000..e918fabf --- /dev/null +++ b/modules/prompt_parser.py @@ -0,0 +1,128 @@ +import re
+from collections import namedtuple
+import torch
+
+import modules.shared as shared
+
+re_prompt = re.compile(r'''
+(.*?)
+\[
+ ([^]:]+):
+ (?:([^]:]*):)?
+ ([0-9]*\.?[0-9]+)
+]
+|
+(.+)
+''', re.X)
+
+# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
+# will be represented with prompt_schedule like this (assuming steps=100):
+# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
+# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
+# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
+# [75, 'fantasy landscape with a lake and an oak in background masterful']
+# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
+
+
+def get_learned_conditioning_prompt_schedules(prompts, steps):
+ res = []
+ cache = {}
+
+ for prompt in prompts:
+ prompt_schedule: list[list[str | int]] = [[steps, ""]]
+
+ cached = cache.get(prompt, None)
+ if cached is not None:
+ res.append(cached)
+
+ for m in re_prompt.finditer(prompt):
+ plaintext = m.group(1) if m.group(5) is None else m.group(5)
+ concept_from = m.group(2)
+ concept_to = m.group(3)
+ if concept_to is None:
+ concept_to = concept_from
+ concept_from = ""
+ swap_position = float(m.group(4)) if m.group(4) is not None else None
+
+ if swap_position is not None:
+ if swap_position < 1:
+ swap_position = swap_position * steps
+ swap_position = int(min(swap_position, steps))
+
+ swap_index = None
+ found_exact_index = False
+ for i in range(len(prompt_schedule)):
+ end_step = prompt_schedule[i][0]
+ prompt_schedule[i][1] += plaintext
+
+ if swap_position is not None and swap_index is None:
+ if swap_position == end_step:
+ swap_index = i
+ found_exact_index = True
+
+ if swap_position < end_step:
+ swap_index = i
+
+ if swap_index is not None:
+ if not found_exact_index:
+ prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
+
+ for i in range(len(prompt_schedule)):
+ end_step = prompt_schedule[i][0]
+ must_replace = swap_position < end_step
+
+ prompt_schedule[i][1] += concept_to if must_replace else concept_from
+
+ res.append(prompt_schedule)
+ cache[prompt] = prompt_schedule
+ #for t in prompt_schedule:
+ # print(t)
+
+ return res
+
+
+ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
+ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
+
+
+def get_learned_conditioning(prompts, steps):
+
+ res = []
+
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
+ cache = {}
+
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
+
+ cached = cache.get(prompt, None)
+ if cached is not None:
+ res.append(cached)
+
+ texts = [x[1] for x in prompt_schedule]
+ conds = shared.sd_model.get_learned_conditioning(texts)
+
+ cond_schedule = []
+ for i, (end_at_step, text) in enumerate(prompt_schedule):
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+
+ cache[prompt] = cond_schedule
+ res.append(cond_schedule)
+
+ return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
+
+
+def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
+ res = torch.zeros(c.shape)
+ for i, cond_schedule in enumerate(c.schedules):
+ target_index = 0
+ for curret_index, (end_at, cond) in enumerate(cond_schedule):
+ if current_step <= end_at:
+ target_index = curret_index
+ break
+ res[i] = cond_schedule[target_index].cond
+
+ return res.to(shared.device)
+
+
+
+#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7ef507f1..c042c5c3 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,6 +7,7 @@ from PIL import Image import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
+from modules import prompt_parser
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -53,20 +54,6 @@ def store_latent(decoded): shared.state.current_image = sample_to_image(decoded)
-def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
- if sampler_wrapper.mask is not None:
- img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
- x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
-
- res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
-
- if sampler_wrapper.mask is not None:
- store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
- else:
- store_latent(res[1])
-
- return res
-
def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
@@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler: self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
+
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ if self.mask is not None:
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
+ x_dec = img_orig * self.mask + self.nmask * x_dec
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ if self.mask is not None:
+ store_latent(self.init_latent * self.mask + self.nmask * res[1])
+ else:
+ store_latent(res[1])
+
+ self.step += 1
+ return res
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
@@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
- self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
+ self.sampler.p_sample_ddim = self.p_sample_ddim_hook
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
@@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning):
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = None
self.nmask = None
self.init_latent = None
@@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module): self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
@@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module): if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.step += 1
+
return denoised
|