aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
authorJairo Correa <jn.j41r0@gmail.com>2022-10-04 22:53:52 +0000
committerJairo Correa <jn.j41r0@gmail.com>2022-10-04 22:53:52 +0000
commit1f50971fb8c83c255c2819dd0b3f29a46b74f7d9 (patch)
treefd57f40a1ffa2b28105ec0bb3f7f3ab4a742681a /modules
parentad0cc85d1f0bd52877963f296eb1257a0c2b012b (diff)
parentef40e4cd4d383a3405e03f1da3f5b5a1820a8f53 (diff)
downloadstable-diffusion-webui-gfx803-1f50971fb8c83c255c2819dd0b3f29a46b74f7d9.tar.gz
stable-diffusion-webui-gfx803-1f50971fb8c83c255c2819dd0b3f29a46b74f7d9.tar.bz2
stable-diffusion-webui-gfx803-1f50971fb8c83c255c2819dd0b3f29a46b74f7d9.zip
Merge branch 'master' into fix-vram
Diffstat (limited to 'modules')
-rw-r--r--modules/bsrgan_model.py6
-rw-r--r--modules/codeformer_model.py16
-rw-r--r--modules/devices.py15
-rw-r--r--modules/esrgan_model.py9
-rw-r--r--modules/gfpgan_model.py19
-rw-r--r--modules/images.py41
-rw-r--r--modules/img2img.py10
-rw-r--r--modules/processing.py53
-rw-r--r--modules/prompt_parser.py212
-rw-r--r--modules/scunet_model.py8
-rw-r--r--modules/sd_samplers.py2
-rw-r--r--modules/shared.py17
-rw-r--r--modules/textual_inversion/dataset.py7
-rw-r--r--modules/textual_inversion/preprocess.py4
-rw-r--r--modules/textual_inversion/textual_inversion.py2
-rw-r--r--modules/txt2img.py3
-rw-r--r--modules/ui.py25
17 files changed, 275 insertions, 174 deletions
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
index e62c6657..3bd80791 100644
--- a/modules/bsrgan_model.py
+++ b/modules/bsrgan_model.py
@@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
@@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
model = self.load_model(selected_file)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_bsrgan)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 8769e1db..e6d9fa4f 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net
self.face_helper = face_helper
- self.net.to(devices.device_codeformer)
return net, face_helper
+ def send_model_to(self, device):
+ self.net.to(device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1]
@@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None:
return np_image
+ self.send_model_to(devices.device_codeformer)
+
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -97,7 +103,7 @@ def setup_model(dirname):
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
- devices.torch_gc()
+ torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -113,10 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ self.face_helper.clean_all()
+
if shared.opts.face_restoration_unload:
- self.net = None
- self.face_helper = None
- devices.torch_gc()
+ self.send_model_to(devices.cpu)
return restored_img
diff --git a/modules/devices.py b/modules/devices.py
index ebf40082..6db4e57c 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,9 +1,11 @@
+import contextlib
+
import torch
import gc
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
from modules import errors
+# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
@@ -33,8 +35,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = get_optimal_device()
-device_codeformer = cpu if has_mps else device
+device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
def randn(seed, shape):
@@ -58,3 +59,11 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
+
+def autocast():
+ from modules import shared
+
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ return contextlib.nullcontext()
+
+ return torch.autocast("cuda")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 4aed9283..d17e730f 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -6,8 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
-from modules import shared, modelloader, images
-from modules.devices import has_mps
+from modules import shared, modelloader, images, devices
from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
@@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler):
model = self.load_model(selected_model)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
@@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
@@ -127,7 +126,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index a5fd9632..a9452dce 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(shared.device)
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
if gfpgan_constructor is None:
@@ -37,25 +37,32 @@ def gfpgann():
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
- model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
return model
+def send_model_to(model, device):
+ model.gfpgan.to(device)
+ model.face_helper.face_det.to(device)
+ model.face_helper.face_parse.to(device)
+
+
def gfpgan_fix_faces(np_image):
- global loaded_gfpgan_model
model = gfpgann()
if model is None:
return np_image
+
+ send_model_to(model, devices.device_gfpgan)
+
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
+ model.face_helper.clean_all()
+
if shared.opts.face_restoration_unload:
- del model
- loaded_gfpgan_model = None
- devices.torch_gc()
+ send_model_to(model, devices.cpu)
return np_image
diff --git a/modules/images.py b/modules/images.py
index d7563244..c2fadab9 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -287,6 +287,25 @@ def apply_filename_pattern(x, p, seed, prompt):
if seed is not None:
x = x.replace("[seed]", str(seed))
+ if p is not None:
+ x = x.replace("[steps]", str(p.steps))
+ x = x.replace("[cfg]", str(p.cfg_scale))
+ x = x.replace("[width]", str(p.width))
+ x = x.replace("[height]", str(p.height))
+
+ #currently disabled if using the save button, will work otherwise
+ # if enabled it will cause a bug because styles is not included in the save_files data dictionary
+ if hasattr(p, "styles"):
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
+
+ x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
+
+ x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
+ x = x.replace("[date]", datetime.date.today().isoformat())
+ x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
+ x = x.replace("[job_timestamp]", shared.state.job_timestamp)
+
+ # Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x:
@@ -295,7 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")]
for part in style_parts:
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
@@ -306,24 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
words = ["empty"]
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
- if p is not None:
- x = x.replace("[steps]", str(p.steps))
- x = x.replace("[cfg]", str(p.cfg_scale))
- x = x.replace("[width]", str(p.width))
- x = x.replace("[height]", str(p.height))
-
- #currently disabled if using the save button, will work otherwise
- # if enabled it will cause a bug because styles is not included in the save_files data dictionary
- if hasattr(p, "styles"):
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
-
- x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
-
- x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
- x = x.replace("[date]", datetime.date.today().isoformat())
- x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
- x = x.replace("[job_timestamp]", shared.state.job_timestamp)
-
if cmd_opts.hide_ui_dir_config:
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
@@ -379,7 +380,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
- dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
+ dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
diff --git a/modules/img2img.py b/modules/img2img.py
index f4455c90..da212d72 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args):
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
+ save_normally = output_dir == ''
+
p.do_not_save_grid = True
- p.do_not_save_samples = True
+ p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
@@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args):
left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}"
- processed_image.save(os.path.join(output_dir, filename))
+ if not save_normally:
+ processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
@@ -126,4 +129,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout:
print(generation_info_js)
+ if opts.do_not_show_images:
+ processed.images = []
+
return processed.images, generation_info_js, plaintext_to_html(processed.info)
diff --git a/modules/processing.py b/modules/processing.py
index 2f5a2967..e7f9c85e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,3 @@
-import contextlib
import json
import math
import os
@@ -85,7 +84,7 @@ class StableDiffusionProcessing:
self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise
-
+
if not seed_enable_extras:
self.subseed = -1
self.subseed_strength = 0
@@ -249,9 +248,16 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def get_fixed_seed(seed):
+ if seed is None or seed == '' or seed == -1:
+ return int(random.randrange(4294967294))
+
+ return seed
+
+
def fix_seed(p):
- p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
- p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
+ p.seed = get_fixed_seed(p.seed)
+ p.subseed = get_fixed_seed(p.subseed)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
@@ -290,10 +296,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None
-
+
devices.torch_gc()
- fix_seed(p)
+ seed = get_fixed_seed(p.seed)
+ subseed = get_fixed_seed(p.subseed)
if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True)
@@ -312,15 +319,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
all_prompts = p.batch_size * p.n_iter * [p.prompt]
- if type(p.seed) == list:
- all_seeds = p.seed
+ if type(seed) == list:
+ all_seeds = seed
else:
- all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
- if type(p.subseed) == list:
- all_subseeds = p.subseed
+ if type(subseed) == list:
+ all_subseeds = subseed
else:
- all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
+ all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
@@ -330,10 +337,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
- precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
- ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
- with torch.no_grad(), precision_scope("cuda"), ema_scope():
- p.init(all_prompts, all_seeds, all_subseeds)
+
+ with torch.no_grad():
+ with devices.autocast():
+ p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
@@ -352,8 +359,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
- uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
- c = prompt_parser.get_learned_conditioning(prompts, p.steps)
+ with devices.autocast():
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+ c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -362,13 +370,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ with devices.autocast():
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+
if state.interrupted:
# if we are interruped, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
+ samples_ddim = samples_ddim.to(devices.dtype)
+
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -394,6 +406,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
x_sample = modules.face_restoration.restore_faces(x_sample)
+ devices.torch_gc()
devices.torch_gc()
@@ -530,7 +543,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
-
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
return samples
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index e811eb9e..a3b12421 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,19 +1,7 @@
import re
from collections import namedtuple
-import torch
-import modules.shared as shared
-
-re_prompt = re.compile(r'''
-(.*?)
-\[
- ([^]:]+):
- (?:([^]:]*):)?
- ([0-9]*\.?[0-9]+)
-]
-|
-(.+)
-''', re.X)
+import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
@@ -23,71 +11,96 @@ re_prompt = re.compile(r'''
# [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
+schedule_parser = lark.Lark(r"""
+!start: (prompt | /[][():]/+)*
+prompt: (emphasized | scheduled | plain | WHITESPACE)*
+!emphasized: "(" prompt ")"
+ | "(" prompt ":" prompt ")"
+ | "[" prompt "]"
+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
+WHITESPACE: /\s+/
+plain: /([^\\\[\]():]|\\.)+/
+%import common.SIGNED_NUMBER -> NUMBER
+""")
def get_learned_conditioning_prompt_schedules(prompts, steps):
- res = []
- cache = {}
-
- for prompt in prompts:
- prompt_schedule: list[list[str | int]] = [[steps, ""]]
-
- cached = cache.get(prompt, None)
- if cached is not None:
- res.append(cached)
- continue
-
- for m in re_prompt.finditer(prompt):
- plaintext = m.group(1) if m.group(5) is None else m.group(5)
- concept_from = m.group(2)
- concept_to = m.group(3)
- if concept_to is None:
- concept_to = concept_from
- concept_from = ""
- swap_position = float(m.group(4)) if m.group(4) is not None else None
-
- if swap_position is not None:
- if swap_position < 1:
- swap_position = swap_position * steps
- swap_position = int(min(swap_position, steps))
-
- swap_index = None
- found_exact_index = False
- for i in range(len(prompt_schedule)):
- end_step = prompt_schedule[i][0]
- prompt_schedule[i][1] += plaintext
-
- if swap_position is not None and swap_index is None:
- if swap_position == end_step:
- swap_index = i
- found_exact_index = True
-
- if swap_position < end_step:
- swap_index = i
-
- if swap_index is not None:
- if not found_exact_index:
- prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
-
- for i in range(len(prompt_schedule)):
- end_step = prompt_schedule[i][0]
- must_replace = swap_position < end_step
-
- prompt_schedule[i][1] += concept_to if must_replace else concept_from
-
- res.append(prompt_schedule)
- cache[prompt] = prompt_schedule
- #for t in prompt_schedule:
- # print(t)
+ """
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
+ >>> g("test")
+ [[10, 'test']]
+ >>> g("a [b:3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [b: 3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [[[b]]:2]")
+ [[2, 'a '], [10, 'a [[b]]']]
+ >>> g("[(a:2):3]")
+ [[3, ''], [10, '(a:2)']]
+ >>> g("a [b : c : 1] d")
+ [[1, 'a b d'], [10, 'a c d']]
+ >>> g("a[b:[c:d:2]:1]e")
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
+ >>> g("a [unbalanced")
+ [[10, 'a [unbalanced']]
+ >>> g("a [b:.5] c")
+ [[5, 'a c'], [10, 'a b c']]
+ >>> g("a [{b|d{:.5] c") # not handling this right now
+ [[5, 'a c'], [10, 'a {b|d{ c']]
+ >>> g("((a][:b:c [d:3]")
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
+ """
- return res
+ def collect_steps(steps, tree):
+ l = [steps]
+ class CollectSteps(lark.Visitor):
+ def scheduled(self, tree):
+ tree.children[-1] = float(tree.children[-1])
+ if tree.children[-1] < 1:
+ tree.children[-1] *= steps
+ tree.children[-1] = min(steps, int(tree.children[-1]))
+ l.append(tree.children[-1])
+ CollectSteps().visit(tree)
+ return sorted(set(l))
+
+ def at_step(step, tree):
+ class AtStep(lark.Transformer):
+ def scheduled(self, args):
+ before, after, _, when = args
+ yield before or () if step <= when else after
+ def start(self, args):
+ def flatten(x):
+ if type(x) == str:
+ yield x
+ else:
+ for gen in x:
+ yield from flatten(gen)
+ return ''.join(flatten(args))
+ def plain(self, args):
+ yield args[0].value
+ def __default__(self, data, children, meta):
+ for child in children:
+ yield from child
+ return AtStep().transform(tree)
+
+ def get_schedule(prompt):
+ try:
+ tree = schedule_parser.parse(prompt)
+ except lark.exceptions.LarkError as e:
+ if 0:
+ import traceback
+ traceback.print_exc()
+ return [[steps, prompt]]
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
+
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
+ return [promptdict[prompt] for prompt in prompts]
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
-def get_learned_conditioning(prompts, steps):
-
+def get_learned_conditioning(model, prompts, steps):
res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
@@ -101,7 +114,7 @@ def get_learned_conditioning(prompts, steps):
continue
texts = [x[1] for x in prompt_schedule]
- conds = shared.sd_model.get_learned_conditioning(texts)
+ conds = model.get_learned_conditioning(texts)
cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule):
@@ -114,12 +127,13 @@ def get_learned_conditioning(prompts, steps):
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
- res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
+ param = c.schedules[0][0].cond
+ res = torch.zeros(c.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c.schedules):
target_index = 0
- for curret_index, (end_at, cond) in enumerate(cond_schedule):
+ for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at:
- target_index = curret_index
+ target_index = current
break
res[i] = cond_schedule[target_index].cond
@@ -157,23 +171,26 @@ def parse_prompt_attention(text):
\\ - literal character '\'
anything else - just text
- Example:
-
- 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
-
- produces:
-
- [
- ['a ', 1.0],
- ['house', 1.5730000000000004],
- [' ', 1.1],
- ['on', 1.0],
- [' a ', 1.1],
- ['hill', 0.55],
- [', sun, ', 1.1],
- ['sky', 1.4641000000000006],
- ['.', 1.1]
- ]
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\(literal\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
"""
res = []
@@ -215,4 +232,19 @@ def parse_prompt_attention(text):
if len(res) == 0:
res = [["", 1.0]]
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
return res
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
+else:
+ import torch # doctest faster
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 7987ac14..fb64b740 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net
@@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None:
return img
- device = shared.device
+ device = devices.device_scunet
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(device)
img = img.to(device)
with torch.no_grad():
@@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
- device = shared.device
+ device = devices.device_scunet
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 9316875a..dbf570d2 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -127,7 +127,7 @@ class VanillaStableDiffusionSampler:
return res
def initialize(self, p):
- self.eta = p.eta or opts.eta_ddim
+ self.eta = p.eta if p.eta is not None else opts.eta_ddim
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
diff --git a/modules/shared.py b/modules/shared.py
index 7246eadc..e52c9b1d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -12,7 +12,7 @@ import modules.interrogate
import modules.memmon
import modules.sd_models
import modules.styles
-from modules.devices import get_optimal_device
+import modules.devices as devices
from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -46,6 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -54,6 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@@ -63,7 +65,11 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print
cmd_opts = parser.parse_args()
-device = get_optimal_device()
+
+devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
+
+device = devices.device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
@@ -183,7 +189,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
- "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
+ "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -195,7 +201,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
- "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
+ "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
@@ -204,7 +210,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
@@ -224,6 +230,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"return_grid": OptionInfo(True, "Show grid in results for web"),
+ "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index e8394ff6..7c44ea5b 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -9,6 +9,9 @@ from torchvision import transforms
import random
import tqdm
from modules import devices
+import re
+
+re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
@@ -38,8 +41,8 @@ class PersonalizedBase(Dataset):
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
- filename_tokens = [token for token in filename_tokens if token.isalpha()]
+ filename_tokens = os.path.splitext(filename)[0]
+ filename_tokens = re_tag.findall(filename_tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 209e928f..f545a993 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -26,7 +26,9 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
if process_caption:
caption = "-" + shared.interrogator.generate_caption(image)
else:
- caption = ""
+ caption = filename
+ caption = os.path.splitext(caption)[0]
+ caption = os.path.basename(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8686f534..cd9f3498 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
diff --git a/modules/txt2img.py b/modules/txt2img.py
index d4406c3c..e985242b 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -48,5 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.samples_log_stdout:
print(generation_info_js)
+ if opts.do_not_show_images:
+ processed.images = []
+
return processed.images, generation_info_js, plaintext_to_html(processed.info)
diff --git a/modules/ui.py b/modules/ui.py
index d9d02ece..de6342a4 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -69,7 +69,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
-folder_symbol = '\uD83D\uDCC2'
+folder_symbol = '\U0001f4c2' # 📂
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -196,6 +196,11 @@ def wrap_gradio_call(func, extra_outputs=None):
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
+ elapsed_m = int(elapsed // 60)
+ elapsed_s = elapsed % 60
+ elapsed_text = f"{elapsed_s:.2f}s"
+ if (elapsed_m > 0):
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@@ -210,7 +215,7 @@ def wrap_gradio_call(func, extra_outputs=None):
vram_html = ''
# last item is always HTML
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
shared.state.interrupted = False
shared.state.job_count = 0
@@ -386,14 +391,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component]
)
+
def update_token_counter(text, steps):
- prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
+ try:
+ prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
+ except Exception:
+ # a parsing error can happen here during typing, and we don't want to bother the user with
+ # messages related to it in console
+ prompt_schedules = [[[steps, text]]]
+
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step,prompt_text in flat_prompts]
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
+
def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
@@ -636,7 +649,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'):
- init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil")
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
@@ -658,7 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>")
+ gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)