diff options
Diffstat (limited to 'modules/ui.py')
-rw-r--r-- | modules/ui.py | 201 |
1 files changed, 132 insertions, 69 deletions
diff --git a/modules/ui.py b/modules/ui.py index b6be713b..028eb4e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,47 +5,56 @@ import json import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
import gradio as gr
import gradio.utils
import gradio.routes
-from modules import sd_hijack, sd_models, localization
+from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
+
from modules.shared import opts, cmd_opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.images_history as img_his
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
+txt2img_paste_fields = []
+img2img_paste_fields = []
if not cmd_opts.share and not cmd_opts.listen:
@@ -268,8 +277,13 @@ def calc_time_left(progress, threshold, label, force_display): time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and progress > 0.02) or force_display:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ if (eta_relative > threshold and progress > 0.02) or force_display:
+ if eta_relative > 3600:
+ return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ elif eta_relative > 60:
+ return label + time.strftime('%M:%S', time.gmtime(eta_relative))
+ else:
+ return label + time.strftime('%Ss', time.gmtime(eta_relative))
else:
return ""
@@ -285,7 +299,7 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
- time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
+ time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
if time_left != "":
shared.state.time_left_force_display = True
@@ -293,7 +307,7 @@ def check_progress_call(id_part): progressbar = ""
if opts.show_progressbar:
- progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>"""
+ progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
image = gr_show(False)
preview_visibility = gr_show(False)
@@ -302,7 +316,10 @@ def check_progress_call(id_part): if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
+ if opts.show_progress_grid:
+ shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
+ else:
+ shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image
@@ -477,14 +494,14 @@ def create_toprow(is_img2img): with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
@@ -561,6 +578,9 @@ def apply_setting(key, value): if value is None:
return gr.update()
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
# dont allow model to be swapped when model hash exists in prompt
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
return gr.update()
@@ -587,27 +607,29 @@ def apply_setting(key, value): return value
-def create_ui(wrap_gradio_gpu_call):
- import modules.img2img
- import modules.txt2img
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
- def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
- for k, v in args.items():
- setattr(refresh_component, k, v)
+ return gr.update(**(args or {}))
- return gr.update(**(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn = refresh,
- inputs = [],
- outputs = [refresh_component]
- )
- return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -705,6 +727,7 @@ def create_ui(wrap_gradio_gpu_call): firstphase_width,
firstphase_height,
] + custom_inputs,
+
outputs=[
txt2img_gallery,
generation_info,
@@ -761,6 +784,7 @@ def create_ui(wrap_gradio_gpu_call): ]
)
+ global txt2img_paste_fields
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
@@ -781,6 +805,7 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
]
txt2img_preview_params = [
@@ -848,8 +873,8 @@ def create_ui(wrap_gradio_gpu_call): sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -1030,6 +1055,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[prompt, negative_prompt, style1, style2],
)
+ global img2img_paste_fields
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
@@ -1046,6 +1072,7 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
+ *modules.scripts.scripts_img2img.infotext_fields
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1077,9 +1104,9 @@ def create_ui(wrap_gradio_gpu_call): upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
-
+
with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
@@ -1166,15 +1193,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[image],
outputs=[html, generation_info, html2],
)
- #images history
- images_history_switch_dict = {
- "fn":modules.generation_parameters_copypaste.connect_paste,
- "t2i":txt2img_paste_fields,
- "i2i":img2img_paste_fields
- }
-
- images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
-
+
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -1206,6 +1225,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
@@ -1217,6 +1237,11 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
with gr.Row():
with gr.Column(scale=3):
@@ -1230,14 +1255,19 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images into two')
+ process_split = gr.Checkbox(label='Split oversized images')
process_entropy_focus = gr.Checkbox(label='Create auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1245,15 +1275,24 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
with gr.Tab(label="Train"):
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ with gr.Row():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
+
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1287,6 +1326,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name,
initialization_text,
nvpt,
+ overwrite_old_embedding,
],
outputs=[
train_embedding_name,
@@ -1300,6 +1340,11 @@ def create_ui(wrap_gradio_gpu_call): inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
+ new_hypernetwork_layer_structure,
+ new_hypernetwork_activation_func,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout
],
outputs=[
train_hypernetwork_name,
@@ -1316,11 +1361,14 @@ def create_ui(wrap_gradio_gpu_call): process_dst,
process_width,
process_height,
+ preprocess_txt_action,
process_flip,
process_split,
process_caption,
process_caption_deepbooru,
- process_entropy_focus
+ process_split_threshold,
+ process_overlap_ratio,
+ process_entropy_focus,
],
outputs=[
ti_output,
@@ -1333,7 +1381,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion",
inputs=[
train_embedding_name,
- learn_rate,
+ embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1358,7 +1406,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
- learn_rate,
+ hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1422,6 +1470,9 @@ def create_ui(wrap_gradio_gpu_call): components = []
component_dict = {}
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
def open_folder(f):
if not os.path.exists(f):
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
@@ -1447,6 +1498,8 @@ Requested path was: {f} def run_settings(*args):
changed = 0
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
@@ -1476,13 +1529,15 @@ Requested path was: {f} return f'{changed} settings changed.', opts.dumpjson()
def run_settings_single(value, key):
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
+ oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson()
- oldval = opts.data.get(key, None)
opts.data[key] = value
if oldval != value:
@@ -1525,9 +1580,10 @@ Requested path was: {f} previous_section = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
+ elem_id, text = item.section
+ gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='<h1 class="gr-button-lg">{}</h1>'.format(text))
- if k in quicksettings_names:
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1560,7 +1616,7 @@ Requested path was: {f} def reload_scripts():
modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
+ reload_javascript() # need to refresh the html page
reload_script_bodies.click(
fn=reload_scripts,
@@ -1588,19 +1644,26 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (images_history, "History", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
- (settings_interface, "Settings", "settings"),
]
- with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
- css = file.read()
+ interfaces += script_callbacks.ui_tabs_callback()
+
+ interfaces += [(settings_interface, "Settings", "settings")]
+
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- usercss = file.read()
- css += usercss
+ css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
@@ -1823,9 +1886,10 @@ def load_javascript(raw_response): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
- jsdir = os.path.join(script_path, "javascript")
- for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+ scripts_list = modules.scripts.list_scripts("javascript", ".js")
+
+ for basedir, filename, path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None:
@@ -1843,6 +1907,5 @@ def load_javascript(raw_response): gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript,
- gradio.routes.templates.TemplateResponse)
+reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
|