aboutsummaryrefslogtreecommitdiffstats
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py33
1 files changed, 26 insertions, 7 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 4ecdfcd2..de5cda79 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -12,7 +12,7 @@ import cv2
from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser, masking
+from modules import devices, prompt_parser, masking, lowvram
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -335,7 +335,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
- for n in range(p.n_iter):
+ for n in range(p.n_iter):
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
if state.interrupted:
break
@@ -368,22 +369,32 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ del samples_ddim
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+
+ devices.torch_gc()
+
if opts.filter_nsfw:
import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
- for i, x_sample in enumerate(x_samples_ddim):
+ for i, x_sample in enumerate(x_samples_ddim):
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
- if p.restore_faces:
+ if p.restore_faces:
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
- devices.torch_gc()
-
x_sample = modules.face_restoration.restore_faces(x_sample)
+ devices.torch_gc()
+
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
image = Image.fromarray(x_sample)
if p.color_corrections is not None and i < len(p.color_corrections):
@@ -411,8 +422,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts.append(infotext(n, i))
output_images.append(image)
- state.nextjob()
+ del x_samples_ddim
+ devices.torch_gc()
+
+ state.nextjob()
+
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
p.color_corrections = None
index_of_first_image = 0
@@ -648,4 +664,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
+ del x
+ devices.torch_gc()
+
return samples