aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_common.py
diff options
context:
space:
mode:
authorSakura-Luna <53183413+Sakura-Luna@users.noreply.github.com>2023-05-14 04:42:44 +0000
committerSakura-Luna <53183413+Sakura-Luna@users.noreply.github.com>2023-05-14 06:06:01 +0000
commite14b586d0494d6c5cc3cbc45b5fa00c03d052443 (patch)
tree807b3e771ef465654b672956d09d94af525d14ab /modules/sd_samplers_common.py
parentb08500cec8a791ef20082628b49b17df833f5dda (diff)
downloadstable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.tar.gz
stable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.tar.bz2
stable-diffusion-webui-gfx803-e14b586d0494d6c5cc3cbc45b5fa00c03d052443.zip
Add Tiny AE live preview
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r--modules/sd_samplers_common.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index bc074238..d3dc130c 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
+from modules import devices, processing, images, sd_vae_approx, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
- if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
- elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ if approximation == 1:
+ x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample)
+ x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1)
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ if approximation == 3:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 2:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)