From c938679de7b87b4f14894d9f57fe0f40dd6e3c06 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Wed, 28 Sep 2022 22:14:13 -0300 Subject: Fix memory leak and reduce memory usage --- modules/processing.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 4ecdfcd2..de5cda79 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, lowvram from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -335,7 +335,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): + for n in range(p.n_iter): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if state.interrupted: break @@ -368,22 +369,32 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + del samples_ddim + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + + devices.torch_gc() + if opts.filter_nsfw: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): + for i, x_sample in enumerate(x_samples_ddim): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: + if p.restore_faces: + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") - devices.torch_gc() - x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -411,8 +422,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - state.nextjob() + del x_samples_ddim + devices.torch_gc() + + state.nextjob() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0 @@ -648,4 +664,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask + del x + devices.torch_gc() + return samples -- cgit v1.2.3 From 9de1e56e2dbb405213da9c221e0329d27f411691 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 30 Sep 2022 01:44:38 +0100 Subject: add sampler_noise_scheduler_override property --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..1da753a2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -79,7 +79,7 @@ class StableDiffusionProcessing: self.paste_to = None self.color_corrections = None self.denoising_strength: float = 0 - + self.sampler_noise_scheduler_override = None self.ddim_discretize = opts.ddim_discretize self.s_churn = opts.s_churn self.s_tmin = opts.s_tmin @@ -130,7 +130,7 @@ class Processed: self.s_tmin = p.s_tmin self.s_tmax = p.s_tmax self.s_noise = p.s_noise - + self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) -- cgit v1.2.3 From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:03:39 +0300 Subject: initial support for training textual inversion --- .gitignore | 1 + javascript/progressbar.js | 1 + javascript/textualInversion.js | 8 + modules/devices.py | 3 +- modules/processing.py | 13 +- modules/sd_hijack.py | 324 ++++------------------ modules/sd_hijack_optimizations.py | 164 +++++++++++ modules/sd_models.py | 4 +- modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 76 +++++ modules/textual_inversion/textual_inversion.py | 258 +++++++++++++++++ modules/textual_inversion/ui.py | 32 +++ modules/ui.py | 139 ++++++++-- style.css | 10 +- textual_inversion_templates/style.txt | 19 ++ textual_inversion_templates/style_filewords.txt | 19 ++ textual_inversion_templates/subject.txt | 27 ++ textual_inversion_templates/subject_filewords.txt | 27 ++ webui.py | 15 +- 19 files changed, 828 insertions(+), 315 deletions(-) create mode 100644 javascript/textualInversion.js create mode 100644 modules/sd_hijack_optimizations.py create mode 100644 modules/textual_inversion/dataset.py create mode 100644 modules/textual_inversion/textual_inversion.py create mode 100644 modules/textual_inversion/ui.py create mode 100644 textual_inversion_templates/style.txt create mode 100644 textual_inversion_templates/style_filewords.txt create mode 100644 textual_inversion_templates/subject.txt create mode 100644 textual_inversion_templates/subject_filewords.txt (limited to 'modules/processing.py') diff --git a/.gitignore b/.gitignore index 3532dab3..7afc9395 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ __pycache__ /.idea notification.mp3 /SwinIR +/textual_inversion diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 21f25b38..1e297abb 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte onUiUpdate(function(){ check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') + check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery') }) function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js new file mode 100644 index 00000000..8061be08 --- /dev/null +++ b/javascript/textualInversion.js @@ -0,0 +1,8 @@ + + +function start_training_textual_inversion(){ + requestProgress('ti') + gradioApp().querySelector('#ti_error').innerHTML='' + + return args_to_array(arguments) +} diff --git a/modules/devices.py b/modules/devices.py index 07bb2339..ff82f2f6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,10 +32,9 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") - device = get_optimal_device() device_codeformer = cpu if has_mps else device - +dtype = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..8223423a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -56,7 +56,7 @@ class StableDiffusionProcessing: self.prompt: str = prompt self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") - self.styles: str = styles + self.styles: list = styles or [] self.seed: int = seed self.subseed: int = subseed self.subseed_strength: float = subseed_strength @@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), } generation_params.update(p.extra_generation_params) @@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: fix_seed(p) - os.makedirs(p.outpath_samples, exist_ok=True) - os.makedirs(p.outpath_grids, exist_ok=True) + if p.outpath_samples is not None: + os.makedirs(p.outpath_samples, exist_ok=True) + + if p.outpath_grids is not None: + os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) @@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir): - model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) + model_hijack.embedding_db.load_textual_inversion_embeddings() infotexts = [] output_images = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa7eaeb8..fd57e5c5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,244 +6,41 @@ import torch import numpy as np from torch import einsum -from modules import prompt_parser +import modules.textual_inversion.textual_inversion +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts -from ldm.util import default -from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): - h = self.heads - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - del context, x +def apply_optimizations(): + if cmd_opts.opt_split_attention_v1: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale +def undo_optimizations(): + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion -def split_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 class StableDiffusionModelHijack: - ids_lookup = {} - word_embeddings = {} - word_embeddings_checksums = {} fixes = None comments = [] - dir_mtime = None layers = None circular_enabled = False clip = None - def load_textual_inversion_embeddings(self, dirname, model): - mt = os.path.getmtime(dirname) - if self.dir_mtime is not None and mt <= self.dir_mtime: - return - - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - - tokenizer = model.cond_stage_model.tokenizer - - def const_hash(a): - r = 0 - for v in a: - r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF - return r - - def process_file(path, filename): - name = os.path.splitext(filename)[0] - - data = torch.load(path, map_location="cpu") - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - self.word_embeddings[name] = emb.detach().to(device) - self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' - - ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] - - first_id = ids[0] - if first_id not in self.ids_lookup: - self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, name)) - - for fn in os.listdir(dirname): - try: - fullfn = os.path.join(dirname, fn) - - if os.stat(fullfn).st_size == 0: - continue - - process_file(fullfn, fn) - except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -253,12 +50,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model - if cmd_opts.opt_split_attention_v1: - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.hijack = hijack + self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - if possible_matches is None: + if embedding is None: remade_tokens.append(token) multipliers.append(weight) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i + len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: mult *= mult_change - elif possible_matches is None: + i += 1 + elif embedding is None: remade_tokens.append(token) multipliers.append(mult) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i+len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(mult) - - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] @@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments @@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module): inputs_embeds = self.wrapped(input_ids) - if batch_fixes is not None: - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - emb = self.embeddings.word_embeddings[word] - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + emb = embedding.vec + emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + + vecs.append(tensor) - return inputs_embeds + return torch.stack(vecs) def add_circular_option_to_conv_2d(): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py new file mode 100644 index 00000000..9c079e57 --- /dev/null +++ b/modules/sd_hijack_optimizations.py @@ -0,0 +1,164 @@ +import math +import torch +from torch import einsum + +from ldm.util import default +from einops import rearrange + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward_v1(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# taken from https://github.com/Doggettx/stable-diffusion +def split_cross_attention_forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) * self.scale + v_in = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + +def nonlinearity_hijack(x): + # swish + t = torch.sigmoid(x) + x *= t + del t + + return x + +def cross_attention_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 diff --git a/modules/sd_models.py b/modules/sd_models.py index 2539f14c..5b3dbdc7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader +from modules import shared, modelloader, devices from modules.paths import models_path model_dir = "Stable-diffusion" @@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): if not shared.cmd_opts.no_half: model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ac968b2d..ac0bc480 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + textinfo = None def interrupt(self): self.interrupted = True @@ -88,7 +89,7 @@ class State: self.current_image_sampling_step = 0 def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") + return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? state = State() diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py new file mode 100644 index 00000000..7e134a08 --- /dev/null +++ b/modules/textual_inversion/dataset.py @@ -0,0 +1,76 @@ +import os +import numpy as np +import PIL +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random +import tqdm + + +class PersonalizedBase(Dataset): + def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + + self.placeholder_token = placeholder_token + + self.size = size + self.width = width + self.height = height + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.dataset = [] + + with open(template_file, "r") as file: + lines = [x.strip() for x in file.readlines()] + + self.lines = lines + + assert data_root, 'dataset directory not specified' + + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + print("Preparing dataset...") + for path in tqdm.tqdm(self.image_paths): + image = Image.open(path) + image = image.convert('RGB') + image = image.resize((self.width, self.height), PIL.Image.BICUBIC) + + filename = os.path.basename(path) + filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + npimage = np.array(image).astype(np.uint8) + npimage = (npimage / 127.5 - 1.0).astype(np.float32) + + torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) + torchdata = torch.moveaxis(torchdata, 2, 0) + + init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + + self.dataset.append((init_latent, filename_tokens)) + + self.length = len(self.dataset) * repeats + + self.initial_indexes = np.arange(self.length) % len(self.dataset) + self.indexes = None + self.shuffle() + + def shuffle(self): + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i % len(self.dataset) == 0: + self.shuffle() + + index = self.indexes[i % len(self.indexes)] + x, filename_tokens = self.dataset[index] + + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + + return x, text diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py new file mode 100644 index 00000000..c0baaace --- /dev/null +++ b/modules/textual_inversion/textual_inversion.py @@ -0,0 +1,258 @@ +import os +import sys +import traceback + +import torch +import tqdm +import html +import datetime + +from modules import shared, devices, sd_hijack, processing +import modules.textual_inversion.dataset + + +class Embedding: + def __init__(self, vec, name, step=None): + self.vec = vec + self.name = name + self.step = step + self.cached_checksum = None + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + } + + torch.save(embedding_data, filename) + + def checksum(self): + if self.cached_checksum is not None: + return self.cached_checksum + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' + return self.cached_checksum + +class EmbeddingDatabase: + def __init__(self, embeddings_dir): + self.ids_lookup = {} + self.word_embeddings = {} + self.dir_mtime = None + self.embeddings_dir = embeddings_dir + + def register_embedding(self, embedding, model): + + self.word_embeddings[embedding.name] = embedding + + ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] + + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + self.ids_lookup[first_id].append((ids, embedding)) + + return embedding + + def load_textual_inversion_embeddings(self): + mt = os.path.getmtime(self.embeddings_dir) + if self.dir_mtime is not None and mt <= self.dir_mtime: + return + + self.dir_mtime = mt + self.ids_lookup.clear() + self.word_embeddings.clear() + + def process_file(path, filename): + name = os.path.splitext(filename)[0] + + data = torch.load(path, map_location="cpu") + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + self.register_embedding(embedding, shared.sd_model) + + for fn in os.listdir(self.embeddings_dir): + try: + fullfn = os.path.join(self.embeddings_dir, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading emedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + continue + + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + + def find_embedding_at_position(self, tokens, offset): + token = tokens[offset] + possible_matches = self.ids_lookup.get(token, None) + + if possible_matches is None: + return None + + for ids, embedding in possible_matches: + if tokens[offset:offset + len(ids)] == ids: + return embedding + + return None + + + +def create_embedding(name, num_vectors_per_token): + init_text = '*' + + cond_model = shared.sd_model.cond_stage_model + embedding_layer = cond_model.wrapped.transformer.text_model.embeddings + + ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) + + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + + fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + embedding = Embedding(vec, name) + embedding.step = 0 + embedding.save(fn) + + return fn + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): + assert embedding_name, 'embedding not selected' + + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + embedding.vec.requires_grad = True + + optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = embedding.step or 0 + if ititial_step > steps: + return embedding, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, (x, text) in pbar: + embedding.step = i + ititial_step + + if embedding.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: + last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') + embedding.save(last_saved_file) + + if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + embedding.cached_checksum = None + embedding.save(filename) + + return embedding, filename + diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py new file mode 100644 index 00000000..ce3677a9 --- /dev/null +++ b/modules/textual_inversion/ui.py @@ -0,0 +1,32 @@ +import html + +import gradio as gr + +import modules.textual_inversion.textual_inversion as ti +from modules import sd_hijack, shared + + +def create_embedding(name, nvpt): + filename = ti.create_embedding(name, nvpt) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" + + +def train_embedding(*args): + + try: + sd_hijack.undo_optimizations() + + embedding, filename = ti.train_embedding(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps. +Embedding saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + sd_hijack.apply_optimizations() diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..57aef6ff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,6 +21,7 @@ import gradio as gr import gradio.utils import gradio.routes +from modules import sd_hijack from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared @@ -32,6 +33,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +import modules.textual_inversion.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -142,8 +144,8 @@ def save_files(js_data, images, index): return '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func): - def f(*args, **kwargs): +def wrap_gradio_call(func, extra_outputs=None): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled if run_memmon: shared.mem_mon.monitor() @@ -159,7 +161,10 @@ def wrap_gradio_call(func): shared.state.job = "" shared.state.job_count = 0 - res = [None, '', f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] elapsed = time.perf_counter() - t @@ -179,6 +184,7 @@ def wrap_gradio_call(func): res[-1] += f"

Time taken: {elapsed:.2f}s

{vram_html}
" shared.state.interrupted = False + shared.state.job_count = 0 return tuple(res) @@ -187,7 +193,7 @@ def wrap_gradio_call(func): def check_progress_call(id_part): if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False) + return "", gr_show(False), gr_show(False), gr_show(False) progress = 0 @@ -219,13 +225,19 @@ def check_progress_call(id_part): else: preview_visibility = gr_show(True) - return f"

{progressbar}

", preview_visibility, image + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result def check_progress_call_initial(id_part): shared.state.job_count = -1 shared.state.current_latent = None shared.state.current_image = None + shared.state.textinfo = None return check_progress_call(id_part) @@ -399,13 +411,16 @@ def create_toprow(is_img2img): return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste -def setup_progressbar(progressbar, preview, id_part): +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress.click( fn=lambda: check_progress_call(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) @@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part): fn=lambda: check_progress_call_initial(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) -def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) txt2img_args = dict( - fn=txt2img, + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ txt2img_prompt, @@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) img2img_args = dict( - fn=img2img, + fn=wrap_gradio_gpu_call(modules.img2img.img2img), _js="submit_img2img", inputs=[ dummy_component, @@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): open_extras_folder = gr.Button('Open output directory', elem_id=button_id) submit.click( - fn=run_extras, + fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): pnginfo_send_to_img2img = gr.Button('Send to img2img') image.change( - fn=wrap_gradio_call(run_pnginfo), + fn=wrap_gradio_call(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) @@ -887,7 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - + with gr.Row(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") @@ -896,10 +914,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Safe as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - + with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + with gr.Blocks() as textual_inversion_interface: + with gr.Row().style(equal_height=False): + with gr.Column(): + with gr.Group(): + gr.HTML(value="

Create a new embedding

") + + new_embedding_name = gr.Textbox(label="Name") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create", variant='primary') + + with gr.Group(): + gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + learn_rate = gr.Number(label='Learning rate', value=5.0e-03) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + + with gr.Row(): + with gr.Column(scale=2): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_embedding = gr.Button(value="Train", variant='primary') + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + nvpt, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (textual_inversion_interface, "Textual inversion", "ti"), (settings_interface, "Settings", "settings"), ] @@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def modelmerger(*args): try: - results = run_modelmerger(*args) + results = modules.extras.run_modelmerger(*args) except Exception as e: print("Error loading/saving model file:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() #To remove the potentially missing models from the list + modules.sd_models.list_models() # to remove the potentially missing models from the list return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return results diff --git a/style.css b/style.css index 79d6bb0d..39586bf1 100644 --- a/style.css +++ b/style.css @@ -157,7 +157,7 @@ button{ max-width: 10em; } -#txt2img_preview, #img2img_preview{ +#txt2img_preview, #img2img_preview, #ti_preview{ position: absolute; width: 320px; left: 0; @@ -172,18 +172,18 @@ button{ } @media screen and (min-width: 768px) { - #txt2img_preview, #img2img_preview { + #txt2img_preview, #img2img_preview, #ti_preview { position: absolute; } } @media screen and (max-width: 767px) { - #txt2img_preview, #img2img_preview { + #txt2img_preview, #img2img_preview, #ti_preview { position: relative; } } -#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{ +#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{ display: none; } @@ -247,7 +247,7 @@ input[type="range"]{ #txt2img_negative_prompt, #img2img_negative_prompt{ } -#txt2img_progressbar, #img2img_progressbar{ +#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ position: absolute; z-index: 1000; right: 0; diff --git a/textual_inversion_templates/style.txt b/textual_inversion_templates/style.txt new file mode 100644 index 00000000..15af2d6b --- /dev/null +++ b/textual_inversion_templates/style.txt @@ -0,0 +1,19 @@ +a painting, art by [name] +a rendering, art by [name] +a cropped painting, art by [name] +the painting, art by [name] +a clean painting, art by [name] +a dirty painting, art by [name] +a dark painting, art by [name] +a picture, art by [name] +a cool painting, art by [name] +a close-up painting, art by [name] +a bright painting, art by [name] +a cropped painting, art by [name] +a good painting, art by [name] +a close-up painting, art by [name] +a rendition, art by [name] +a nice painting, art by [name] +a small painting, art by [name] +a weird painting, art by [name] +a large painting, art by [name] diff --git a/textual_inversion_templates/style_filewords.txt b/textual_inversion_templates/style_filewords.txt new file mode 100644 index 00000000..b3a8159a --- /dev/null +++ b/textual_inversion_templates/style_filewords.txt @@ -0,0 +1,19 @@ +a painting of [filewords], art by [name] +a rendering of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +the painting of [filewords], art by [name] +a clean painting of [filewords], art by [name] +a dirty painting of [filewords], art by [name] +a dark painting of [filewords], art by [name] +a picture of [filewords], art by [name] +a cool painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a bright painting of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +a good painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a rendition of [filewords], art by [name] +a nice painting of [filewords], art by [name] +a small painting of [filewords], art by [name] +a weird painting of [filewords], art by [name] +a large painting of [filewords], art by [name] diff --git a/textual_inversion_templates/subject.txt b/textual_inversion_templates/subject.txt new file mode 100644 index 00000000..79f36aa0 --- /dev/null +++ b/textual_inversion_templates/subject.txt @@ -0,0 +1,27 @@ +a photo of a [name] +a rendering of a [name] +a cropped photo of the [name] +the photo of a [name] +a photo of a clean [name] +a photo of a dirty [name] +a dark photo of the [name] +a photo of my [name] +a photo of the cool [name] +a close-up photo of a [name] +a bright photo of the [name] +a cropped photo of a [name] +a photo of the [name] +a good photo of the [name] +a photo of one [name] +a close-up photo of the [name] +a rendition of the [name] +a photo of the clean [name] +a rendition of a [name] +a photo of a nice [name] +a good photo of a [name] +a photo of the nice [name] +a photo of the small [name] +a photo of the weird [name] +a photo of the large [name] +a photo of a cool [name] +a photo of a small [name] diff --git a/textual_inversion_templates/subject_filewords.txt b/textual_inversion_templates/subject_filewords.txt new file mode 100644 index 00000000..008652a6 --- /dev/null +++ b/textual_inversion_templates/subject_filewords.txt @@ -0,0 +1,27 @@ +a photo of a [name], [filewords] +a rendering of a [name], [filewords] +a cropped photo of the [name], [filewords] +the photo of a [name], [filewords] +a photo of a clean [name], [filewords] +a photo of a dirty [name], [filewords] +a dark photo of the [name], [filewords] +a photo of my [name], [filewords] +a photo of the cool [name], [filewords] +a close-up photo of a [name], [filewords] +a bright photo of the [name], [filewords] +a cropped photo of a [name], [filewords] +a photo of the [name], [filewords] +a good photo of the [name], [filewords] +a photo of one [name], [filewords] +a close-up photo of the [name], [filewords] +a rendition of the [name], [filewords] +a photo of the clean [name], [filewords] +a rendition of a [name], [filewords] +a photo of a nice [name], [filewords] +a good photo of a [name], [filewords] +a photo of the nice [name], [filewords] +a photo of the small [name], [filewords] +a photo of the weird [name], [filewords] +a photo of the large [name], [filewords] +a photo of a cool [name], [filewords] +a photo of a small [name], [filewords] diff --git a/webui.py b/webui.py index b8cccd54..19fdcdd4 100644 --- a/webui.py +++ b/webui.py @@ -12,7 +12,6 @@ import modules.bsrgan_model as bsrgan import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan -import modules.img2img import modules.ldsr_model as ldsr import modules.lowvram import modules.realesrgan_model as realesrgan @@ -21,7 +20,6 @@ import modules.sd_hijack import modules.sd_models import modules.shared as shared import modules.swinir_model as swinir -import modules.txt2img import modules.ui from modules import modelloader from modules.paths import script_path @@ -46,7 +44,7 @@ def wrap_queued_call(func): return f -def wrap_gradio_gpu_call(func): +def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): devices.torch_gc() @@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func): shared.state.current_image = None shared.state.current_image_sampling_step = 0 shared.state.interrupted = False + shared.state.textinfo = None with queue_lock: res = func(*args, **kwargs) @@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func): return res - return modules.ui.wrap_gradio_call(f) + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) modules.scripts.load_scripts(os.path.join(script_path, "scripts")) @@ -86,13 +85,7 @@ def webui(): signal.signal(signal.SIGINT, sigint_handler) - demo = modules.ui.create_ui( - txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), - img2img=wrap_gradio_gpu_call(modules.img2img.img2img), - run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), - run_pnginfo=modules.extras.run_pnginfo, - run_modelmerger=modules.extras.run_modelmerger - ) + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo.launch( share=cmd_opts.share, -- cgit v1.2.3 From 6c6ae28bf5fd1e8bc3e8f64a3430b6f29f338f77 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 12:32:22 +0300 Subject: send all three of GFPGAN's and codeformer's models to CPU memory instead of just one for #1283 --- modules/codeformer_model.py | 12 ++++++++++-- modules/devices.py | 10 ++++++++++ modules/gfpgan_model.py | 14 ++++++++++++-- modules/processing.py | 16 +++++++++------- 4 files changed, 41 insertions(+), 11 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index a29f3855..e6d9fa4f 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -69,10 +69,14 @@ def setup_model(dirname): self.net = net self.face_helper = face_helper - self.net.to(devices.device_codeformer) return net, face_helper + def send_model_to(self, device): + self.net.to(device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + def restore(self, np_image, w=None): np_image = np_image[:, :, ::-1] @@ -82,6 +86,8 @@ def setup_model(dirname): if self.net is None or self.face_helper is None: return np_image + self.send_model_to(devices.device_codeformer) + self.face_helper.clean_all() self.face_helper.read_image(np_image) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) @@ -113,8 +119,10 @@ def setup_model(dirname): if original_resolution != restored_img.shape[0:2]: restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) + self.face_helper.clean_all() + if shared.opts.face_restoration_unload: - self.net.to(devices.cpu) + self.send_model_to(devices.cpu) return restored_img diff --git a/modules/devices.py b/modules/devices.py index ff82f2f6..12aab665 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,3 +1,5 @@ +import contextlib + import torch # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -57,3 +59,11 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) + +def autocast(): + from modules import shared + + if dtype == torch.float32 or shared.cmd_opts.precision == "full": + return contextlib.nullcontext() + + return torch.autocast("cuda") diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index dd3fbcab..5586b554 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -37,22 +37,32 @@ def gfpgann(): print("Unable to load gfpgan model!") return None model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) - model.gfpgan.to(shared.device) loaded_gfpgan_model = model return model +def send_model_to(model, device): + model.gfpgan.to(device) + model.face_helper.face_det.to(device) + model.face_helper.face_parse.to(device) + + def gfpgan_fix_faces(np_image): model = gfpgann() if model is None: return np_image + + send_model_to(model, devices.device) + np_image_bgr = np_image[:, :, ::-1] cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] + model.face_helper.clean_all() + if shared.opts.face_restoration_unload: - model.gfpgan.to(devices.cpu) + send_model_to(model, devices.cpu) return np_image diff --git a/modules/processing.py b/modules/processing.py index 0a4b6198..9cbecdd8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,3 @@ -import contextlib import json import math import os @@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext - ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + + with torch.no_grad(): p.init(all_prompts, all_seeds, all_subseeds) if state.job_count == -1: @@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) - uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(prompts, p.steps) + with devices.autocast(): + uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) + c = prompt_parser.get_learned_conditioning(prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -361,7 +360,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + with devices.autocast(): + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype) + if state.interrupted: # if we are interruped, sample returns just noise @@ -386,6 +387,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() image = Image.fromarray(x_sample) -- cgit v1.2.3 From 61652461242951966e5b4cee83ce359cefa91c17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 14:23:22 +0300 Subject: support interrupting after the previous change --- modules/processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 9cbecdd8..6f5599c7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -361,7 +361,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) if state.interrupted: @@ -369,6 +369,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent + samples_ddim = samples_ddim.to(devices.dtype) + x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) -- cgit v1.2.3 From 52cef36f6ba169a8e606ecdcaed73d47378f0e8e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 16:54:31 +0300 Subject: emergency fix for img2img --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 6f5599c7..e9c45394 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -331,7 +331,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: output_images = [] with torch.no_grad(): - p.init(all_prompts, all_seeds, all_subseeds) + with devices.autocast(): + p.init(all_prompts, all_seeds, all_subseeds) if state.job_count == -1: state.job_count = p.n_iter -- cgit v1.2.3 From e1b128d8e46bddb9c0b2fd3ee0eefd57e0527ee0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 17:36:39 +0300 Subject: do not touch p.seed/p.subseed during processing #1181 --- modules/processing.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index e9c45394..8180c63d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -248,9 +248,16 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +def get_fixed_seed(seed): + if seed is None or seed == '' or seed == -1: + return int(random.randrange(4294967294)) + + return seed + + def fix_seed(p): - p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed - p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed + p.seed = get_fixed_seed(p.seed) + p.subseed = get_fixed_seed(p.subseed) def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): @@ -292,7 +299,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - fix_seed(p) + seed = get_fixed_seed(p.seed) + subseed = get_fixed_seed(p.subseed) if p.outpath_samples is not None: os.makedirs(p.outpath_samples, exist_ok=True) @@ -311,15 +319,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: else: all_prompts = p.batch_size * p.n_iter * [p.prompt] - if type(p.seed) == list: - all_seeds = p.seed + if type(seed) == list: + all_seeds = seed else: - all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] + all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] - if type(p.subseed) == list: - all_subseeds = p.subseed + if type(subseed) == list: + all_subseeds = subseed else: - all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))] + all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] def infotext(iteration=0, position_in_batch=0): return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) -- cgit v1.2.3 From 90e911fd546e76f879b38a764473569911a0f845 Mon Sep 17 00:00:00 2001 From: Rae Fu Date: Tue, 4 Oct 2022 09:49:51 -0600 Subject: prompt_parser: allow spaces in schedules, add test, log/ignore errors Only build the parser once (at import time) instead of for each step. doctest is run by simply executing modules/prompt_parser.py --- modules/processing.py | 10 ++-- modules/prompt_parser.py | 139 ++++++++++++++++++++++++++++++----------------- 2 files changed, 95 insertions(+), 54 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 8180c63d..bb94033b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -84,7 +84,7 @@ class StableDiffusionProcessing: self.s_tmin = opts.s_tmin self.s_tmax = float('inf') # not representable as a standard ui option self.s_noise = opts.s_noise - + if not seed_enable_extras: self.subseed = -1 self.subseed_strength = 0 @@ -296,7 +296,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: assert(len(p.prompt) > 0) else: assert p.prompt is not None - + devices.torch_gc() seed = get_fixed_seed(p.seed) @@ -359,8 +359,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(prompts, p.steps) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -527,7 +527,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) return samples diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 5d58c4ed..a3b12421 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,10 +1,7 @@ import re from collections import namedtuple -import torch -from lark import Lark, Transformer, Visitor -import functools -import modules.shared as shared +import lark # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): @@ -14,25 +11,48 @@ import modules.shared as shared # [75, 'fantasy landscape with a lake and an oak in background masterful'] # [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] +schedule_parser = lark.Lark(r""" +!start: (prompt | /[][():]/+)* +prompt: (emphasized | scheduled | plain | WHITESPACE)* +!emphasized: "(" prompt ")" + | "(" prompt ":" prompt ")" + | "[" prompt "]" +scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" +WHITESPACE: /\s+/ +plain: /([^\\\[\]():]|\\.)+/ +%import common.SIGNED_NUMBER -> NUMBER +""") def get_learned_conditioning_prompt_schedules(prompts, steps): - grammar = r""" - start: prompt - prompt: (emphasized | scheduled | weighted | plain)* - !emphasized: "(" prompt ")" - | "(" prompt ":" prompt ")" - | "[" prompt "]" - scheduled: "[" (prompt ":")? prompt ":" NUMBER "]" - !weighted: "{" weighted_item ("|" weighted_item)* "}" - !weighted_item: prompt (":" prompt)? - plain: /([^\\\[\](){}:|]|\\.)+/ - %import common.SIGNED_NUMBER -> NUMBER """ - parser = Lark(grammar, parser='lalr') + >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] + >>> g("test") + [[10, 'test']] + >>> g("a [b:3]") + [[3, 'a '], [10, 'a b']] + >>> g("a [b: 3]") + [[3, 'a '], [10, 'a b']] + >>> g("a [[[b]]:2]") + [[2, 'a '], [10, 'a [[b]]']] + >>> g("[(a:2):3]") + [[3, ''], [10, '(a:2)']] + >>> g("a [b : c : 1] d") + [[1, 'a b d'], [10, 'a c d']] + >>> g("a[b:[c:d:2]:1]e") + [[1, 'abe'], [2, 'ace'], [10, 'ade']] + >>> g("a [unbalanced") + [[10, 'a [unbalanced']] + >>> g("a [b:.5] c") + [[5, 'a c'], [10, 'a b c']] + >>> g("a [{b|d{:.5] c") # not handling this right now + [[5, 'a c'], [10, 'a {b|d{ c']] + >>> g("((a][:b:c [d:3]") + [[3, '((a][:b:c '], [10, '((a][:b:c d']] + """ def collect_steps(steps, tree): l = [steps] - class CollectSteps(Visitor): + class CollectSteps(lark.Visitor): def scheduled(self, tree): tree.children[-1] = float(tree.children[-1]) if tree.children[-1] < 1: @@ -43,13 +63,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): return sorted(set(l)) def at_step(step, tree): - class AtStep(Transformer): + class AtStep(lark.Transformer): def scheduled(self, args): - if len(args) == 2: - before, after, when = (), *args - else: - before, after, when = args - yield before if step <= when else after + before, after, _, when = args + yield before or () if step <= when else after def start(self, args): def flatten(x): if type(x) == str: @@ -57,16 +74,22 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): else: for gen in x: yield from flatten(gen) - return ''.join(flatten(args[0])) + return ''.join(flatten(args)) def plain(self, args): yield args[0].value def __default__(self, data, children, meta): for child in children: yield from child return AtStep().transform(tree) - + def get_schedule(prompt): - tree = parser.parse(prompt) + try: + tree = schedule_parser.parse(prompt) + except lark.exceptions.LarkError as e: + if 0: + import traceback + traceback.print_exc() + return [[steps, prompt]] return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)} @@ -77,8 +100,7 @@ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) -def get_learned_conditioning(prompts, steps): - +def get_learned_conditioning(model, prompts, steps): res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) @@ -92,7 +114,7 @@ def get_learned_conditioning(prompts, steps): continue texts = [x[1] for x in prompt_schedule] - conds = shared.sd_model.get_learned_conditioning(texts) + conds = model.get_learned_conditioning(texts) cond_schedule = [] for i, (end_at_step, text) in enumerate(prompt_schedule): @@ -105,12 +127,13 @@ def get_learned_conditioning(prompts, steps): def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): - res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype) + param = c.schedules[0][0].cond + res = torch.zeros(c.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c.schedules): target_index = 0 - for curret_index, (end_at, cond) in enumerate(cond_schedule): + for current, (end_at, cond) in enumerate(cond_schedule): if current_step <= end_at: - target_index = curret_index + target_index = current break res[i] = cond_schedule[target_index].cond @@ -148,23 +171,26 @@ def parse_prompt_attention(text): \\ - literal character '\' anything else - just text - Example: - - 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' - - produces: - - [ - ['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1] - ] + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] """ res = [] @@ -206,4 +232,19 @@ def parse_prompt_attention(text): if len(res) == 0: res = [["", 1.0]] + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + return res + +if __name__ == "__main__": + import doctest + doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) +else: + import torch # doctest faster -- cgit v1.2.3 From 82380d9ac18614c87bebba1b4cfd4b147cc76a18 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 4 Oct 2022 22:28:50 -0300 Subject: Removing parts no longer needed to fix vram --- modules/devices.py | 3 +-- modules/processing.py | 21 ++++++++------------- 2 files changed, 9 insertions(+), 15 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/devices.py b/modules/devices.py index 6db4e57c..0158b11f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ import contextlib import torch -import gc from modules import errors @@ -20,8 +19,8 @@ def get_optimal_device(): return cpu + def torch_gc(): - gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/modules/processing.py b/modules/processing.py index e7f9c85e..f666ba81 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -345,8 +345,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for n in range(p.n_iter): if state.interrupted: break @@ -395,22 +394,19 @@ def process_images(p: StableDiffusionProcessing) -> Processed: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") - x_sample = modules.face_restoration.restore_faces(x_sample) devices.torch_gc() - devices.torch_gc() + x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() - with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -438,13 +434,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - del x_samples_ddim + del x_samples_ddim - devices.torch_gc() + devices.torch_gc() - state.nextjob() + state.nextjob() - with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0 -- cgit v1.2.3 From c26732fbee2a57e621ac22bf70decf7496daa4cd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 5 Oct 2022 23:16:27 +0300 Subject: added support for AND from https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ --- modules/processing.py | 2 +- modules/prompt_parser.py | 114 ++++++++++++++++++++++++++++++++++++++++++++--- modules/sd_samplers.py | 35 ++++++++++----- modules/ui.py | 6 ++- 4 files changed, 138 insertions(+), 19 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index bb94033b..d8c6b8d5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -360,7 +360,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps) + c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a3b12421..f7420daf 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -97,10 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) def get_learned_conditioning(model, prompts, steps): + """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), + and the sampling step at which this condition is to be replaced by the next one. + + Input: + (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20) + + Output: + [ + [ + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0')) + ], + [ + ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')), + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0')) + ] + ] + """ res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) @@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps): cache[prompt] = cond_schedule res.append(cond_schedule) - return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res) + return res + + +re_AND = re.compile(r"\bAND\b") +re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$") + + +def get_multicond_prompt_list(prompts): + res_indexes = [] + + prompt_flat_list = [] + prompt_indexes = {} + + for prompt in prompts: + subprompts = re_AND.split(prompt) + + indexes = [] + for subprompt in subprompts: + text, weight = re_weight.search(subprompt).groups() + + weight = float(weight) if weight is not None else 1.0 + + index = prompt_indexes.get(text, None) + if index is None: + index = len(prompt_flat_list) + prompt_flat_list.append(text) + prompt_indexes[text] = index + + indexes.append((index, weight)) + + res_indexes.append(indexes) + + return res_indexes, prompt_flat_list, prompt_indexes + + +class ComposableScheduledPromptConditioning: + def __init__(self, schedules, weight=1.0): + self.schedules: list[ScheduledPromptConditioning] = schedules + self.weight: float = weight + + +class MulticondLearnedConditioning: + def __init__(self, shape, batch): + self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS + self.batch: list[list[ComposableScheduledPromptConditioning]] = batch -def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): - param = c.schedules[0][0].cond - res = torch.zeros(c.shape, device=param.device, dtype=param.dtype) - for i, cond_schedule in enumerate(c.schedules): +def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: + """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. + For each prompt, the list is obtained by splitting the prompt using the AND separator. + + https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ + """ + + res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) + + learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) + + res = [] + for indexes in res_indexes: + res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes]) + + return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) + + +def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): + param = c[0][0].cond + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, (end_at, cond) in enumerate(cond_schedule): if current_step <= end_at: @@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res +def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): + param = c.batch[0][0].schedules[0].cond + + tensors = [] + conds_list = [] + + for batch_no, composable_prompts in enumerate(c.batch): + conds_for_batch = [] + + for cond_index, composable_prompt in enumerate(composable_prompts): + target_index = 0 + for current, (end_at, cond) in enumerate(composable_prompt.schedules): + if current_step <= end_at: + target_index = current + break + + conds_for_batch.append((len(tensors), composable_prompt.weight)) + tensors.append(composable_prompt.schedules[target_index].cond) + + conds_list.append(conds_for_batch) + + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + + re_attention = re.compile(r""" \\\(| \\\)| diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index dbf570d2..d27c547b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -109,9 +109,12 @@ class VanillaStableDiffusionSampler: return 0 def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' + cond = tensor + if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec @@ -183,19 +186,31 @@ class CFGDenoiser(torch.nn.Module): self.step = 0 def forward(self, x, sigma, uncond, cond, cond_scale): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + cond_in = torch.cat([tensor, uncond]) + if shared.batch_cond_uncond: - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - denoised = uncond + (cond - uncond) * cond_scale + x_out = self.inner_model(x_in, sigma_in, cond=cond_in) else: - uncond = self.inner_model(x, sigma, cond=uncond) - cond = self.inner_model(x, sigma, cond=cond) - denoised = uncond + (cond - uncond) * cond_scale + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + + denoised_uncond = x_out[-batch_size:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised diff --git a/modules/ui.py b/modules/ui.py index 523ab25b..9620350f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -34,7 +34,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste -from modules.prompt_parser import get_learned_conditioning_prompt_schedules +from modules import prompt_parser from modules.images import apply_filename_pattern, get_next_sequence_number import modules.textual_inversion.ui @@ -394,7 +394,9 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: - prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + except Exception: # a parsing error can happen here during typing, and we don't want to bother the user with # messages related to it in console -- cgit v1.2.3 From 5f24b7bcf4a074fbdec757617fcd1bc82e76551b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 12:08:48 +0300 Subject: option to let users select which samplers they want to hide --- modules/processing.py | 13 ++++++------- modules/sd_samplers.py | 19 +++++++++++++++++-- modules/shared.py | 15 +++++++++------ webui.py | 4 +++- 4 files changed, 35 insertions(+), 16 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index d8c6b8d5..e01c8b3f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,9 +11,8 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, sd_samplers from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.face_restoration @@ -110,7 +109,7 @@ class Processed: self.width = p.width self.height = p.height self.sampler_index = p.sampler_index - self.sampler = samplers[p.sampler_index].name + self.sampler = sd_samplers.samplers[p.sampler_index].name self.cfg_scale = p.cfg_scale self.steps = p.steps self.batch_size = p.batch_size @@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": samplers[p.sampler_index].name, + "Sampler": sd_samplers.samplers[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), @@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d27c547b..2e1f7715 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -32,12 +32,27 @@ samplers_data_k_diffusion = [ if hasattr(k_diffusion.sampling, funcname) ] -samplers = [ +all_samplers = [ *samplers_data_k_diffusion, SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), ] -samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']] + +samplers = [] +samplers_for_img2img = [] + + +def set_samplers(): + global samplers, samplers_for_img2img + + hidden = set(opts.hide_samplers) + hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive']) + + samplers = [x for x in all_samplers if x.name not in hidden] + samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + + +set_samplers() sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], diff --git a/modules/shared.py b/modules/shared.py index bab0fe6e..ca2e4c74 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,6 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices +from modules import sd_samplers from modules.paths import script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -238,14 +239,16 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), + "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) + class Options: data = None data_labels = options_templates diff --git a/webui.py b/webui.py index 47848ba5..9ef12427 100644 --- a/webui.py +++ b/webui.py @@ -2,7 +2,7 @@ import os import threading import time import importlib -from modules import devices +from modules import devices, sd_samplers from modules.paths import script_path import signal import threading @@ -109,6 +109,8 @@ def webui(): time.sleep(0.5) break + sd_samplers.set_samplers() + print('Reloading Custom Scripts') modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) print('Reloading modules: modules.ui') -- cgit v1.2.3 From 5993df24a1026225cb8af89237547c1d9101ce69 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 14:12:52 +0300 Subject: integrate the new samplers PR --- modules/processing.py | 7 ++-- modules/sd_samplers.py | 59 +++++++++++++++------------- modules/shared.py | 1 - scripts/alternate_sampler_noise_schedules.py | 53 ------------------------- scripts/img2imgalt.py | 3 +- 5 files changed, 36 insertions(+), 87 deletions(-) delete mode 100644 scripts/alternate_sampler_noise_schedules.py (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index e01c8b3f..e567956c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -477,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -520,7 +520,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -555,7 +556,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8d6eb762..497df943 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -13,46 +13,46 @@ from modules.shared import opts, cmd_opts, state import modules.shared as shared -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases']) +SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a']), - ('Euler', 'sample_euler', ['k_euler']), - ('LMS', 'sample_lms', ['k_lms']), - ('Heun', 'sample_heun', ['k_heun']), - ('DPM2', 'sample_dpm_2', ['k_dpm_2']), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), + ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}), + ('Euler', 'sample_euler', ['k_euler'], {}), + ('LMS', 'sample_lms', ['k_lms'], {}), + ('Heun', 'sample_heun', ['k_heun'], {}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), + ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ] -if opts.show_karras_scheduler_variants: - k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2 - k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral - k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms - samplers_k_diffusion_ka = [ - ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']), - ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']), - ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']), - ] - samplers_k_diffusion.extend(samplers_k_diffusion_ka) - samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) - for label, funcname, aliases in samplers_k_diffusion + SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_k_diffusion if hasattr(k_diffusion.sampling, funcname) ] all_samplers = [ *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), + SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), ] samplers = [] samplers_for_img2img = [] +def create_sampler_with_index(list_of_configs, index, model): + config = list_of_configs[index] + sampler = config.constructor(model) + sampler.config = config + + return sampler + + def set_samplers(): global samplers, samplers_for_img2img @@ -130,6 +130,7 @@ class VanillaStableDiffusionSampler: self.step = 0 self.eta = None self.default_eta = 0.0 + self.config = None def number_of_needed_noises(self, p): return 0 @@ -291,6 +292,7 @@ class KDiffusionSampler: self.stop_at = None self.eta = None self.default_eta = 1.0 + self.config = None def callback_state(self, d): store_latent(d["denoised"]) @@ -355,11 +357,12 @@ class KDiffusionSampler: steps = steps or p.steps if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.funcname.endswith('ka'): - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) + sigmas = p.sampler_noise_scheduler_override(steps) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: - sigmas = self.model_wrap.get_sigmas(steps) + sigmas = self.model_wrap.get_sigmas(steps) + x = x * sigmas[0] extra_params_kwargs = self.initialize(p) diff --git a/modules/shared.py b/modules/shared.py index 9e4860a2..ca2e4c74 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,6 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { diff --git a/scripts/alternate_sampler_noise_schedules.py b/scripts/alternate_sampler_noise_schedules.py deleted file mode 100644 index 4f3ed8fb..00000000 --- a/scripts/alternate_sampler_noise_schedules.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from modules.processing import Processed, process_images -import gradio as gr -import modules.scripts as scripts -import k_diffusion.sampling -import torch - - -class Script(scripts.Script): - - def title(self): - return "Alternate Sampler Noise Schedules" - - def ui(self, is_img2img): - noise_scheduler = gr.Dropdown(label="Noise Scheduler", choices=['Default','Karras','Exponential', 'Variance Preserving'], value='Default', type="index") - sched_smin = gr.Slider(value=0.1, label="Sigma min", minimum=0.0, maximum=100.0, step=0.5,) - sched_smax = gr.Slider(value=10.0, label="Sigma max", minimum=0.0, maximum=100.0, step=0.5) - sched_rho = gr.Slider(value=7.0, label="Sigma rho (Karras only)", minimum=7.0, maximum=100.0, step=0.5) - sched_beta_d = gr.Slider(value=19.9, label="Beta distribution (VP only)",minimum=0.0, maximum=40.0, step=0.5) - sched_beta_min = gr.Slider(value=0.1, label="Beta min (VP only)", minimum=0.0, maximum=40.0, step=0.1) - sched_eps_s = gr.Slider(value=0.001, label="Epsilon (VP only)", minimum=0.001, maximum=1.0, step=0.001) - - return [noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s] - - def run(self, p, noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s): - - noise_scheduler_func_name = ['-','get_sigmas_karras','get_sigmas_exponential','get_sigmas_vp'][noise_scheduler] - - base_params = { - "sigma_min":sched_smin, - "sigma_max":sched_smax, - "rho":sched_rho, - "beta_d":sched_beta_d, - "beta_min":sched_beta_min, - "eps_s":sched_eps_s, - "device":"cuda" if torch.cuda.is_available() else "cpu" - } - - if hasattr(k_diffusion.sampling,noise_scheduler_func_name): - - sigma_func = getattr(k_diffusion.sampling,noise_scheduler_func_name) - sigma_func_kwargs = {} - - for k,v in base_params.items(): - if k in inspect.signature(sigma_func).parameters: - sigma_func_kwargs[k] = v - - def substitute_noise_scheduler(n): - return sigma_func(n,**sigma_func_kwargs) - - p.sampler_noise_scheduler_override = substitute_noise_scheduler - - return process_images(p) diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 0ef137f7..f9894cb0 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -8,7 +8,6 @@ import gradio as gr from modules import processing, shared, sd_samplers, prompt_parser from modules.processing import Processed -from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state import torch @@ -159,7 +158,7 @@ class Script(scripts.Script): combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) - sampler = samplers[p.sampler_index].constructor(p.sd_model) + sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model) sigmas = sampler.model_wrap.get_sigmas(p.steps) -- cgit v1.2.3 From dbc8a4d35129b08eab30776bbbaf3a2e7ac10a6c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 20:27:50 +0300 Subject: add generation parameters to images shown in web ui --- modules/processing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index de818d5b..8faf9095 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -430,7 +430,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) - infotexts.append(infotext(n, i)) + text = infotext(n, i) + infotexts.append(text) + image.info["parameters"] = text output_images.append(image) del x_samples_ddim @@ -447,7 +449,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: grid = images.image_grid(output_images, p.batch_size) if opts.return_grid: - infotexts.insert(0, infotext()) + text = infotext() + infotexts.insert(0, text) + grid.info["parameters"] = text output_images.insert(0, grid) index_of_first_image = 1 -- cgit v1.2.3