From c938679de7b87b4f14894d9f57fe0f40dd6e3c06 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Wed, 28 Sep 2022 22:14:13 -0300 Subject: Fix memory leak and reduce memory usage --- modules/processing.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 4ecdfcd2..de5cda79 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, lowvram from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -335,7 +335,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): + for n in range(p.n_iter): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if state.interrupted: break @@ -368,22 +369,32 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + del samples_ddim + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + + devices.torch_gc() + if opts.filter_nsfw: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): + for i, x_sample in enumerate(x_samples_ddim): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: + if p.restore_faces: + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") - devices.torch_gc() - x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -411,8 +422,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - state.nextjob() + del x_samples_ddim + devices.torch_gc() + + state.nextjob() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0 @@ -648,4 +664,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask + del x + devices.torch_gc() + return samples -- cgit v1.2.3 From 82380d9ac18614c87bebba1b4cfd4b147cc76a18 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 4 Oct 2022 22:28:50 -0300 Subject: Removing parts no longer needed to fix vram --- modules/processing.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index e7f9c85e..f666ba81 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -345,8 +345,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for n in range(p.n_iter): if state.interrupted: break @@ -395,22 +394,19 @@ def process_images(p: StableDiffusionProcessing) -> Processed: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") - x_sample = modules.face_restoration.restore_faces(x_sample) devices.torch_gc() - devices.torch_gc() + x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() - with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -438,13 +434,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - del x_samples_ddim + del x_samples_ddim - devices.torch_gc() + devices.torch_gc() - state.nextjob() + state.nextjob() - with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0 -- cgit v1.2.3