diff options
-rw-r--r-- | modules/extras.py | 2 | ||||
-rw-r--r-- | modules/processing.py | 16 | ||||
-rw-r--r-- | modules/sd_hijack.py | 4 | ||||
-rw-r--r-- | modules/sd_hijack_optimizations.py | 8 |
4 files changed, 20 insertions, 10 deletions
diff --git a/modules/extras.py b/modules/extras.py index 6a0d5cb0..1d9e64e5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -100,6 +100,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs.append(image)
+ devices.torch_gc()
+
return outputs, plaintext_to_html(info), ''
diff --git a/modules/processing.py b/modules/processing.py index e567956c..de818d5b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,7 +11,7 @@ import cv2 from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -382,6 +382,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ del samples_ddim
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+
+ devices.torch_gc()
+
if opts.filter_nsfw:
import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
@@ -426,6 +433,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i))
output_images.append(image)
+ del x_samples_ddim
+
+ devices.torch_gc()
+
state.nextjob()
p.color_corrections = None
@@ -663,4 +674,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
+ del x
+ devices.torch_gc()
+
return samples
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3fa06242..a6fa890c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,6 +5,7 @@ import traceback import torch
import numpy as np
from torch import einsum
+from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
@@ -19,11 +20,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations():
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
+
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9c079e57..ea4cfdfc 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -92,14 +92,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2)
-def nonlinearity_hijack(x):
- # swish
- t = torch.sigmoid(x)
- x *= t
- del t
-
- return x
-
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
|