aboutsummaryrefslogtreecommitdiffstats
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py53
1 files changed, 26 insertions, 27 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 678c4468..2b8dd361 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -29,12 +29,6 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
-import tomesd
-
-# add a logger for the processing module
-logger = logging.getLogger(__name__)
-# manually set output level here since there is no option to do so yet through launch options
-# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')
# some of those options should not be changed at all because they would break the model, so I removed them from options.
@@ -156,6 +150,8 @@ class StableDiffusionProcessing:
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
self.disable_extra_networks = False
+ self.token_merging_ratio = 0
+ self.token_merging_ratio_hr = 0
if not seed_enable_extras:
self.subseed = -1
@@ -171,6 +167,7 @@ class StableDiffusionProcessing:
self.all_subseeds = None
self.iteration = 0
self.is_hr_pass = False
+ self.sampler = None
@property
@@ -280,6 +277,12 @@ class StableDiffusionProcessing:
def close(self):
self.sampler = None
+ def get_token_merging_ratio(self, for_hr=False):
+ if for_hr:
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
+
+ return self.token_merging_ratio or opts.token_merging_ratio
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -309,6 +312,8 @@ class Processed:
self.styles = p.styles
self.job_timestamp = state.job_timestamp
self.clip_skip = opts.CLIP_stop_at_last_layers
+ self.token_merging_ratio = p.token_merging_ratio
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
@@ -316,6 +321,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
+ self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
@@ -366,6 +372,9 @@ class Processed:
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
@@ -479,6 +488,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
enable_hr = getattr(p, 'enable_hr', False)
+ token_merging_ratio = p.get_token_merging_ratio()
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
uses_ensd = opts.eta_noise_seed_delta != 0
if uses_ensd:
@@ -501,8 +512,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
- "Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
- "Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
@@ -535,17 +546,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
- if opts.token_merging_ratio > 0:
- sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
- logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'")
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
res = process_images_inner(p)
finally:
- # undo model optimizations made by tomesd
- if opts.token_merging_ratio > 0:
- tomesd.remove_patch(p.sd_model)
- logger.debug('Token merging model optimizations removed')
+ sd_models.apply_token_merging(p.sd_model, 0)
# restore opts to original state
if p.override_settings_restore_afterwards:
@@ -995,21 +1001,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- # apply token merging optimizations from tomesd for high-res pass
- if opts.token_merging_ratio_hr > 0:
- # in case the user has used separate merge ratios
- if opts.token_merging_ratio > 0:
- tomesd.remove_patch(self.sd_model)
- logger.debug('Adjusting token merging ratio for high-res pass')
-
- sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
- logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'")
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
- if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0:
- tomesd.remove_patch(self.sd_model)
- logger.debug('Removed token merging optimizations from model')
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.is_hr_pass = False
@@ -1172,3 +1168,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
devices.torch_gc()
return samples
+
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio