aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_common.py
diff options
context:
space:
mode:
authorcatboxanon <122327233+catboxanon@users.noreply.github.com>2023-08-13 08:16:48 +0000
committercatboxanon <122327233+catboxanon@users.noreply.github.com>2023-08-13 08:16:48 +0000
commit822597db49218de17e105e62075096284dfcfd41 (patch)
tree9d83c9389d99f262d897ce28ebcac6a0e3089415 /modules/sd_samplers_common.py
parentda80d649fd6a6083be02aca5695367bd25abf0d5 (diff)
downloadstable-diffusion-webui-gfx803-822597db49218de17e105e62075096284dfcfd41.tar.gz
stable-diffusion-webui-gfx803-822597db49218de17e105e62075096284dfcfd41.tar.bz2
stable-diffusion-webui-gfx803-822597db49218de17e105e62075096284dfcfd41.zip
Encode batches separately
Significantly reduces VRAM. This makes encoding more inline with how decoding currently functions.
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r--modules/sd_samplers_common.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 09d1e11e..f9d034ca 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -92,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None):
model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
- x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+ if len(image) > 1:
+ x_latent = torch.stack([
+ model.get_first_stage_encoding(
+ model.encode_first_stage(torch.unsqueeze(img, 0))
+ )[0]
+ for img in image
+ ])
+ else:
+ x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent