diff options
-rw-r--r-- | modules/devices.py | 3 | ||||
-rw-r--r-- | modules/extras.py | 2 | ||||
-rw-r--r-- | modules/processing.py | 33 | ||||
-rw-r--r-- | modules/sd_hijack.py | 4 | ||||
-rw-r--r-- | modules/sd_hijack_optimizations.py | 8 |
5 files changed, 33 insertions, 17 deletions
diff --git a/modules/devices.py b/modules/devices.py index 0158b11f..6db4e57c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,6 +1,7 @@ import contextlib import torch +import gc from modules import errors @@ -19,8 +20,8 @@ def get_optimal_device(): return cpu - def torch_gc(): + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/modules/extras.py b/modules/extras.py index 6a0d5cb0..1d9e64e5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -100,6 +100,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs.append(image)
+ devices.torch_gc()
+
return outputs, plaintext_to_html(info), ''
diff --git a/modules/processing.py b/modules/processing.py index bb94033b..e7f9c85e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,7 +11,7 @@ import cv2 from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser, masking
+from modules import devices, prompt_parser, masking, lowvram
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -345,7 +345,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1:
state.job_count = p.n_iter
- for n in range(p.n_iter):
+ for n in range(p.n_iter):
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
if state.interrupted:
break
@@ -383,23 +384,33 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ del samples_ddim
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+
+ devices.torch_gc()
+
if opts.filter_nsfw:
import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
- for i, x_sample in enumerate(x_samples_ddim):
+ for i, x_sample in enumerate(x_samples_ddim):
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
- if p.restore_faces:
+ if p.restore_faces:
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
- devices.torch_gc()
-
x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc()
+ devices.torch_gc()
+
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
image = Image.fromarray(x_sample)
if p.color_corrections is not None and i < len(p.color_corrections):
@@ -427,8 +438,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i))
output_images.append(image)
- state.nextjob()
+ del x_samples_ddim
+ devices.torch_gc()
+
+ state.nextjob()
+
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
p.color_corrections = None
index_of_first_image = 0
@@ -663,4 +679,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
+ del x
+ devices.torch_gc()
+
return samples
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3fa06242..a6fa890c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,6 +5,7 @@ import traceback import torch
import numpy as np
from torch import einsum
+from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
@@ -19,11 +20,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations():
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
+
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9c079e57..ea4cfdfc 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -92,14 +92,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2)
-def nonlinearity_hijack(x):
- # swish
- t = torch.sigmoid(x)
- x *= t
- del t
-
- return x
-
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
|