aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py5
-rw-r--r--modules/processing.py26
-rw-r--r--modules/sd_samplers.py4
-rw-r--r--modules/shared.py20
-rw-r--r--modules/ui.py10
5 files changed, 45 insertions, 20 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index bb87d795..71c9c160 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -5,10 +5,9 @@ import uvicorn
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
-from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers
+from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.extras import run_extras, run_pnginfo
@@ -179,6 +178,8 @@ class Api:
progress = min(progress, 1)
+ shared.state.set_current_image()
+
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
diff --git a/modules/processing.py b/modules/processing.py
index b541ee2b..3a364b5f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -199,7 +199,7 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
def close(self):
@@ -521,7 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -649,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
@@ -662,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
+ def save_intermediate(image, index):
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ return
+
+ if not isinstance(image, Image.Image):
+ image = sd_samplers.sample_to_image(image, index)
+
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -674,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
+
+ save_intermediate(image, i)
+
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
@@ -831,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
@@ -843,4 +857,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
- return samples \ No newline at end of file
+ return samples
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 44d4c189..c7c414ef 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -93,8 +93,8 @@ def single_sample_to_image(sample):
return Image.fromarray(x_sample)
-def sample_to_image(samples):
- return single_sample_to_image(samples[0])
+def sample_to_image(samples, index=0):
+ return single_sample_to_image(samples[index])
def samples_to_image_grid(samples):
diff --git a/modules/shared.py b/modules/shared.py
index 1ccb269a..01f47e38 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -4,6 +4,7 @@ import json
import os
import sys
from collections import OrderedDict
+import time
import gradio as gr
import tqdm
@@ -135,6 +136,7 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
+ time_start = None
need_restart = False
def skip(self):
@@ -172,6 +174,7 @@ class State:
self.skipped = False
self.interrupted = False
self.textinfo = None
+ self.time_start = time.time()
devices.torch_gc()
@@ -181,6 +184,20 @@ class State:
devices.torch_gc()
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+ def set_current_image(self):
+ if not parallel_processing_allowed:
+ return
+
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
+ if opts.show_progress_grid:
+ self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
+ else:
+ self.current_image = sd_samplers.sample_to_image(self.current_latent)
+
+ self.current_image_sampling_step = self.sampling_step
+
+
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
@@ -238,6 +255,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
+ "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
@@ -304,7 +323,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
diff --git a/modules/ui.py b/modules/ui.py
index 45cd8c3f..784439ba 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -277,15 +277,7 @@ def check_progress_call(id_part):
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
- if shared.parallel_processing_allowed:
-
- if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- if opts.show_progress_grid:
- shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
- else:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
- shared.state.current_image_sampling_step = shared.state.sampling_step
-
+ shared.state.set_current_image()
image = shared.state.current_image
if image is None: