From c5334fc56b3d44976425da2e6d0a303ae96836a1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 8 Nov 2022 08:35:01 +0300 Subject: fix javascript duplication bug after pressing the restart UI button --- modules/ui.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 34c31ef1..67cf1d6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1752,7 +1752,7 @@ def create_ui(wrap_gradio_gpu_call): return demo -def load_javascript(raw_response): +def reload_javascript(): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: javascript = f'' @@ -1768,7 +1768,7 @@ def load_javascript(raw_response): javascript += f"\n" def template_response(*args, **kwargs): - res = raw_response(*args, **kwargs) + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace( b'', f'{javascript}'.encode("utf8")) res.init_headers() @@ -1777,4 +1777,5 @@ def load_javascript(raw_response): gradio.routes.templates.TemplateResponse = template_response -reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse -- cgit v1.2.3 From 1610b3258458025025e9c4faae57d290e4519745 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 8 Nov 2022 08:38:10 +0300 Subject: add callback for creating a tab in train UI --- modules/script_callbacks.py | 27 +++++++++++++++++++++++++-- modules/ui.py | 4 ++++ 2 files changed, 29 insertions(+), 2 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 74dfb880..f19e164c 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -7,6 +7,7 @@ from typing import Optional from fastapi import FastAPI from gradio import Blocks + def report_exception(c, job): print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -45,15 +46,21 @@ class CFGDenoiserParams: """Total number of sampling steps planned""" +class UiTrainTabParams: + def __init__(self, txt2img_preview_params): + self.txt2img_preview_params = txt2img_preview_params + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], callbacks_ui_tabs=[], + callbacks_ui_train_tabs=[], callbacks_ui_settings=[], callbacks_before_image_saved=[], callbacks_image_saved=[], - callbacks_cfg_denoiser=[] + callbacks_cfg_denoiser=[], ) @@ -61,6 +68,7 @@ def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() + def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: @@ -79,7 +87,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - + for c in callback_map['callbacks_ui_tabs']: try: res += c.callback() or [] @@ -89,6 +97,14 @@ def ui_tabs_callback(): return res +def ui_train_tabs_callback(params: UiTrainTabParams): + for c in callback_map['callbacks_ui_train_tabs']: + try: + c.callback(params) + except Exception: + report_exception(c, 'callbacks_ui_train_tabs') + + def ui_settings_callback(): for c in callback_map['callbacks_ui_settings']: try: @@ -169,6 +185,13 @@ def on_ui_tabs(callback): add_callback(callback_map['callbacks_ui_tabs'], callback) +def on_ui_train_tabs(callback): + """register a function to be called when the UI is creating new tabs for the train tab. + Create your new tabs with gr.Tab. + """ + add_callback(callback_map['callbacks_ui_train_tabs'], callback) + + def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ diff --git a/modules/ui.py b/modules/ui.py index 67cf1d6a..7ea1177f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call): train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') train_embedding = gr.Button(value="Train Embedding", variant='primary') + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + with gr.Column(): progressbar = gr.HTML(elem_id="ti_progressbar") ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) -- cgit v1.2.3 From 81f2575df91a50e4aa9ca816e02e3f77342eedc8 Mon Sep 17 00:00:00 2001 From: Liam Date: Wed, 9 Nov 2022 15:24:31 -0500 Subject: updating the displayed generation info when user clicks images in the gallery. feature request 4415 --- javascript/ui.js | 10 +++++++++- modules/ui.py | 20 ++++++++++++++++++++ scripts/prompt_matrix.py | 2 ++ scripts/prompts_from_file.py | 6 +++++- 4 files changed, 36 insertions(+), 2 deletions(-) (limited to 'modules/ui.py') diff --git a/javascript/ui.js b/javascript/ui.js index 95cfd106..443d1642 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -179,9 +179,17 @@ onUiUpdate(function(){ img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); } + if (!txt2img_gallery) { + txt2img_gallery = gradioApp().querySelector('#txt2img_gallery') + txt2img_gallery?.addEventListener('click', () => gradioApp().getElementById("txt2img_generation_info_button").click()); + } + if (!img2img_gallery) { + img2img_gallery = gradioApp().querySelector('#img2img_gallery') + img2img_gallery?.addEventListener('click', () => gradioApp().getElementById("img2img_generation_info_button").click()); + } }) -let txt2img_textarea, img2img_textarea = undefined; +let txt2img_textarea, img2img_textarea, txt2img_gallery, img2img_gallery = undefined; let wait_time = 800 let token_timeout; diff --git a/modules/ui.py b/modules/ui.py index 7ea1177f..756499d1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -566,6 +566,17 @@ def apply_setting(key, value): return value +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -638,6 +649,15 @@ Requested path was: {f} with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) save.click( fn=wrap_gradio_call(save_files), diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index e49c9b20..4d1e152d 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -80,6 +80,8 @@ class Script(scripts.Script): grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) processed.images.insert(0, grid) + processed.index_of_first_image = 1 + processed.infotexts.insert(0, processed.infotexts[0]) if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 3388bc77..32fe6bdb 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -145,6 +145,8 @@ class Script(scripts.Script): state.job_count = job_count images = [] + all_prompts = [] + infotexts = [] for n, args in enumerate(jobs): state.job = f"{state.job_no + 1} out of {state.job_count}" @@ -157,5 +159,7 @@ class Script(scripts.Script): if checkbox_iterate: p.seed = p.seed + (p.batch_size * p.n_iter) + all_prompts += proc.all_prompts + infotexts += proc.infotexts - return Processed(p, images, p.seed, "") + return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts) -- cgit v1.2.3 From b98740129c435f04a060369bd071fc4bafe021f5 Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 10 Nov 2022 13:07:41 -0500 Subject: added event listener for the image gallery modal; moved js to separate file --- javascript/generationParams.js | 33 +++++++++++++++++++++++++++++++++ javascript/ui.js | 10 +--------- modules/ui.py | 2 ++ 3 files changed, 36 insertions(+), 9 deletions(-) create mode 100644 javascript/generationParams.js (limited to 'modules/ui.py') diff --git a/javascript/generationParams.js b/javascript/generationParams.js new file mode 100644 index 00000000..95f05093 --- /dev/null +++ b/javascript/generationParams.js @@ -0,0 +1,33 @@ +// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes + +let txt2img_gallery, img2img_gallery, modal = undefined; +onUiUpdate(function(){ + if (!txt2img_gallery) { + txt2img_gallery = attachGalleryListeners("txt2img") + } + if (!img2img_gallery) { + img2img_gallery = attachGalleryListeners("img2img") + } + if (!modal) { + modal = gradioApp().getElementById('lightboxModal') + modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] }); + } +}); + +let modalObserver = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText + if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') + gradioApp().getElementById(selectedTab+"_generation_info_button").click() + }); +}); + +function attachGalleryListeners(tab_name) { + gallery = gradioApp().querySelector('#'+tab_name+'_gallery') + gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); + gallery?.addEventListener('keydown', (e) => { + if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow + gradioApp().getElementById(tab_name+"_generation_info_button").click() + }); + return gallery; +} diff --git a/javascript/ui.js b/javascript/ui.js index 443d1642..95cfd106 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -179,17 +179,9 @@ onUiUpdate(function(){ img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); } - if (!txt2img_gallery) { - txt2img_gallery = gradioApp().querySelector('#txt2img_gallery') - txt2img_gallery?.addEventListener('click', () => gradioApp().getElementById("txt2img_generation_info_button").click()); - } - if (!img2img_gallery) { - img2img_gallery = gradioApp().querySelector('#img2img_gallery') - img2img_gallery?.addEventListener('click', () => gradioApp().getElementById("img2img_generation_info_button").click()); - } }) -let txt2img_textarea, img2img_textarea, txt2img_gallery, img2img_gallery = undefined; +let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 let token_timeout; diff --git a/modules/ui.py b/modules/ui.py index 756499d1..5dce7f3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -570,6 +570,8 @@ def update_generation_info(args): generation_info, html_info, img_index = args try: generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info return plaintext_to_html(generation_info["infotexts"][img_index]) except Exception: pass -- cgit v1.2.3 From 72b52fbb77360f848cfa296b0c79d2bc0a1060f2 Mon Sep 17 00:00:00 2001 From: dtlnor Date: Wed, 16 Nov 2022 13:08:03 +0900 Subject: add css override --- modules/ui.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 5dce7f3b..5e2a992f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -69,8 +69,11 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None css_hide_progressbar = """ .wrap .m-12 svg { display:none!important; } .wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } .progress-bar { display:none!important; } .meta-text { display:none!important; } +.meta-text-center { display:none!important; } """ # Using constants for these since the variation selector isn't visible. -- cgit v1.2.3 From c8c40c8a643f2d20e3475e4d9ae7aae6d36c7e85 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 17 Nov 2022 18:03:57 -0800 Subject: Add interrupt button to preprocessing --- modules/textual_inversion/ui.py | 2 +- modules/ui.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index d679e6f4..35c4feef 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): def preprocess(*args): modules.textual_inversion.preprocess.preprocess(*args) - return "Preprocessing finished.", "" + return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" def train_embedding(*args): diff --git a/modules/ui.py b/modules/ui.py index 5dce7f3b..88e3c827 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1249,7 +1249,9 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - run_preprocess = gr.Button(value="Preprocess", variant='primary') + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt") + run_preprocess = gr.Button(value="Preprocess", variant='primary') process_split.change( fn=lambda show: gr_show(show), @@ -1422,6 +1424,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[], ) + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + def create_setting_component(key, is_quicksettings=False): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default -- cgit v1.2.3 From cdc8020d13c5eef099c609b0a911ccf3568afc0d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 12:01:51 +0300 Subject: change StableDiffusionProcessing to internally use sampler name instead of sampler index --- modules/api/api.py | 26 ++++++++--------------- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/images.py | 2 +- modules/img2img.py | 4 ++-- modules/processing.py | 29 +++++++++++--------------- modules/sd_samplers.py | 13 +++++++++--- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/txt2img.py | 3 ++- modules/ui.py | 2 +- scripts/img2imgalt.py | 4 ++-- scripts/xy_grid.py | 12 +++++------ 11 files changed, 49 insertions(+), 54 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..0eccccbb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,9 +6,9 @@ from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException import modules.shared as shared +from modules import sd_samplers from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo from PIL import PngImagePlugin from modules.sd_models import checkpoints_list @@ -25,8 +25,12 @@ def upscaler_to_index(name: str): raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") -sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def validate_sampler_name(name): + config = sd_samplers.all_samplers_map.get(name, None) + if config is None: + raise HTTPException(status_code=404, detail="Sampler not found") + return name def setUpscalers(req: dict): reqDict = vars(req) @@ -82,14 +86,9 @@ class Api: self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - sampler_index = sampler_to_index(txt2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } @@ -109,12 +108,6 @@ class Api: return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): - sampler_index = sampler_to_index(img2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - - init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -123,10 +116,9 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask @@ -272,7 +264,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7f182712..fbb87dd1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared +from modules import devices, processing, sd_models, shared, sd_samplers from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_index = preview_sampler_index + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width diff --git a/modules/images.py b/modules/images.py index ae705cbd..26d5b7a9 100644 --- a/modules/images.py +++ b/modules/images.py @@ -303,7 +303,7 @@ class FilenameGenerator: 'width': lambda self: self.image.width, 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), - 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), + 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime