diff options
Diffstat (limited to 'modules/ui.py')
-rw-r--r-- | modules/ui.py | 59 |
1 files changed, 51 insertions, 8 deletions
diff --git a/modules/ui.py b/modules/ui.py index 1df74070..8e7a3ee4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -9,6 +9,8 @@ import sys import time
import traceback
+import numpy as np
+import torch
from PIL import Image
import gradio as gr
@@ -119,6 +121,9 @@ def wrap_gradio_call(func): print("Arguments:", args, kwargs, file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ shared.state.job = ""
+ shared.state.job_count = 0
+
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
@@ -134,11 +139,9 @@ def wrap_gradio_call(func): def check_progress_call():
- if not opts.show_progressbar:
- return ""
if shared.state.job_count == 0:
- return ""
+ return "", gr_show(False), gr_show(False)
progress = 0
@@ -149,9 +152,29 @@ def check_progress_call(): progress = min(progress, 1)
- progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
+ progressbar = ""
+ if opts.show_progressbar:
+ progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
+
+ image = gr_show(False)
+ preview_visibility = gr_show(False)
+
+ if opts.show_progress_every_n_steps > 0:
+ if (shared.state.sampling_step-1) % opts.show_progress_every_n_steps == 0 and shared.state.current_latent is not None:
+ x_sample = shared.sd_model.decode_first_stage(shared.state.current_latent[0:1].type(shared.sd_model.dtype))[0]
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ shared.state.current_image = Image.fromarray(x_sample)
- return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>"
+ image = shared.state.current_image
+
+ if image is None or progress >= 1:
+ image = gr.update(value=None)
+ else:
+ preview_visibility = gr_show(True)
+
+ return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
def roll_artist(prompt):
@@ -204,6 +227,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Column(variant='panel'):
with gr.Group():
+ txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery')
@@ -251,8 +275,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): check_progress.click(
fn=check_progress_call,
+ show_progress=False,
inputs=[],
- outputs=[progressbar],
+ outputs=[progressbar, txt2img_preview, txt2img_preview],
)
@@ -337,13 +362,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Column(variant='panel'):
with gr.Group():
+ img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery')
with gr.Group():
with gr.Row():
- interrupt = gr.Button('Interrupt')
save = gr.Button('Save')
+ img2img_send_to_img2img = gr.Button('Send to img2img')
+ img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras')
+ interrupt = gr.Button('Interrupt')
progressbar = gr.HTML(elem_id="progressbar")
@@ -426,8 +454,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): check_progress.click(
fn=check_progress_call,
+ show_progress=False,
inputs=[],
- outputs=[progressbar],
+ outputs=[progressbar, img2img_preview, img2img_preview],
)
interrupt.click(
@@ -463,6 +492,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): outputs=[init_img_with_mask],
)
+ img2img_send_to_img2img.click(
+ fn=lambda x: image_from_url_text(x),
+ _js="extract_image_from_gallery",
+ inputs=[img2img_gallery],
+ outputs=[init_img],
+ )
+
+ img2img_send_to_inpaint.click(
+ fn=lambda x: image_from_url_text(x),
+ _js="extract_image_from_gallery",
+ inputs=[img2img_gallery],
+ outputs=[init_img_with_mask],
+ )
+
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
|