From bc607686065b8c7751d1af7c05b960378fa256de Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Tue, 1 Nov 2022 23:26:55 +0800 Subject: Enable override_settings to take effect for hypernetworks --- modules/processing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 57d3a523..86d015af 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -422,13 +422,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): - opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible + opts.data[k] = v # we don't call onchange for simplicity which makes changing model impossible + if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not res = process_images_inner(p) - finally: + finally: # restore opts to original state for k, v in stored_opts.items(): opts.data[k] = v + if k == 'sd_hypernetwork': shared.reload_hypernetworks() return res -- cgit v1.2.3 From 55ca04095845b41bf66333b3b7343e3ea0babed1 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Sun, 6 Nov 2022 16:31:44 +0800 Subject: Resolve conflict --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 86d015af..db35983b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -422,14 +422,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): - opts.data[k] = v # we don't call onchange for simplicity which makes changing model impossible + setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not res = process_images_inner(p) finally: # restore opts to original state for k, v in stored_opts.items(): - opts.data[k] = v + setattr(opts, k, v) if k == 'sd_hypernetwork': shared.reload_hypernetworks() return res -- cgit v1.2.3 From 6fa891b934ba854efa87315baffc4ff458ab2539 Mon Sep 17 00:00:00 2001 From: KEV Date: Mon, 14 Nov 2022 00:25:38 +1000 Subject: Add 'Inpainting strength' to the 'generation_params' dictionary of 'infotext' which is saved into the 'params.txt' or png chunks. Value appears only if 'Denoising strength' appears too. --- modules/processing.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 03c9143d..01d7cbdc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -399,6 +399,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), + "Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else shared.opts.inpainting_mask_weight), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, -- cgit v1.2.3 From 40ae95d53218b3b8f12fca50b5e4e98a1e50af4b Mon Sep 17 00:00:00 2001 From: KEV Date: Mon, 14 Nov 2022 18:05:59 +1000 Subject: Fix retrieving value for 'x/y plot' script. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 01d7cbdc..2fc9fe13 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -399,7 +399,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else shared.opts.inpainting_mask_weight), + "Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, -- cgit v1.2.3 From 9bbe1e3c2e54f64283bb333ebb648d8f40f5d4ee Mon Sep 17 00:00:00 2001 From: Llewellyn Pritchard Date: Wed, 16 Nov 2022 19:19:00 +0200 Subject: Fix unbounded prompt growth scripts that loop --- modules/processing.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 03c9143d..2fd12288 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -450,6 +450,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: modules.sd_hijack.model_hijack.clear_comments() comments = {} + prompt_tmp = p.prompt + negative_prompt_tmp = p.negative_prompt shared.prompt_styles.apply_styles(p) @@ -596,6 +598,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.postprocess(p, res) + p.prompt = prompt_tmp + p.negative_prompt = negative_prompt_tmp + return res -- cgit v1.2.3 From cdc8020d13c5eef099c609b0a911ccf3568afc0d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 12:01:51 +0300 Subject: change StableDiffusionProcessing to internally use sampler name instead of sampler index --- modules/processing.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 03c9143d..be2edf48 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -2,6 +2,7 @@ import json import math import os import sys +import warnings import torch import numpy as np @@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays): return image -def get_correct_sampler(p): - if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): - return sd_samplers.samplers - elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): - return sd_samplers.samplers_for_img2img - elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI): - return sd_samplers.samplers class StableDiffusionProcessing(): """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None): + if sampler_index is not None: + warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name") + self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -91,7 +88,7 @@ class StableDiffusionProcessing(): self.subseed_strength: float = subseed_strength self.seed_resize_from_h: int = seed_resize_from_h self.seed_resize_from_w: int = seed_resize_from_w - self.sampler_index: int = sampler_index + self.sampler_name: str = sampler_name self.batch_size: int = batch_size self.n_iter: int = n_iter self.steps: int = steps @@ -210,8 +207,7 @@ class Processed: self.info = info self.width = p.width self.height = p.height - self.sampler_index = p.sampler_index - self.sampler = sd_samplers.samplers[p.sampler_index].name + self.sampler_name = p.sampler_name self.cfg_scale = p.cfg_scale self.steps = p.steps self.batch_size = p.batch_size @@ -256,8 +252,7 @@ class Processed: "subseed_strength": self.subseed_strength, "width": self.width, "height": self.height, - "sampler_index": self.sampler_index, - "sampler": self.sampler, + "sampler_name": self.sampler_name, "cfg_scale": self.cfg_scale, "steps": self.steps, "batch_size": self.batch_size, @@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": get_correct_sampler(p)[p.sampler_index].name, + "Sampler": p.sampler_name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), @@ -645,7 +640,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -706,7 +701,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -743,7 +738,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) + self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) crop_region = None if self.image_mask is not None: -- cgit v1.2.3 From 0d702930b068ca8da8eb0117613053a480d9439e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 12:47:52 +0300 Subject: renamed Inpainting strength infotext to Conditional mask weight, made it only appear if using inpainting model, made it possible to read the setting from it using the blue arrow button --- modules/processing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index fb30aa81..def95846 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -113,6 +113,7 @@ class StableDiffusionProcessing(): self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_noise = s_noise or opts.s_noise self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} + self.is_using_inpainting_conditioning = False if not seed_enable_extras: self.subseed = -1 @@ -133,6 +134,8 @@ class StableDiffusionProcessing(): # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. return x.new_zeros(x.shape[0], 5, 1, 1) + self.is_using_inpainting_conditioning = True + height = height or self.height width = width or self.width @@ -151,6 +154,8 @@ class StableDiffusionProcessing(): # Dummy zero conditioning if we're not using inpainting model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) + self.is_using_inpainting_conditioning = True + # Handle the different mask inputs if image_mask is not None: if torch.is_tensor(image_mask): @@ -234,6 +239,7 @@ class Processed: self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -268,6 +274,7 @@ class Processed: "styles": self.styles, "job_timestamp": self.job_timestamp, "clip_skip": self.clip_skip, + "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, } return json.dumps(obj) @@ -394,7 +401,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)), + "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, -- cgit v1.2.3 From 617c5b486f42aa73062ee7699ee1147eb995c899 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 13:23:25 +0300 Subject: make it possible for StableDiffusionProcessing to accept multiple different negative prompts in a batch --- modules/processing.py | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 604d822a..bc7e5311 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -124,6 +124,7 @@ class StableDiffusionProcessing(): self.scripts = None self.script_args = None self.all_prompts = None + self.all_negative_prompts = None self.all_seeds = None self.all_subseeds = None @@ -202,7 +203,7 @@ class StableDiffusionProcessing(): class Processed: - def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): + def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt @@ -241,16 +242,18 @@ class Processed: self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning - self.all_prompts = all_prompts or [self.prompt] - self.all_seeds = all_seeds or [self.seed] - self.all_subseeds = all_subseeds or [self.subseed] + self.all_prompts = all_prompts or p.all_prompts or [self.prompt] + self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt] + self.all_seeds = all_seeds or p.all_seeds or [self.seed] + self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] self.infotexts = infotexts or [info] def js(self): obj = { - "prompt": self.prompt, + "prompt": self.all_prompts[0], "all_prompts": self.all_prompts, - "negative_prompt": self.negative_prompt, + "negative_prompt": self.all_negative_prompts[0], + "all_negative_prompts": self.all_negative_prompts, "seed": self.seed, "all_seeds": self.all_seeds, "subseed": self.subseed, @@ -411,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) - negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" + negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else "" return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() @@ -440,10 +443,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: assert p.prompt is not None - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") - file.write(processed.infotext(p, 0)) - devices.torch_gc() seed = get_fixed_seed(p.seed) @@ -453,15 +452,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: modules.sd_hijack.model_hijack.clear_comments() comments = {} - prompt_tmp = p.prompt - negative_prompt_tmp = p.negative_prompt - - shared.prompt_styles.apply_styles(p) if type(p.prompt) == list: - p.all_prompts = p.prompt + p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt] + else: + p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)] + + if type(p.negative_prompt) == list: + p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt] else: - p.all_prompts = p.batch_size * p.n_iter * [p.prompt] + p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)] if type(seed) == list: p.all_seeds = seed @@ -476,6 +476,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: def infotext(iteration=0, position_in_batch=0): return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -500,6 +504,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: break prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] @@ -510,7 +515,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -596,14 +601,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) if p.scripts is not None: p.scripts.postprocess(p, res) - p.prompt = prompt_tmp - p.negative_prompt = negative_prompt_tmp - return res -- cgit v1.2.3 From 413c077969d35bc90a8b3218ab0db7e35c8c46f2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 13:47:37 +0300 Subject: prevent StableDiffusionProcessingImg2Img changing image_mask field as an alternative solution to #4765 --- modules/processing.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index bc7e5311..accb31d1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -740,7 +740,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.denoising_strength: float = denoising_strength self.init_latent = None self.image_mask = mask - #self.image_unblurred_mask = None self.latent_mask = None self.mask_for_overlay = None self.mask_blur = mask_blur @@ -756,36 +755,36 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) crop_region = None - if self.image_mask is not None: - self.image_mask = self.image_mask.convert('L') + image_mask = self.image_mask - if self.inpainting_mask_invert: - self.image_mask = ImageOps.invert(self.image_mask) + if image_mask is not None: + image_mask = image_mask.convert('L') - #self.image_unblurred_mask = self.image_mask + if self.inpainting_mask_invert: + image_mask = ImageOps.invert(image_mask) if self.mask_blur > 0: - self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) + image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) if self.inpaint_full_res: - self.mask_for_overlay = self.image_mask - mask = self.image_mask.convert('L') + self.mask_for_overlay = image_mask + mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) x1, y1, x2, y2 = crop_region mask = mask.crop(crop_region) - self.image_mask = images.resize_image(2, mask, self.width, self.height) + image_mask = images.resize_image(2, mask, self.width, self.height) self.paste_to = (x1, y1, x2-x1, y2-y1) else: - self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) - np_mask = np.array(self.image_mask) + image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) self.mask_for_overlay = Image.fromarray(np_mask) self.overlay_images = [] - latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask + latent_mask = self.latent_mask if self.latent_mask is not None else image_mask add_color_corrections = opts.img2img_color_correction and self.color_corrections is None if add_color_corrections: @@ -797,7 +796,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if crop_region is None: image = images.resize_image(self.resize_mode, image, self.width, self.height) - if self.image_mask is not None: + if image_mask is not None: image_masked = Image.new('RGBa', (image.width, image.height)) image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) @@ -807,7 +806,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = image.crop(crop_region) image = images.resize_image(2, image, self.width, self.height) - if self.image_mask is not None: + if image_mask is not None: if self.inpainting_fill != 1: image = masking.fill(image, latent_mask) @@ -839,7 +838,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) - if self.image_mask is not None: + if image_mask is not None: init_mask = latent_mask latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 @@ -856,7 +855,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) + self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) -- cgit v1.2.3 From 40ca34b837b5068ec35b8d5681bae32cf28f5816 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 13:17:39 +0300 Subject: fix for broken sampler selection in img2img and xy plot #4860 #4909 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index c310df6a..edceb532 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -74,7 +74,7 @@ class StableDiffusionProcessing(): """ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None): if sampler_index is not None: - warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name") + print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) self.sd_model = sd_model self.outpath_samples: str = outpath_samples -- cgit v1.2.3 From 67efee33a6c65e58b3f6c788993d0e68a33e4fd0 Mon Sep 17 00:00:00 2001 From: klimaleksus Date: Mon, 28 Nov 2022 16:29:43 +0500 Subject: Make VAE step sequential to prevent VRAM spikes --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index edceb532..fd995b8a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -530,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) - samples_ddim = samples_ddim.to(devices.dtype_vae) - x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) + x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] + x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) del samples_ddim -- cgit v1.2.3 From 9a8678f61eff172811498a682c171399b7216e12 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Tue, 29 Nov 2022 11:11:29 +0800 Subject: Support changing checkpoint and vae through override_settings --- modules/processing.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index edceb532..a5c72e3d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -20,6 +20,8 @@ import modules.shared as shared import modules.face_restoration import modules.images as images import modules.styles +import modules.sd_models as sd_models +import modules.sd_vae as sd_vae import logging @@ -424,8 +426,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): - setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible - if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not + setattr(opts, k, v) + if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet + if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model + if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE res = process_images_inner(p) @@ -433,6 +437,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: for k, v in stored_opts.items(): setattr(opts, k, v) if k == 'sd_hypernetwork': shared.reload_hypernetworks() + if k == 'sd_model_checkpoint': sd_models.reload_model_weights() + if k == 'sd_vae': sd_vae.reload_vae_weights() return res -- cgit v1.2.3 From a44994e2c926fc1f8479281e5b1e08d7fe9db2bb Mon Sep 17 00:00:00 2001 From: Adi Eyal Date: Wed, 30 Nov 2022 15:23:53 +0200 Subject: Fixed incorrect negative prompt text in infotext Previously only the first negative prompt in all_negative_prompts was being used for infotext. This fixes that by selecting the index-th negative prompt --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index edceb532..0a73ccbb 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -414,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) - negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else "" + negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -- cgit v1.2.3 From cf3e844d1d31d64f3234a0fbdfcac91cc5834657 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 3 Dec 2022 18:05:47 +0300 Subject: add noise strength parameter similar to NAI --- modules/processing.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 3d2c4dc9..b9cb6d32 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -861,6 +861,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = x*shared.opts.initial_noise_multiplier samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) -- cgit v1.2.3 From 358a8628f6abb4ca1e1bfddf122687c6fb13be0c Mon Sep 17 00:00:00 2001 From: Andrew Ryan Date: Thu, 8 Dec 2022 07:09:09 +0000 Subject: Add latent upscale option to img2img Recently, the option to do latent upscale was added to txt2img highres fix. This feature runs by scaling the latent sample of the image, and then running a second pass of img2img. But, in this edition of highres fix, the image and parameters cannot be changed between the first pass and second pass. We might want to do a fixup in img2img before doing the second pass, or might want to run the second pass at a different resolution. This change adds the option for img2img to perform its upscale in latent space, rather than image space, giving very similar results to highres fix with latent upscale. The result is not exactly the same because there is an additional latent -> decoder -> image -> encoder -> latent that won't happen in highres fix, but this conversion has relatively small losses --- modules/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 3d2c4dc9..ab5a34d0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -795,7 +795,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): for img in self.init_images: image = img.convert("RGB") - if crop_region is None: + if crop_region is None and self.resize_mode != 3: image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: @@ -804,6 +804,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.overlay_images.append(image_masked.convert('RGBA')) + # crop_region is not none iif we are doing inpaint full res if crop_region is not None: image = image.crop(crop_region) image = images.resize_image(2, image, self.width, self.height) @@ -840,6 +841,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) + if self.resize_mode == 3: + self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + if image_mask is not None: init_mask = latent_mask latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) -- cgit v1.2.3 From 1ed4f0e22807f3afef925210182cbbee51f0cb2c Mon Sep 17 00:00:00 2001 From: Jay Smith Date: Thu, 8 Dec 2022 18:14:35 -0600 Subject: Depth2img model support --- modules/processing.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 3d2c4dc9..0417ffc5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -21,7 +21,10 @@ import modules.face_restoration import modules.images as images import modules.styles import logging +from ldm.data.util import AddMiDaS +from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion +from einops import repeat, rearrange # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -150,11 +153,26 @@ class StableDiffusionProcessing(): return image_conditioning - def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): - if self.sampler.conditioning_key not in {'hybrid', 'concat'}: - # Dummy zero conditioning if we're not using inpainting model. - return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) + def depth2img_image_conditioning(self, source_image): + # Use the AddMiDaS helper to Format our source image to suit the MiDaS model + transformer = AddMiDaS(model_type="dpt_hybrid") + transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")}) + midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) + midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) + + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning = torch.nn.functional.interpolate( + self.sd_model.depth_model(midas_in), + size=conditioning_image.shape[2:], + mode="bicubic", + align_corners=False, + ) + + (depth_min, depth_max) = torch.aminmax(conditioning) + conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. + return conditioning + def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -191,6 +209,18 @@ class StableDiffusionProcessing(): return image_conditioning + def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely + # identify itself with a field common to all models. The conditioning_key is also hybrid. + if isinstance(self.sd_model, LatentDepth2ImageDiffusion): + return self.depth2img_image_conditioning(source_image) + + if self.sampler.conditioning_key in {'hybrid', 'concat'}: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + + # Dummy zero conditioning if we're not using inpainting or depth model. + return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) + def init(self, all_prompts, all_seeds, all_subseeds): pass -- cgit v1.2.3 From bab91b12798f67c19a2b14dab13a08d5d3e3de26 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:51:26 +0300 Subject: add Noise multiplier option to infotext --- modules/processing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index dd22a2fa..81400d14 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -764,7 +764,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): + def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): super().__init__(**kwargs) self.init_images = init_images @@ -779,6 +779,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.inpaint_full_res = inpaint_full_res self.inpaint_full_res_padding = inpaint_full_res_padding self.inpainting_mask_invert = inpainting_mask_invert + self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier self.mask = None self.nmask = None self.image_conditioning = None @@ -891,7 +892,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - x = x*shared.opts.initial_noise_multiplier + + if self.initial_noise_multiplier != 1.0: + self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier + x *= self.initial_noise_multiplier samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) -- cgit v1.2.3 From 991e2dcee9d6baa66b5c0b1969c4c07407be933a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 14:54:02 +0300 Subject: remove NSFW filter and its dependency; if you still want it, find it in the extensions section --- modules/processing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 81400d14..056c9322 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -571,9 +571,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - if opts.filter_nsfw: - import modules.safety as safety - x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) + if p.scripts is not None: + p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) -- cgit v1.2.3 From 2e8b5418e3cd4e9212f2fcdb36305d7a40f97916 Mon Sep 17 00:00:00 2001 From: ThereforeGames <95403634+ThereforeGames@users.noreply.github.com> Date: Sun, 11 Dec 2022 18:03:36 -0500 Subject: Improve color correction with luminosity blend --- modules/processing.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 24c537d1..bc841837 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -27,6 +27,7 @@ from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange +from blendmodes.blend import blendLayers, BlendType # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -39,17 +40,19 @@ def setup_color_correction(image): return correction_target -def apply_color_correction(correction, image): +def apply_color_correction(correction, original_image): logging.info("Applying color correction.") image = Image.fromarray(cv2.cvtColor(exposure.match_histograms( cv2.cvtColor( - np.asarray(image), + np.asarray(original_image), cv2.COLOR_RGB2LAB ), correction, channel_axis=2 ), cv2.COLOR_LAB2RGB).astype("uint8")) - + + image = blendLayers(image, original_image, BlendType.LUMINOSITY) + return image -- cgit v1.2.3 From 7077428209cd02f7da23ef843e5027e960f6aa39 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 13 Dec 2022 13:05:40 -0800 Subject: Save hypernetwork hash in infotext --- modules/processing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 24c537d1..6dd7491b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -314,7 +314,7 @@ class Processed: return json.dumps(obj) - def infotext(self, p: StableDiffusionProcessing, index): + def infotext(self, p: StableDiffusionProcessing, index): return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) @@ -429,6 +429,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), + "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)), "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), @@ -446,7 +447,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) - negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" + negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -- cgit v1.2.3 From c0355caefe3d82e304e6d832699d581fc8f9fbf9 Mon Sep 17 00:00:00 2001 From: Jim Hays Date: Wed, 14 Dec 2022 21:01:32 -0500 Subject: Fix various typos --- modules/processing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 24c537d1..fe7f4faf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -147,11 +147,11 @@ class StableDiffusionProcessing(): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + image_conditioning = image_conditioning.to(x.dtype) return image_conditioning @@ -199,7 +199,7 @@ class StableDiffusionProcessing(): source_image * (1.0 - conditioning_mask), getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) ) - + # Encode the new masked image using first stage of network. conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) @@ -537,7 +537,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: for n in range(p.n_iter): if state.skipped: state.skipped = False - + if state.interrupted: break @@ -612,7 +612,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image.info["parameters"] = text output_images.append(image) - del x_samples_ddim + del x_samples_ddim devices.torch_gc() @@ -704,7 +704,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] - """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images""" + """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" def save_intermediate(image, index): if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return @@ -720,7 +720,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - # Avoid making the inpainting conditioning unless necessary as + # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) -- cgit v1.2.3 From 22f1527fa79a03dbc8b1a4eec3b22369a877f4bd Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 20 Dec 2022 20:36:49 +1100 Subject: feat(api): add override_settings_restore_afterwards --- modules/processing.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 24c537d1..f7335da2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -77,7 +77,7 @@ class StableDiffusionProcessing(): """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -118,6 +118,7 @@ class StableDiffusionProcessing(): self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_noise = s_noise or opts.s_noise self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} + self.override_settings_restore_afterwards = override_settings_restore_afterwards self.is_using_inpainting_conditioning = False if not seed_enable_extras: @@ -147,11 +148,11 @@ class StableDiffusionProcessing(): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + image_conditioning = image_conditioning.to(x.dtype) return image_conditioning @@ -199,7 +200,7 @@ class StableDiffusionProcessing(): source_image * (1.0 - conditioning_mask), getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) ) - + # Encode the new masked image using first stage of network. conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) @@ -463,12 +464,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed: res = process_images_inner(p) - finally: # restore opts to original state - for k, v in stored_opts.items(): - setattr(opts, k, v) - if k == 'sd_hypernetwork': shared.reload_hypernetworks() - if k == 'sd_model_checkpoint': sd_models.reload_model_weights() - if k == 'sd_vae': sd_vae.reload_vae_weights() + finally: + # restore opts to original state + if p.override_settings_restore_afterwards: + for k, v in stored_opts.items(): + setattr(opts, k, v) + if k == 'sd_hypernetwork': shared.reload_hypernetworks() + if k == 'sd_model_checkpoint': sd_models.reload_model_weights() + if k == 'sd_vae': sd_vae.reload_vae_weights() return res @@ -537,7 +540,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: for n in range(p.n_iter): if state.skipped: state.skipped = False - + if state.interrupted: break @@ -612,7 +615,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image.info["parameters"] = text output_images.append(image) - del x_samples_ddim + del x_samples_ddim devices.torch_gc() @@ -720,7 +723,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - # Avoid making the inpainting conditioning unless necessary as + # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) -- cgit v1.2.3 From 9441c28c947588d756e279a8cd5db6c0b4a8d2e4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 24 Dec 2022 09:46:35 +0300 Subject: add an option for img2img background color --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index bc841837..7c4bcd74 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -832,7 +832,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.color_corrections = [] imgs = [] for img in self.init_images: - image = img.convert("RGB") + image = images.flatten(img, opts.img2img_background_color) if crop_region is None: image = images.resize_image(self.resize_mode, image, self.width, self.height) -- cgit v1.2.3 From c0a8401b5a8368d03bb14fc63abbdedb1e802d8d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 24 Dec 2022 11:12:17 +0300 Subject: rename the option for img2img latent upscale --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 75b0067c..d2288f26 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -846,7 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.overlay_images.append(image_masked.convert('RGBA')) - # crop_region is not none iif we are doing inpaint full res + # crop_region is not None if we are doing inpaint full res if crop_region is not None: image = image.crop(crop_region) image = images.resize_image(2, image, self.width, self.height)