diff options
author | unknown <mcgpapu@gmail.com> | 2023-01-30 11:12:31 +0000 |
---|---|---|
committer | unknown <mcgpapu@gmail.com> | 2023-01-30 11:12:31 +0000 |
commit | 21766a0898fd9cc344c5d100396280fa8b0c4e74 (patch) | |
tree | 2a365bb6bc60c9916004e67c04f3816880366edf /modules/sd_samplers_common.py | |
parent | e79b7db4b47a33889551b9266ee3277879d4f560 (diff) | |
parent | aa4688eb8345de583070ca9ddb4c6f585f06762b (diff) | |
download | stable-diffusion-webui-gfx803-21766a0898fd9cc344c5d100396280fa8b0c4e74.tar.gz stable-diffusion-webui-gfx803-21766a0898fd9cc344c5d100396280fa8b0c4e74.tar.bz2 stable-diffusion-webui-gfx803-21766a0898fd9cc344c5d100396280fa8b0c4e74.zip |
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui into gamepad
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py new file mode 100644 index 00000000..3c03d442 --- /dev/null +++ b/modules/sd_samplers_common.py @@ -0,0 +1,78 @@ +from collections import namedtuple
+import numpy as np
+import torch
+from PIL import Image
+import torchsde._brownian.brownian_interval
+from modules import devices, processing, images, sd_vae_approx
+
+from modules.shared import opts, state
+import modules.shared as shared
+
+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+def setup_img2img_steps(p, steps=None):
+ if opts.img2img_fix_steps or steps is not None:
+ requested_steps = (steps or p.steps)
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
+ t_enc = requested_steps - 1
+ else:
+ steps = p.steps
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
+
+ return steps, t_enc
+
+
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ return Image.fromarray(x_sample)
+
+
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
+
+
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+
+
+def store_latent(decoded):
+ state.current_latent = decoded
+
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if not shared.parallel_processing_allowed:
+ shared.state.assign_current_image(sample_to_image(decoded))
+
+
+class InterruptedException(BaseException):
+ pass
+
+
+# MPS fix for randn in torchsde
+# XXX move this to separate file for MPS
+def torchsde_randn(size, dtype, device, seed):
+ if device.type == 'mps':
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ else:
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+torchsde._brownian.brownian_interval._randn = torchsde_randn
+
|