aboutsummaryrefslogtreecommitdiffstats
path: root/modules/ui.py
diff options
context:
space:
mode:
author不会画画的中医不是好程序员 <yfszzx@gmail.com>2022-10-16 02:04:05 +0000
committerGitHub <noreply@github.com>2022-10-16 02:04:05 +0000
commitd41ac174e24e1e7cdcf7b42f2a03cbc6394eb5e5 (patch)
treeda39175c109598f17e89fafe57aac9b9597ff616 /modules/ui.py
parent6e4f5566b58e36aede83427df6c69eba8517af28 (diff)
parentbe1596ce30b1ead6998da0c62003003dcce5eb2c (diff)
downloadstable-diffusion-webui-gfx803-d41ac174e24e1e7cdcf7b42f2a03cbc6394eb5e5.tar.gz
stable-diffusion-webui-gfx803-d41ac174e24e1e7cdcf7b42f2a03cbc6394eb5e5.tar.bz2
stable-diffusion-webui-gfx803-d41ac174e24e1e7cdcf7b42f2a03cbc6394eb5e5.zip
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/ui.py')
-rw-r--r--modules/ui.py107
1 files changed, 69 insertions, 38 deletions
diff --git a/modules/ui.py b/modules/ui.py
index 1bc919c7..b867d40f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -7,6 +7,7 @@ import mimetypes
import os
import random
import sys
+import tempfile
import time
import traceback
import platform
@@ -80,6 +81,8 @@ art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
+save_style_symbol = '\U0001f4be' # 💾
+apply_style_symbol = '\U0001f4cb' # 📋
def plaintext_to_html(text):
@@ -88,6 +91,14 @@ def plaintext_to_html(text):
def image_from_url_text(filedata):
+ if type(filedata) == dict and filedata["is_file"]:
+ filename = filedata["name"]
+ tempdir = os.path.normpath(tempfile.gettempdir())
+ normfn = os.path.normpath(filename)
+ assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
+
+ return Image.open(filename)
+
if type(filedata) == list:
if len(filedata) == 0:
return None
@@ -143,10 +154,7 @@ def save_files(js_data, images, do_make_zip, index):
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
for image_index, filedata in enumerate(images, start_index):
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
-
- image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
+ image = image_from_url_text(filedata)
is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
@@ -176,6 +184,23 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
+def save_pil_to_file(pil_image, dir=None):
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in pil_image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
+ return file_obj
+
+
+# override save to file function so that it also writes PNG info
+gr.processing_utils.save_pil_to_file = save_pil_to_file
+
+
def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
@@ -304,7 +329,7 @@ def visit(x, func, path=""):
def add_style(name: str, prompt: str, negative_prompt: str):
if name is None:
- return [gr_show(), gr_show()]
+ return [gr_show() for x in range(4)]
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
shared.prompt_styles.styles[style.name] = style
@@ -429,29 +454,38 @@ def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
with gr.Row(elem_id="toprow"):
- with gr.Column(scale=4):
+ with gr.Column(scale=6):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
- with gr.Column(scale=1, elem_id="roll_col"):
- roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- with gr.Column(scale=10, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
with gr.Row():
- with gr.Column(scale=8):
+ with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
- with gr.Column(scale=1, elem_id="roll_col"):
- sh = gr.Button(elem_id="sh", visible=True)
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+ with gr.Column(scale=1, elem_id="roll_col"):
+ roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
+ paste = gr.Button(value=paste_symbol, elem_id="paste")
+ save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
+ prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
+
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+
+ button_interrogate = None
+ button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_id="interrogate_col"):
+ button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+
+ if cmd_opts.deepdanbooru:
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
with gr.Row():
@@ -471,20 +505,14 @@ def create_toprow(is_img2img):
outputs=[],
)
- with gr.Row(scale=1):
- if is_img2img:
- interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- if cmd_opts.deepdanbooru:
- deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
- else:
- deepbooru = None
- else:
- interrogate = None
- deepbooru = None
- prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
- save_style = gr.Button('Create style', elem_id="style_create")
+ with gr.Row():
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -588,7 +616,7 @@ def create_ui(wrap_gradio_gpu_call):
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
- with gr.Group():
+ with gr.Column():
with gr.Row():
save = gr.Button('Save')
send_to_img2img = gr.Button('Send to img2img')
@@ -744,10 +772,10 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
@@ -803,7 +831,7 @@ def create_ui(wrap_gradio_gpu_call):
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
- with gr.Group():
+ with gr.Column():
with gr.Row():
save = gr.Button('Save')
img2img_send_to_img2img = gr.Button('Send to img2img')
@@ -1166,6 +1194,7 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@@ -1244,6 +1273,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_embedding_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
training_width,
@@ -1268,6 +1298,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_hypernetwork_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
steps,