diff options
author | InvincibleDude <81354513+InvincibleDude@users.noreply.github.com> | 2023-01-30 12:35:13 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-30 12:35:13 +0000 |
commit | 3ec2eb8bf12ae629c292ed0e96f199669040c5de (patch) | |
tree | fb46cb76c06f4c6a5ad4ad2ce8cd3a4577525be5 /modules/sd_samplers_common.py | |
parent | 0d834b9394bb1a9dbcbdc02a3d4d24d1e6511073 (diff) | |
parent | ee9fdf7f62984dc30770fb1a73e68736b319746f (diff) | |
download | stable-diffusion-webui-gfx803-3ec2eb8bf12ae629c292ed0e96f199669040c5de.tar.gz stable-diffusion-webui-gfx803-3ec2eb8bf12ae629c292ed0e96f199669040c5de.tar.bz2 stable-diffusion-webui-gfx803-3ec2eb8bf12ae629c292ed0e96f199669040c5de.zip |
Merge branch 'master' into improved-hr-conflict-test
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py new file mode 100644 index 00000000..3c03d442 --- /dev/null +++ b/modules/sd_samplers_common.py @@ -0,0 +1,78 @@ +from collections import namedtuple
+import numpy as np
+import torch
+from PIL import Image
+import torchsde._brownian.brownian_interval
+from modules import devices, processing, images, sd_vae_approx
+
+from modules.shared import opts, state
+import modules.shared as shared
+
+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+def setup_img2img_steps(p, steps=None):
+ if opts.img2img_fix_steps or steps is not None:
+ requested_steps = (steps or p.steps)
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
+ t_enc = requested_steps - 1
+ else:
+ steps = p.steps
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
+
+ return steps, t_enc
+
+
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ return Image.fromarray(x_sample)
+
+
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
+
+
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+
+
+def store_latent(decoded):
+ state.current_latent = decoded
+
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if not shared.parallel_processing_allowed:
+ shared.state.assign_current_image(sample_to_image(decoded))
+
+
+class InterruptedException(BaseException):
+ pass
+
+
+# MPS fix for randn in torchsde
+# XXX move this to separate file for MPS
+def torchsde_randn(size, dtype, device, seed):
+ if device.type == 'mps':
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ else:
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+torchsde._brownian.brownian_interval._randn = torchsde_randn
+
|