From dbca512154341bb13e1b15d207176f2d403aff30 Mon Sep 17 00:00:00 2001 From: siutin Date: Fri, 3 Feb 2023 03:13:03 +0800 Subject: add an internal API for obtaining current task id --- modules/progress.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index c69ecf3d..05032ac5 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -4,6 +4,7 @@ import time import gradio as gr from pydantic import BaseModel, Field +from typing import List from modules.shared import opts @@ -37,6 +38,9 @@ def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +class CurrentTaskResponse(BaseModel): + current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") + class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") @@ -56,6 +60,8 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) +def setup_current_task_api(app): + return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse) def progressapi(req: ProgressRequest): active = req.id_task == current_task @@ -97,3 +103,5 @@ def progressapi(req: ProgressRequest): return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) +def current_task_api(): + return CurrentTaskResponse(current_task=current_task) \ No newline at end of file -- cgit v1.2.3 From 9407f1731aa8c112ffc0efaa611a76f7fead3d0c Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 6 Feb 2023 03:53:05 +0800 Subject: store the last generated result --- modules/call_queue.py | 1 + modules/progress.py | 10 ++++++++++ 2 files changed, 11 insertions(+) (limited to 'modules/progress.py') diff --git a/modules/call_queue.py b/modules/call_queue.py index 92097c15..30ac26bc 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -37,6 +37,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): res = func(*args, **kwargs) finally: progress.finish_task(id_task) + progress.set_last_task_result(id_task, res) shared.state.end() diff --git a/modules/progress.py b/modules/progress.py index 05032ac5..27a336ad 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -37,6 +37,16 @@ def finish_task(id_task): def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +last_task_id = None +last_task_result = None + +def set_last_task_result(id_job, result): + global last_task_id + global last_task_result + + last_task_id = id_job + last_task_result = result + class CurrentTaskResponse(BaseModel): current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") -- cgit v1.2.3 From 4242e194e417ec5008d09ec6d756594ac65f77bd Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 6 Feb 2023 03:55:31 +0800 Subject: add a button to restore the current progress --- javascript/progressbar.js | 4 ++-- javascript/ui.js | 36 ++++++++++++++++++++++++++++++++++-- modules/progress.py | 14 ++++++++++++++ modules/ui.py | 34 ++++++++++++++++++++++++++++++---- 4 files changed, 80 insertions(+), 8 deletions(-) (limited to 'modules/progress.py') diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 4ac9b8db..7ba14192 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -59,8 +59,8 @@ function setTitle(progress){ } -function randomId(){ - return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" +function randomId(prefix=null){ + return "task(" + (prefix == null ? "" : prefix + "_") + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" } // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and diff --git a/javascript/ui.js b/javascript/ui.js index 4a440193..9fe884c0 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -163,7 +163,7 @@ function submit(){ rememberGallerySelection('txt2img_gallery') showSubmitButtons('txt2img', false) - var id = randomId() + var id = randomId("txt2img") requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ showSubmitButtons('txt2img', true) @@ -180,7 +180,7 @@ function submit_img2img(){ rememberGallerySelection('img2img_gallery') showSubmitButtons('img2img', false) - var id = randomId() + var id = randomId("img2img") requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ showSubmitButtons('img2img', true) }) @@ -361,3 +361,35 @@ function selectCheckpoint(name){ desiredCheckpointName = name; gradioApp().getElementById('change_checkpoint').click() } + +function restoreProgress (task_tag) { + + if (task_tag) { + let successHandler = ({ current_task }) => { + if (current_task) { + let _task_tag = ["txt2img", "img2img"].find(t => current_task.startsWith(`task(${t}_`) && current_task.endsWith(")")) + if (!_task_tag) { + console.warn(`task tag ${current_task} not implemented yet`) + return + } + if (task_tag != _task_tag) return + showSubmitButtons(task_tag, false) + requestProgress(current_task, gradioApp().getElementById(`${task_tag}_gallery_container`), gradioApp().getElementById(`${task_tag}_gallery`), function(){ + showSubmitButtons(task_tag, true) + }) + } + } + + let errorHandler = e => window.alert(`invalid internal api respsonse. message: ${e}`) + + fetch("./internal/current_task") + .then(res => res.json()) + .then(successHandler) + .catch(errorHandler) + } + + var res = create_submit_args(arguments) + res[0] = 0 + return res + +} \ No newline at end of file diff --git a/modules/progress.py b/modules/progress.py index 27a336ad..36963c92 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -48,6 +48,20 @@ def set_last_task_result(id_job, result): last_task_result = result +def restore_progress_call(task_tag): + if current_task is None or not current_task[5:-1].startswith(task_tag): + + # image, generation_info, html_info, html_log + return tuple(list([None, None, None, None])) + + else: + + t_task = current_task + while t_task != last_task_id: + time.sleep(2.5) + return last_task_result + + class CurrentTaskResponse(BaseModel): current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") diff --git a/modules/ui.py b/modules/ui.py index 627fbe0b..0133ee12 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,6 +41,7 @@ from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text import modules.extras +from modules.progress import restore_progress_call warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) @@ -293,6 +294,7 @@ def create_toprow(is_img2img): interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + restore_progress = gr.Button('Restore Progress', elem_id=f"{id_part}_restore_progress") skip.click( fn=lambda: shared.state.skip(), @@ -329,7 +331,7 @@ def create_toprow(is_img2img): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button + return prompt, prompt_styles, negative_prompt, submit, restore_progress, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -446,7 +448,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, restore_progress, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) @@ -578,6 +580,18 @@ def create_ui(): res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) + restore_progress.click( + fn=lambda: restore_progress_call('txt2img'), + _js="() => restoreProgress('txt2img')", + inputs=[], + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ] + ) + txt_prompt_img.change( fn=modules.images.image_data, inputs=[ @@ -646,7 +660,7 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, restore_progress, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) @@ -898,6 +912,18 @@ def create_ui(): submit.click(**img2img_args) res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) + restore_progress.click( + fn=lambda: restore_progress_call('img2img'), + _js="() => restoreProgress('img2img')", + inputs=[], + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ] + ) + img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, @@ -1491,7 +1517,7 @@ def create_ui(): gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - + def unload_sd_weights(): modules.sd_models.unload_model_weights() -- cgit v1.2.3 From e0b58527ff040f9c547ea45b5fcf1bfb7ab23cdd Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 6 Feb 2023 15:57:26 +0800 Subject: use condition to wait for result --- modules/call_queue.py | 2 +- modules/progress.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/call_queue.py b/modules/call_queue.py index 30ac26bc..9888109e 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -7,7 +7,7 @@ import time from modules import shared, progress queue_lock = threading.Lock() - +queue_lock_condition = threading.Condition(lock=queue_lock) def wrap_queued_call(func): def f(*args, **kwargs): diff --git a/modules/progress.py b/modules/progress.py index 36963c92..1947c0fd 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -6,6 +6,7 @@ import gradio as gr from pydantic import BaseModel, Field from typing import List +from modules import call_queue from modules.shared import opts import modules.shared as shared @@ -57,8 +58,9 @@ def restore_progress_call(task_tag): else: t_task = current_task - while t_task != last_task_id: - time.sleep(2.5) + with call_queue.queue_lock_condition: + call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) + return last_task_result -- cgit v1.2.3 From 70ab21e67d128b953fbf4a360e02ac783f40dd55 Mon Sep 17 00:00:00 2001 From: siutin Date: Wed, 29 Mar 2023 00:17:19 +0800 Subject: keep randomId simpler --- javascript/progressbar.js | 4 ++-- javascript/ui.js | 10 ++-------- modules/progress.py | 4 ++-- 3 files changed, 6 insertions(+), 12 deletions(-) (limited to 'modules/progress.py') diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 7ba14192..4ac9b8db 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -59,8 +59,8 @@ function setTitle(progress){ } -function randomId(prefix=null){ - return "task(" + (prefix == null ? "" : prefix + "_") + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" +function randomId(){ + return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" } // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and diff --git a/javascript/ui.js b/javascript/ui.js index 9fe884c0..c9df066d 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -163,7 +163,7 @@ function submit(){ rememberGallerySelection('txt2img_gallery') showSubmitButtons('txt2img', false) - var id = randomId("txt2img") + var id = randomId() requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ showSubmitButtons('txt2img', true) @@ -180,7 +180,7 @@ function submit_img2img(){ rememberGallerySelection('img2img_gallery') showSubmitButtons('img2img', false) - var id = randomId("img2img") + var id = randomId() requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ showSubmitButtons('img2img', true) }) @@ -367,12 +367,6 @@ function restoreProgress (task_tag) { if (task_tag) { let successHandler = ({ current_task }) => { if (current_task) { - let _task_tag = ["txt2img", "img2img"].find(t => current_task.startsWith(`task(${t}_`) && current_task.endsWith(")")) - if (!_task_tag) { - console.warn(`task tag ${current_task} not implemented yet`) - return - } - if (task_tag != _task_tag) return showSubmitButtons(task_tag, false) requestProgress(current_task, gradioApp().getElementById(`${task_tag}_gallery_container`), gradioApp().getElementById(`${task_tag}_gallery`), function(){ showSubmitButtons(task_tag, true) diff --git a/modules/progress.py b/modules/progress.py index 1947c0fd..e99267f5 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -49,8 +49,8 @@ def set_last_task_result(id_job, result): last_task_result = result -def restore_progress_call(task_tag): - if current_task is None or not current_task[5:-1].startswith(task_tag): +def restore_progress_call(): + if current_task is None: # image, generation_info, html_info, html_log return tuple(list([None, None, None, None])) -- cgit v1.2.3 From 984970068c2bdc14cff266129ca25a26fbccbf2e Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 17 Apr 2023 01:06:28 +0800 Subject: multi users support --- modules/call_queue.py | 23 ++++++++++++-------- modules/progress.py | 60 ++++++++++++++++++++++++++++++++++++--------------- modules/ui.py | 4 ++-- 3 files changed, 59 insertions(+), 28 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/call_queue.py b/modules/call_queue.py index 9888109e..632afcdd 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -4,6 +4,7 @@ import threading import traceback import time +import gradio as gr from modules import shared, progress queue_lock = threading.Lock() @@ -20,41 +21,45 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): - def f(*args, **kwargs): + def f(request: gr.Request, *args, **kwargs): + user = request.username # if the first argument is a string that says "task(...)", it is treated as a job id if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": id_task = args[0] - progress.add_task_to_queue(id_task) + progress.add_task_to_queue(user, id_task) else: id_task = None with queue_lock: shared.state.begin() - progress.start_task(id_task) + progress.start_task(user, id_task) try: res = func(*args, **kwargs) finally: - progress.finish_task(id_task) - progress.set_last_task_result(id_task, res) + progress.finish_task(user, id_task) + progress.set_last_task_result(user, id_task, res) shared.state.end() return res - return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) + return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True) -def wrap_gradio_call(func, extra_outputs=None, add_stats=False): - def f(*args, extra_outputs_array=extra_outputs, **kwargs): +def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False): + def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: shared.mem_mon.monitor() t = time.perf_counter() try: - res = list(func(*args, **kwargs)) + if add_request: + res = list(func(request, *args, **kwargs)) + else: + res = list(func(*args, **kwargs)) except Exception as e: # When printing out our debug argument list, do not print out more than a MB of text max_debug_str_len = 131072 # (1024*1024)/8 diff --git a/modules/progress.py b/modules/progress.py index e99267f5..13568701 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -4,7 +4,9 @@ import time import gradio as gr from pydantic import BaseModel, Field -from typing import List +from typing import Optional +from fastapi import Depends, Security +from fastapi.security import APIKeyCookie from modules import call_queue from modules.shared import opts @@ -12,57 +14,71 @@ from modules.shared import opts import modules.shared as shared +current_task_user = None current_task = None pending_tasks = {} finished_tasks = [] -def start_task(id_task): +def start_task(user, id_task): global current_task + global current_task_user + current_task_user = user current_task = id_task - pending_tasks.pop(id_task, None) + pending_tasks.pop((user, id_task), None) -def finish_task(id_task): +def finish_task(user, id_task): global current_task + global current_task_user if current_task == id_task: current_task = None - finished_tasks.append(id_task) + if current_task_user == user: + current_task_user = None + + finished_tasks.append((user, id_task)) if len(finished_tasks) > 16: finished_tasks.pop(0) -def add_task_to_queue(id_job): - pending_tasks[id_job] = time.time() +def add_task_to_queue(user, id_job): + pending_tasks[(user, id_job)] = time.time() last_task_id = None last_task_result = None +last_task_user = None + +def set_last_task_result(user, id_job, result): -def set_last_task_result(id_job, result): global last_task_id global last_task_result + global last_task_user last_task_id = id_job last_task_result = result + last_task_user = user -def restore_progress_call(): +def restore_progress_call(request: gr.Request): if current_task is None: # image, generation_info, html_info, html_log return tuple(list([None, None, None, None])) else: + user = request.username - t_task = current_task - with call_queue.queue_lock_condition: - call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) + if current_task_user == user: + t_task = current_task + with call_queue.queue_lock_condition: + call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) - return last_task_result + return last_task_result + return tuple(list([None, None, None, None])) class CurrentTaskResponse(BaseModel): current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") @@ -87,6 +103,19 @@ def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) def setup_current_task_api(app): + + def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))): + return None if token is None else app.tokens.get(token) + + def current_task_api(current_user: str = Depends(get_current_user)): + + if app.auth is None or current_task_user == current_user: + current_user_task = current_task + else: + current_user_task = None + + return CurrentTaskResponse(current_task=current_user_task) + return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse) def progressapi(req: ProgressRequest): @@ -127,7 +156,4 @@ def progressapi(req: ProgressRequest): else: live_preview = None - return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) - -def current_task_api(): - return CurrentTaskResponse(current_task=current_task) \ No newline at end of file + return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) \ No newline at end of file diff --git a/modules/ui.py b/modules/ui.py index 8fc17ce7..a7b3cccb 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -582,7 +582,7 @@ def create_ui(): res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) restore_progress_button.click( - fn=lambda: restore_progress_call(), + fn=restore_progress_call, _js="() => restoreProgress('txt2img')", inputs=[], outputs=[ @@ -914,7 +914,7 @@ def create_ui(): res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) restore_progress_button.click( - fn=lambda: restore_progress_call(), + fn=restore_progress_call, _js="() => restoreProgress('img2img')", inputs=[], outputs=[ -- cgit v1.2.3 From bd9700405a0686769b58437fd87d9106d3cd1346 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Apr 2023 22:15:20 +0300 Subject: Revert "Merge pull request #7595 from siutin/feature/restore-progress" This reverts commit 80987c36f9bfa33092ac7c75624b25d839cb2a06, reversing changes made to 2e78e65a22bfa6b116ae18d12fdcb85ec8acd727. --- javascript/hints.js | 2 +- javascript/ui.js | 27 ------------------ modules/call_queue.py | 18 ++++-------- modules/progress.py | 76 ++++++--------------------------------------------- modules/ui.py | 35 +++--------------------- webui.py | 1 - 6 files changed, 18 insertions(+), 141 deletions(-) (limited to 'modules/progress.py') diff --git a/javascript/hints.js b/javascript/hints.js index 1a3130f8..23d85710 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -22,7 +22,7 @@ titles = { "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", "\u{1f3b4}": "Show/hide extra networks", - "\u{1F300}": "Restore progress", + "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/ui.js b/javascript/ui.js index e50b44ee..0ba92ef8 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -362,32 +362,6 @@ function selectCheckpoint(name){ gradioApp().getElementById('change_checkpoint').click() } -function restoreProgress (task_tag) { - - if (task_tag) { - let successHandler = ({ current_task }) => { - if (current_task) { - showSubmitButtons(task_tag, false) - requestProgress(current_task, gradioApp().getElementById(`${task_tag}_gallery_container`), gradioApp().getElementById(`${task_tag}_gallery`), function(){ - showSubmitButtons(task_tag, true) - }) - } - } - - let errorHandler = e => window.alert(`invalid internal api respsonse. message: ${e}`) - - fetch("./internal/current_task") - .then(res => res.json()) - .then(successHandler) - .catch(errorHandler) - } - - var res = create_submit_args(arguments) - res[0] = 0 - return res - -} - function currentImg2imgSourceResolution(_, _, scaleBy){ var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img') return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy] @@ -403,4 +377,3 @@ function updateImg2imgResizeToTextAfterChangingImage(){ return [] } - diff --git a/modules/call_queue.py b/modules/call_queue.py index 43f6ebe0..92097c15 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -4,16 +4,10 @@ import threading import traceback import time -import gradio as gr from modules import shared, progress queue_lock = threading.Lock() -queue_lock_condition = threading.Condition(lock=queue_lock) -def wrap_session_call(func): - def f(request: gr.Request, *args, **kwargs): - return func(request, *args, **kwargs) - return f def wrap_queued_call(func): def f(*args, **kwargs): @@ -26,31 +20,29 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): - def f(request: gr.Request, *args, **kwargs): - user = request.username + def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": id_task = args[0] - progress.add_task_to_queue(user, id_task) + progress.add_task_to_queue(id_task) else: id_task = None with queue_lock: shared.state.begin() - progress.start_task(user, id_task) + progress.start_task(id_task) try: res = func(*args, **kwargs) finally: - progress.finish_task(user, id_task) - progress.set_last_task_result(user, id_task, res) + progress.finish_task(id_task) shared.state.end() return res - return wrap_session_call(wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)) + return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) def wrap_gradio_call(func, extra_outputs=None, add_stats=False): diff --git a/modules/progress.py b/modules/progress.py index 13568701..c69ecf3d 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -4,84 +4,38 @@ import time import gradio as gr from pydantic import BaseModel, Field -from typing import Optional -from fastapi import Depends, Security -from fastapi.security import APIKeyCookie -from modules import call_queue from modules.shared import opts import modules.shared as shared -current_task_user = None current_task = None pending_tasks = {} finished_tasks = [] -def start_task(user, id_task): +def start_task(id_task): global current_task - global current_task_user - current_task_user = user current_task = id_task - pending_tasks.pop((user, id_task), None) + pending_tasks.pop(id_task, None) -def finish_task(user, id_task): +def finish_task(id_task): global current_task - global current_task_user if current_task == id_task: current_task = None - if current_task_user == user: - current_task_user = None - - finished_tasks.append((user, id_task)) + finished_tasks.append(id_task) if len(finished_tasks) > 16: finished_tasks.pop(0) -def add_task_to_queue(user, id_job): - pending_tasks[(user, id_job)] = time.time() - -last_task_id = None -last_task_result = None -last_task_user = None - -def set_last_task_result(user, id_job, result): - - global last_task_id - global last_task_result - global last_task_user - - last_task_id = id_job - last_task_result = result - last_task_user = user - - -def restore_progress_call(request: gr.Request): - if current_task is None: - - # image, generation_info, html_info, html_log - return tuple(list([None, None, None, None])) - - else: - user = request.username +def add_task_to_queue(id_job): + pending_tasks[id_job] = time.time() - if current_task_user == user: - t_task = current_task - with call_queue.queue_lock_condition: - call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) - - return last_task_result - - return tuple(list([None, None, None, None])) - -class CurrentTaskResponse(BaseModel): - current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") @@ -102,21 +56,6 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) -def setup_current_task_api(app): - - def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))): - return None if token is None else app.tokens.get(token) - - def current_task_api(current_user: str = Depends(get_current_user)): - - if app.auth is None or current_task_user == current_user: - current_user_task = current_task - else: - current_user_task = None - - return CurrentTaskResponse(current_task=current_user_task) - - return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse) def progressapi(req: ProgressRequest): active = req.id_task == current_task @@ -156,4 +95,5 @@ def progressapi(req: ProgressRequest): else: live_preview = None - return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) \ No newline at end of file + return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) + diff --git a/modules/ui.py b/modules/ui.py index cc3c8d35..a32500d1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,7 +41,6 @@ from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text import modules.extras -from modules.progress import restore_progress_call warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) @@ -82,7 +81,6 @@ apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ -restore_progress_symbol = '\U0001F300' # 🌀 def plaintext_to_html(text): @@ -327,7 +325,6 @@ def create_toprow(is_img2img): extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") - restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress") token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") @@ -345,7 +342,7 @@ def create_toprow(is_img2img): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -462,7 +459,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) @@ -594,18 +591,6 @@ def create_ui(): res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) - restore_progress_button.click( - fn=restore_progress_call, - _js="() => restoreProgress('txt2img')", - inputs=[], - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ] - ) - txt_prompt_img.change( fn=modules.images.image_data, inputs=[ @@ -674,7 +659,7 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) @@ -966,18 +951,6 @@ def create_ui(): submit.click(**img2img_args) res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) - restore_progress_button.click( - fn=restore_progress_call, - _js="() => restoreProgress('img2img')", - inputs=[], - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ] - ) - img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, @@ -1574,7 +1547,7 @@ def create_ui(): gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - + def unload_sd_weights(): modules.sd_models.unload_model_weights() diff --git a/webui.py b/webui.py index 7f8d4f84..357bf4c1 100644 --- a/webui.py +++ b/webui.py @@ -339,7 +339,6 @@ def webui(): setup_middleware(app) modules.progress.setup_progress_api(app) - modules.progress.setup_current_task_api(app) if launch_api: create_api(app) -- cgit v1.2.3 From c48ab36cb9e0120c6f1779bee9e875bee8f903f5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Apr 2023 22:16:54 +0300 Subject: alternate restore progress button implementation --- javascript/hints.js | 1 + javascript/progressbar.js | 4 ++-- javascript/ui.js | 46 +++++++++++++++++++++++++++++++++++++++++++++- modules/call_queue.py | 1 + modules/progress.py | 18 ++++++++++++++++++ modules/ui.py | 36 ++++++++++++++++++++++++++++++++---- 6 files changed, 99 insertions(+), 7 deletions(-) (limited to 'modules/progress.py') diff --git a/javascript/hints.js b/javascript/hints.js index 23d85710..e7d17d36 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -22,6 +22,7 @@ titles = { "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", "\u{1f3b4}": "Show/hide extra networks", + "\u{1f300}": "Restore progress", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 8df3f569..23bbf298 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -66,7 +66,7 @@ function randomId(){ // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and // preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. // calls onProgress every time there is a progress update -function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){ +function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){ var dateStart = new Date() var wasEverActive = false var parentProgressbar = progressbarContainer.parentNode @@ -138,7 +138,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre return } - if(elapsedFromStart > 40 && !res.queued && !res.active){ + if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){ removeProgressBar() return } diff --git a/javascript/ui.js b/javascript/ui.js index 0ba92ef8..e14b33f5 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -159,14 +159,24 @@ function showSubmitButtons(tabname, show){ gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block" } +function showRestoreProgressButton(tabname, show){ + button = gradioApp().getElementById(tabname + "_restore_progress") + if(! button) return + + button.style.display = show ? "flex" : "none" +} + function submit(){ rememberGallerySelection('txt2img_gallery') showSubmitButtons('txt2img', false) var id = randomId() + localStorage.setItem("txt2img_task_id", id); + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ showSubmitButtons('txt2img', true) - + localStorage.removeItem("txt2img_task_id") + showRestoreProgressButton('txt2img', false) }) var res = create_submit_args(arguments) @@ -181,8 +191,12 @@ function submit_img2img(){ showSubmitButtons('img2img', false) var id = randomId() + localStorage.setItem("img2img_task_id", id); + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ showSubmitButtons('img2img', true) + localStorage.removeItem("img2img_task_id") + showRestoreProgressButton('img2img', false) }) var res = create_submit_args(arguments) @@ -193,6 +207,36 @@ function submit_img2img(){ return res } +function restoreProgressTxt2img(x){ + id = localStorage.getItem("txt2img_task_id") + + if(id) { + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ + showSubmitButtons('txt2img', true) + }, null, 0) + } + + return [id] +} +function restoreProgressImg2img(x){ + id = localStorage.getItem("img2img_task_id") + + if(id) { + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ + showSubmitButtons('img2img', true) + }, null, 0) + } + + return [id] +} + + +onUiLoaded(function () { + showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id")) + showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id")) +}); + + function modelmerger(){ var id = randomId() requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) diff --git a/modules/call_queue.py b/modules/call_queue.py index 92097c15..1829f3a6 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): try: res = func(*args, **kwargs) + progress.record_results(id_task, res) finally: progress.finish_task(id_task) diff --git a/modules/progress.py b/modules/progress.py index c69ecf3d..5655346b 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -13,6 +13,8 @@ import modules.shared as shared current_task = None pending_tasks = {} finished_tasks = [] +recorded_results = [] +recorded_results_limit = 2 def start_task(id_task): @@ -33,6 +35,12 @@ def finish_task(id_task): finished_tasks.pop(0) +def record_results(id_task, res): + recorded_results.append((id_task, res)) + if len(recorded_results) > recorded_results_limit: + recorded_results.pop(0) + + def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() @@ -97,3 +105,13 @@ def progressapi(req: ProgressRequest): return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) + +def restore_progress(id_task): + while id_task == current_task or id_task in pending_tasks: + time.sleep(0.1) + + res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None) + if res is not None: + return res + + return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained" diff --git a/modules/ui.py b/modules/ui.py index a32500d1..9ff4bcd9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -81,6 +81,7 @@ apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ +restore_progress_symbol = '\U0001F300' # 🌀 def plaintext_to_html(text): @@ -325,6 +326,7 @@ def create_toprow(is_img2img): extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") + restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") @@ -342,7 +344,7 @@ def create_toprow(is_img2img): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button def setup_progressbar(*args, **kwargs): @@ -459,7 +461,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) @@ -591,6 +593,19 @@ def create_ui(): res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) + restore_progress_button.click( + fn=progress.restore_progress, + _js="restoreProgressTxt2img", + inputs=[dummy_component], + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + txt_prompt_img.change( fn=modules.images.image_data, inputs=[ @@ -659,7 +674,7 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) @@ -951,6 +966,19 @@ def create_ui(): submit.click(**img2img_args) res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) + restore_progress_button.click( + fn=progress.restore_progress, + _js="restoreProgressImg2img", + inputs=[dummy_component], + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, -- cgit v1.2.3 From 3ba6c3c83c0983a025c7bddc08bb7f49481b3cbb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 9 May 2023 22:17:58 +0300 Subject: Fix up string formatting/concatenation to f-strings where feasible --- modules/api/api.py | 22 ++++++------ modules/call_queue.py | 5 +-- modules/esrgan_model.py | 11 +++--- modules/esrgan_model_arch.py | 16 ++++----- modules/extra_networks_hypernet.py | 3 +- modules/generation_parameters_copypaste.py | 4 +-- modules/hashes.py | 4 +-- modules/images.py | 8 ++--- modules/interrogate.py | 4 +-- modules/models/diffusion/ddpm_edit.py | 4 +-- modules/models/diffusion/uni_pc/uni_pc.py | 4 +-- modules/ngrok.py | 4 +-- modules/paths.py | 2 +- modules/processing.py | 13 ++++++-- modules/progress.py | 3 +- modules/realesrgan_model.py | 8 ++--- modules/scripts.py | 5 +-- modules/sd_hijack_clip_old.py | 3 +- modules/sd_hijack_unet.py | 2 +- modules/sd_models.py | 4 +-- modules/sd_models_config.py | 2 +- modules/sd_samplers_kdiffusion.py | 2 +- modules/sd_vae.py | 2 +- modules/styles.py | 2 +- modules/textual_inversion/autocrop.py | 6 ++-- modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/preprocess.py | 6 ++-- modules/textual_inversion/textual_inversion.py | 12 +++---- modules/ui.py | 46 +++++++++++++------------- modules/ui_extensions.py | 3 +- modules/ui_extra_networks.py | 4 ++- scripts/custom_code.py | 2 +- scripts/loopback.py | 2 +- scripts/xyz_grid.py | 2 +- 34 files changed, 121 insertions(+), 101 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/api/api.py b/modules/api/api.py index cdbdce32..9bb95dfd 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -570,20 +570,20 @@ class Api: filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used shared.state.end() - return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename)) + return CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: shared.state.end() - return TrainResponse(info = "create embedding error: {error}".format(error = e)) + return TrainResponse(info=f"create embedding error: {e}") def create_hypernetwork(self, args: dict): try: shared.state.begin() filename = create_hypernetwork(**args) # create empty embedding shared.state.end() - return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename)) + return CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: shared.state.end() - return TrainResponse(info = "create hypernetwork error: {error}".format(error = e)) + return TrainResponse(info=f"create hypernetwork error: {e}") def preprocess(self, args: dict): try: @@ -593,13 +593,13 @@ class Api: return PreprocessResponse(info = 'preprocess complete') except KeyError as e: shared.state.end() - return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e)) + return PreprocessResponse(info=f"preprocess error: invalid token: {e}") except AssertionError as e: shared.state.end() - return PreprocessResponse(info = "preprocess error: {error}".format(error = e)) + return PreprocessResponse(info=f"preprocess error: {e}") except FileNotFoundError as e: shared.state.end() - return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) + return PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: @@ -617,10 +617,10 @@ class Api: if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {msg}".format(msg = msg)) + return TrainResponse(info=f"train embedding error: {msg}") def train_hypernetwork(self, args: dict): try: @@ -641,10 +641,10 @@ class Api: if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) + return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") except AssertionError as msg: shared.state.end() - return TrainResponse(info="train embedding error: {error}".format(error=error)) + return TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: diff --git a/modules/call_queue.py b/modules/call_queue.py index 1829f3a6..447bb764 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -60,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): max_debug_str_len = 131072 # (1024*1024)/8 print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {str(args)} {str(kwargs)}" + argStr = f"Arguments: {args} {kwargs}" print(argStr[:max_debug_str_len], file=sys.stderr) if len(argStr) > max_debug_str_len: print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) @@ -73,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): if extra_outputs_array is None: extra_outputs_array = [None, ''] - res = extra_outputs_array + [f"
{html.escape(type(e).__name__+': '+str(e))}
"] + error_message = f'{type(e).__name__}: {e}' + res = extra_outputs_array + [f"
{html.escape(error_message)}
"] shared.state.skipped = False shared.state.interrupted = False diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 9a9c38f1..f4369257 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -156,13 +156,16 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, - file_name="%s.pth" % self.model_name, - progress=True) + filename = load_file_from_url( + url=self.model_url, + model_dir=self.model_path, + file_name=f"{self.model_name}.pth", + progress=True, + ) else: filename = path if not os.path.exists(filename) or filename is None: - print("Unable to load %s from %s" % (self.model_path, filename)) + print(f"Unable to load {self.model_path} from {filename}") return None state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py index 1b52b0f5..6071fea7 100644 --- a/modules/esrgan_model_arch.py +++ b/modules/esrgan_model_arch.py @@ -38,7 +38,7 @@ class RRDBNet(nn.Module): elif upsample_mode == 'pixelshuffle': upsample_block = pixelshuffle_block else: - raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found') if upscale == 3: upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) else: @@ -261,10 +261,10 @@ class Upsample(nn.Module): def extra_repr(self): if self.scale_factor is not None: - info = 'scale_factor=' + str(self.scale_factor) + info = f'scale_factor={self.scale_factor}' else: - info = 'size=' + str(self.size) - info += ', mode=' + self.mode + info = f'size={self.size}' + info += f', mode={self.mode}' return info @@ -350,7 +350,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): elif act_type == 'sigmoid': # [0, 1] range output layer = nn.Sigmoid() else: - raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) + raise NotImplementedError(f'activation layer [{act_type}] is not found') return layer @@ -372,7 +372,7 @@ def norm(norm_type, nc): elif norm_type == 'none': def norm_layer(x): return Identity() else: - raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) + raise NotImplementedError(f'normalization layer [{norm_type}] is not found') return layer @@ -388,7 +388,7 @@ def pad(pad_type, padding): elif pad_type == 'zero': layer = nn.ZeroPad2d(padding) else: - raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) + raise NotImplementedError(f'padding layer [{pad_type}] is not implemented') return layer @@ -432,7 +432,7 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias= pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', spectral_norm=False): """ Conv layer with padding, normalization, activation """ - assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) + assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]' padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None padding = padding if pad_type == 'zero' else 0 diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index 33d100dd..04f27c9f 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -10,7 +10,8 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): additional = shared.opts.sd_hypernetwork if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: - p.all_prompts = [x + f"" for x in p.all_prompts] + hypernet_prompt_text = f"" + p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 78248ed2..fe8b18b2 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -269,8 +269,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v m = re_imagesize.match(v) if m is not None: - res[k+"-1"] = m.group(1) - res[k+"-2"] = m.group(2) + res[f"{k}-1"] = m.group(1) + res[f"{k}-2"] = m.group(2) else: res[k] = v diff --git a/modules/hashes.py b/modules/hashes.py index 83272a07..032120f4 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -13,7 +13,7 @@ cache_data = None def dump_cache(): - with filelock.FileLock(cache_filename+".lock"): + with filelock.FileLock(f"{cache_filename}.lock"): with open(cache_filename, "w", encoding="utf8") as file: json.dump(cache_data, file, indent=4) @@ -22,7 +22,7 @@ def cache(subsection): global cache_data if cache_data is None: - with filelock.FileLock(cache_filename+".lock"): + with filelock.FileLock(f"{cache_filename}.lock"): if not os.path.isfile(cache_filename): cache_data = {} else: diff --git a/modules/images.py b/modules/images.py index 6ceb7c7c..a41965ab 100644 --- a/modules/images.py +++ b/modules/images.py @@ -467,7 +467,7 @@ def get_next_sequence_number(path, basename): """ result = -1 if basename != '': - basename = basename + "-" + basename = f"{basename}-" prefix_length = len(basename) for p in os.listdir(path): @@ -536,7 +536,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i add_number = opts.save_images_add_number or file_decoration == '' if file_decoration != "" and add_number: - file_decoration = "-" + file_decoration + file_decoration = f"-{file_decoration}" file_decoration = namegen.apply(file_decoration) + suffix @@ -566,7 +566,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i def _atomically_save_image(image_to_save, filename_without_extension, extension): # save image with .tmp extension to avoid race condition when another process detects new image in the directory - temp_file_path = filename_without_extension + ".tmp" + temp_file_path = f"{filename_without_extension}.tmp" image_format = Image.registered_extensions()[extension] if extension.lower() == '.png': @@ -626,7 +626,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i if opts.save_txt and info is not None: txt_fullfn = f"{fullfn_without_extension}.txt" with open(txt_fullfn, "w", encoding="utf8") as file: - file.write(info + "\n") + file.write(f"{info}\n") else: txt_fullfn = None diff --git a/modules/interrogate.py b/modules/interrogate.py index e1665708..9f7d657f 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -28,7 +28,7 @@ def category_types(): def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") - tmpdir = content_dir + "_tmp" + tmpdir = f"{content_dir}_tmp" category_types = ["artists", "flavors", "mediums", "movements"] try: @@ -214,7 +214,7 @@ class InterrogateModels: if shared.opts.interrogate_return_ranks: res += f", ({match}:{score/100:.3f})" else: - res += ", " + match + res += f", {match}" except Exception: print("Error interrogating", file=sys.stderr) diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index f3d49c44..f880bc3c 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -223,7 +223,7 @@ class DDPM(pl.LightningModule): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) @@ -386,7 +386,7 @@ class DDPM(pl.LightningModule): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index eb5f4e76..11b330bc 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -94,7 +94,7 @@ class NoiseScheduleVP: """ if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'") self.schedule = schedule if schedule == 'discrete': @@ -469,7 +469,7 @@ class UniPC: t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) return t else: - raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'") def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ diff --git a/modules/ngrok.py b/modules/ngrok.py index 1ad7989b..7a7b4b26 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -7,8 +7,8 @@ def connect(token, port, region): else: if ':' in token: # token = authtoken:username:password - account = token.split(':')[1] + ':' + token.split(':')[-1] - token = token.split(':')[0] + token, username, password = token.split(':', 2) + account = f"{username}:{password}" config = conf.PyngrokConfig( auth_token=token, region=region diff --git a/modules/paths.py b/modules/paths.py index 0e1e00e7..acf1894b 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths: sd_path = os.path.abspath(possible_sd_path) break -assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) +assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), diff --git a/modules/processing.py b/modules/processing.py index e786791a..1a76e552 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -500,7 +500,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) - negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" + negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() @@ -780,7 +780,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) + res = Processed( + p, + images_list=output_images, + seed=p.all_seeds[0], + info=infotext(), + comments="".join(f"\n\n{comment}" for comment in comments), + subseed=p.all_subseeds[0], + index_of_first_image=index_of_first_image, + infotexts=infotexts, + ) if p.scripts is not None: p.scripts.postprocess(p, res) diff --git a/modules/progress.py b/modules/progress.py index 5655346b..948e6f00 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -96,7 +96,8 @@ def progressapi(req: ProgressRequest): if image is not None: buffered = io.BytesIO() image.save(buffered, format="png") - live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") + base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') + live_preview = f"data:image/png;base64,{base64_image}" id_live_preview = shared.state.id_live_preview else: live_preview = None diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index d6079433..efd7fca5 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -28,9 +28,9 @@ class UpscalerRealESRGAN(Upscaler): for scaler in scalers: if scaler.local_data_path.startswith("http"): filename = modelloader.friendly_name(scaler.local_data_path) - local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None) - if local: - scaler.local_data_path = local + local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] + if local_model_candidates: + scaler.local_data_path = local_model_candidates[0] if scaler.name in opts.realesrgan_enabled_models: self.scalers.append(scaler) @@ -47,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler): info = self.load_model(path) if not os.path.exists(info.local_data_path): - print("Unable to load RealESRGAN model: %s" % info.name) + print(f"Unable to load RealESRGAN model: {info.name}") return img upsampler = RealESRGANer( diff --git a/modules/scripts.py b/modules/scripts.py index 4d0bbd66..d945b89f 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -163,7 +163,8 @@ class Script: """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" need_tabname = self.show(True) == self.show(False) - tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else "" + tabkind = 'img2img' if self.is_img2img else 'txt2txt' + tabname = f"{tabkind}_" if need_tabname else "" title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) return f'script_{tabname}{title}_{item_id}' @@ -526,7 +527,7 @@ def add_classes_to_gradio_component(comp): this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others """ - comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])] + comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] if getattr(comp, 'multiselect', False): comp.elem_classes.append('multiselect') diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index 6d9fbbe6..a3476e95 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) + self.hijack.comments.append(f"Used embeddings: {embedding_names}") self.hijack.fixes = hijack_fixes return self.process_tokens(remade_batch_tokens, batch_multipliers) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 15858263..ca1daf45 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -18,7 +18,7 @@ class TorchHijackForUnet: if hasattr(torch, item): return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") def cat(self, tensors, *args, **kwargs): if len(tensors) == 2: diff --git a/modules/sd_models.py b/modules/sd_models.py index 59adc7cc..36f643e1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -47,7 +47,7 @@ class CheckpointInfo: self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) + self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}") self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' @@ -69,7 +69,7 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) + self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") if self.sha256 is None: return diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9398f528..7a79925a 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -111,7 +111,7 @@ def find_checkpoint_config_near_filename(info): if info is None: return None - config = os.path.splitext(info.filename)[0] + ".yaml" + config = f"{os.path.splitext(info.filename)[0]}.yaml" if os.path.exists(config): return config diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index eb98e599..0fc9f456 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -198,7 +198,7 @@ class TorchHijack: if hasattr(torch, item): return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") def randn_like(self, x): if self.sampler_noises: diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9b00f76e..521e485a 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -89,7 +89,7 @@ def refresh_vae_list(): def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.splitext(checkpoint_file)[0] - for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]: if os.path.isfile(vae_location): return vae_location diff --git a/modules/styles.py b/modules/styles.py index 9ed85991..11642075 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -74,7 +74,7 @@ class StyleDatabase: def save_styles(self, path: str) -> None: # Always keep a backup file around if os.path.exists(path): - shutil.copy(path, path + ".bak") + shutil.copy(path, f"{path}.bak") fd = os.open(path, os.O_RDWR|os.O_CREAT) with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 68e1103c..ba1bdcd4 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -111,7 +111,7 @@ def focal_point(im, settings): if corner_centroid is not None: color = BLUE box = corner_centroid.bounding(max_size * corner_centroid.weight) - d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color) + d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color) d.ellipse(box, outline=color) if len(corner_points) > 1: for f in corner_points: @@ -119,7 +119,7 @@ def focal_point(im, settings): if entropy_centroid is not None: color = "#ff0" box = entropy_centroid.bounding(max_size * entropy_centroid.weight) - d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color) + d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) d.ellipse(box, outline=color) if len(entropy_points) > 1: for f in entropy_points: @@ -127,7 +127,7 @@ def focal_point(im, settings): if face_centroid is not None: color = RED box = face_centroid.bounding(max_size * face_centroid.weight) - d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color) + d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color) d.ellipse(box, outline=color) if len(face_points) > 1: for f in face_points: diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index af9fbcf2..41610e03 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -72,7 +72,7 @@ class PersonalizedBase(Dataset): except Exception: continue - text_filename = os.path.splitext(path)[0] + ".txt" + text_filename = f"{os.path.splitext(path)[0]}.txt" filename = os.path.basename(path) if os.path.exists(text_filename): diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 4a29151d..da0bcb26 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -63,9 +63,9 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti image.save(os.path.join(params.dstdir, f"{basename}.png")) if params.preprocess_txt_action == 'prepend' and existing_caption: - caption = existing_caption + ' ' + caption + caption = f"{existing_caption} {caption}" elif params.preprocess_txt_action == 'append' and existing_caption: - caption = caption + ' ' + existing_caption + caption = f"{caption} {existing_caption}" elif params.preprocess_txt_action == 'copy' and existing_caption: caption = existing_caption @@ -174,7 +174,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre params.src = filename existing_caption = None - existing_caption_filename = os.path.splitext(filename)[0] + '.txt' + existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt" if os.path.exists(existing_caption_filename): with open(existing_caption_filename, 'r', encoding="utf8") as file: existing_caption = file.read() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 379df243..4368eb63 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -69,7 +69,7 @@ class Embedding: 'hash': self.checksum(), 'optimizer_state_dict': self.optimizer_state_dict, } - torch.save(optimizer_saved_dict, filename + '.optim') + torch.save(optimizer_saved_dict, f"{filename}.optim") def checksum(self): if self.cached_checksum is not None: @@ -437,8 +437,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) if shared.opts.save_optimizer_state: optimizer_state_dict = None - if os.path.exists(filename + '.optim'): - optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') + if os.path.exists(f"{filename}.optim"): + optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu') if embedding.checksum() == optimizer_saved_dict.get('hash', None): optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) @@ -599,7 +599,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embedding_to_b64(data)) - title = "<{}>".format(data.get('name', '???')) + title = f"<{data.get('name', '???')}>" try: vectorSize = list(data['string_to_param'].values())[0].shape[0] @@ -608,8 +608,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.shorthash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) + footer_mid = f'[{checkpoint.shorthash}]' + footer_right = f'{vectorSize}v {steps_done}s' captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = insert_image_data_embed(captioned_image, data) diff --git a/modules/ui.py b/modules/ui.py index 34b2aaff..d02f6e82 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -101,7 +101,7 @@ def visit(x, func, path=""): for c in x.children: visit(c, func, path) elif x.label is not None: - func(path + "/" + str(x.label), x) + func(f"{path}/{x.label}", x) def add_style(name: str, prompt: str, negative_prompt: str): @@ -166,7 +166,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di img = Image.open(image) filename = os.path.basename(image) left, _ = os.path.splitext(filename) - print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) + print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a')) return [gr.update(), None] @@ -182,29 +182,29 @@ def interrogate_deepbooru(image): def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row', variant="compact"): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed") seed.style(container=False) - random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed') - reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed') + random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed') + reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed') - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False) # Components to show/hide based on the 'Extra' checkbox seed_extras = [] - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1: seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed") subseed.style(container=False) - random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed") + reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed") + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength") with FormRow(visible=False) as seed_extra_row_2: seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w") + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h") random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) @@ -765,7 +765,7 @@ def create_ui(): ) button.click( fn=lambda: None, - _js="switch_to_"+name.replace(" ", "_"), + _js=f"switch_to_{name.replace(' ', '_')}", inputs=[], outputs=[], ) @@ -1462,18 +1462,18 @@ def create_ui(): elif t == bool: comp = gr.Checkbox else: - raise Exception(f'bad options item type: {str(t)} for key {key}') + raise Exception(f'bad options item type: {t} for key {key}') - elem_id = "setting_"+key + elem_id = f"setting_{key}" if info.refresh is not None: if is_quicksettings: res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") else: with FormRow(): res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") else: res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) @@ -1545,7 +1545,7 @@ def create_ui(): current_tab.__exit__() gr.Group() - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) current_tab.__enter__() current_row = gr.Column(variant='compact') current_row.__enter__() @@ -1664,7 +1664,7 @@ def create_ui(): for interface, label, ifid in interfaces: if label in shared.opts.hidden_tabs: continue - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"): interface.render() if os.path.exists(os.path.join(script_path, "notification.mp3")): @@ -1771,10 +1771,10 @@ def create_ui(): def loadsave(path, x): def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field + key = f"{path}/{field}" if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key + key = f"customscript/{obj.custom_script_source}/{key}" if getattr(obj, 'do_not_save_to_config', False): return diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 99ac8756..d9faf85a 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -61,7 +61,8 @@ def save_config_state(name): if not name: name = "Config" current_config_state["name"] = name - filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + name + ".json") + timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S') + filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") print(f"Saving backup of webui/extension state to {filename}.") with open(filename, "w", encoding="utf-8") as f: json.dump(current_config_state, f) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 86c05a55..8c3dea56 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -69,7 +69,9 @@ class ExtraNetworksPage: pass def link_preview(self, filename): - return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) + mtime = os.path.getmtime(filename) + return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" def search_terms_from_path(self, filename, possible_directories=None): abspath = os.path.abspath(filename) diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 4071d86d..f36a3675 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -77,7 +77,7 @@ return process_images(p) module.display = display indent = " " * indent_level - indented = code.replace('\n', '\n' + indent) + indented = code.replace('\n', f"\n{indent}") body = f"""def __webuitemp__(): {indent}{indented} __webuitemp__()""" diff --git a/scripts/loopback.py b/scripts/loopback.py index d3065fe6..ad6609be 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -84,7 +84,7 @@ class Script(scripts.Script): p.color_corrections = initial_color_corrections if append_interrogation != "None": - p.prompt = original_prompt + ", " if original_prompt != "" else "" + p.prompt = f"{original_prompt}, " if original_prompt else "" if append_interrogation == "CLIP": p.prompt += shared.interrogator.interrogate(p.init_images[0]) elif append_interrogation == "DeepBooru": diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 01d97791..a725d74a 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -439,7 +439,7 @@ class Script(scripts.Script): z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown]) def get_dropdown_update_from_params(axis,params): - val_key = axis + " Values" + val_key = f"{axis} Values" vals = params.get(val_key,"") valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] return gr.update(value = valslist) -- cgit v1.2.3 From b7e160a87d07b2fd1c12812c43786e141cc86bd5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 11 May 2023 08:14:45 +0300 Subject: change live preview format to jpeg to prevent unreasonably slow previews for large images, and add an option to let user select the format --- modules/progress.py | 4 ++-- modules/shared.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index 948e6f00..289dd311 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -95,9 +95,9 @@ def progressapi(req: ProgressRequest): image = shared.state.current_image if image is not None: buffered = io.BytesIO() - image.save(buffered, format="png") + image.save(buffered, format=opts.live_previews_format) base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') - live_preview = f"data:image/png;base64,{base64_image}" + live_preview = f"data:image/{opts.live_previews_format};base64,{base64_image}" id_live_preview = shared.state.id_live_preview else: live_preview = None diff --git a/modules/shared.py b/modules/shared.py index ac67adc0..fc39161e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -420,6 +420,7 @@ options_templates.update(options_section(('infotext', "Infotext"), { options_templates.update(options_section(('ui', "Live previews"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "live_previews_format": OptionInfo("jpeg", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), -- cgit v1.2.3 From da10de022f69e7847bcc64a7914d56246d852e20 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 20:52:30 +0300 Subject: Make live previews use JPEG only when the image is lorge enough --- modules/progress.py | 12 ++++++++++-- modules/shared.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index 289dd311..c2e37834 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -95,9 +95,17 @@ def progressapi(req: ProgressRequest): image = shared.state.current_image if image is not None: buffered = io.BytesIO() - image.save(buffered, format=opts.live_previews_format) + format = opts.live_previews_format + save_kwargs = {} + if format == "auto": + if max(*image.size) > 256: + format = "jpeg" + else: + format = "png" + save_kwargs = {"optimize": True} + image.save(buffered, format=format, **save_kwargs) base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') - live_preview = f"data:image/{opts.live_previews_format};base64,{base64_image}" + live_preview = f"data:image/{format};base64,{base64_image}" id_live_preview = shared.state.id_live_preview else: live_preview = None diff --git a/modules/shared.py b/modules/shared.py index f387b5ae..22b45618 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -420,7 +420,7 @@ options_templates.update(options_section(('infotext', "Infotext"), { options_templates.update(options_section(('ui', "Live previews"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), - "live_previews_format": OptionInfo("jpeg", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), + "live_previews_format": OptionInfo("auto", "Live preview file format", gr.Radio, {"choices": ["auto", "jpeg", "png", "webp"]}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), -- cgit v1.2.3 From a58ae0b7174d9903fa426def2eda842dbbfcb53c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 14 May 2023 11:15:15 +0300 Subject: remove auto live previews format option, fix slow PNG generation --- modules/progress.py | 19 +++++++++---------- modules/shared.py | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index c2e37834..269863c9 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -95,17 +95,16 @@ def progressapi(req: ProgressRequest): image = shared.state.current_image if image is not None: buffered = io.BytesIO() - format = opts.live_previews_format - save_kwargs = {} - if format == "auto": - if max(*image.size) > 256: - format = "jpeg" - else: - format = "png" - save_kwargs = {"optimize": True} - image.save(buffered, format=format, **save_kwargs) + + if opts.live_previews_image_format == "png": + # using optimize for large images takes an enormous amount of time + save_kwargs = {"optimize": max(*image.size) > 256} + else: + save_kwargs = {} + + image.save(buffered, format=opts.live_previews_image_format, **save_kwargs) base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') - live_preview = f"data:image/{format};base64,{base64_image}" + live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}" id_live_preview = shared.state.id_live_preview else: live_preview = None diff --git a/modules/shared.py b/modules/shared.py index a0577644..07f18b1b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -445,7 +445,7 @@ options_templates.update(options_section(('infotext', "Infotext"), { options_templates.update(options_section(('ui', "Live previews"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), - "live_previews_format": OptionInfo("auto", "Live preview file format", gr.Radio, {"choices": ["auto", "jpeg", "png", "webp"]}), + "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"), "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}).info("Full = slow but pretty; Approx NN = fast but low quality; Approx cheap = super fast but terrible otherwise"), -- cgit v1.2.3 From 9fd6c1e3430f5947add23e2e94ac816c2546481c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 17 May 2023 20:22:38 +0300 Subject: move some settings to the new Optimization page add slider for token merging for img2img rework StableDiffusionProcessing to have the token_merging_ratio field fix a bug with applying png optimizations for live previews when they shouldn't be applied --- modules/processing.py | 52 +++++++++++++++++++++++++-------------------------- modules/progress.py | 6 +++++- modules/sd_models.py | 36 +++++++++++++++++++---------------- modules/shared.py | 8 ++++++-- 4 files changed, 56 insertions(+), 46 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/processing.py b/modules/processing.py index cd63b9a6..2b8dd361 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -29,12 +29,6 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType -import tomesd - -# add a logger for the processing module -logger = logging.getLogger(__name__) -# manually set output level here since there is no option to do so yet through launch options -# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s') # some of those options should not be changed at all because they would break the model, so I removed them from options. @@ -156,6 +150,8 @@ class StableDiffusionProcessing: self.override_settings_restore_afterwards = override_settings_restore_afterwards self.is_using_inpainting_conditioning = False self.disable_extra_networks = False + self.token_merging_ratio = 0 + self.token_merging_ratio_hr = 0 if not seed_enable_extras: self.subseed = -1 @@ -171,6 +167,7 @@ class StableDiffusionProcessing: self.all_subseeds = None self.iteration = 0 self.is_hr_pass = False + self.sampler = None @property @@ -280,6 +277,12 @@ class StableDiffusionProcessing: def close(self): self.sampler = None + def get_token_merging_ratio(self, for_hr=False): + if for_hr: + return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio + + return self.token_merging_ratio or opts.token_merging_ratio + class Processed: def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""): @@ -309,6 +312,8 @@ class Processed: self.styles = p.styles self.job_timestamp = state.job_timestamp self.clip_skip = opts.CLIP_stop_at_last_layers + self.token_merging_ratio = p.token_merging_ratio + self.token_merging_ratio_hr = p.token_merging_ratio_hr self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -367,6 +372,9 @@ class Processed: def infotext(self, p: StableDiffusionProcessing, index): return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) + def get_token_merging_ratio(self, for_hr=False): + return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio + # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 def slerp(val, low, high): @@ -480,6 +488,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) enable_hr = getattr(p, 'enable_hr', False) + token_merging_ratio = p.get_token_merging_ratio() + token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True) uses_ensd = opts.eta_noise_seed_delta != 0 if uses_ensd: @@ -502,8 +512,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": opts.eta_noise_seed_delta if uses_ensd else None, - "Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio, - "Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr, + "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio, + "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), "RNG": opts.randn_source if opts.randn_source != "GPU" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, @@ -536,17 +546,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_vae': sd_vae.reload_vae_weights() - if opts.token_merging_ratio > 0: - sd_models.apply_token_merging(sd_model=p.sd_model, hr=False) - logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'") + sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio()) res = process_images_inner(p) finally: - # undo model optimizations made by tomesd - if opts.token_merging_ratio > 0: - tomesd.remove_patch(p.sd_model) - logger.debug('Token merging model optimizations removed') + sd_models.apply_token_merging(p.sd_model, 0) # restore opts to original state if p.override_settings_restore_afterwards: @@ -996,21 +1001,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - # apply token merging optimizations from tomesd for high-res pass - if opts.token_merging_ratio_hr > 0: - # in case the user has used separate merge ratios - if opts.token_merging_ratio > 0: - tomesd.remove_patch(self.sd_model) - logger.debug('Adjusting token merging ratio for high-res pass') - - sd_models.apply_token_merging(sd_model=self.sd_model, hr=True) - logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'") + sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) - if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0: - tomesd.remove_patch(self.sd_model) - logger.debug('Removed token merging optimizations from model') + sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.is_hr_pass = False @@ -1173,3 +1168,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): devices.torch_gc() return samples + + def get_token_merging_ratio(self, for_hr=False): + return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio diff --git a/modules/progress.py b/modules/progress.py index 269863c9..f405f07f 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -98,7 +98,11 @@ def progressapi(req: ProgressRequest): if opts.live_previews_image_format == "png": # using optimize for large images takes an enormous amount of time - save_kwargs = {"optimize": max(*image.size) > 256} + if max(*image.size) <= 256: + save_kwargs = {"optimize": True} + else: + save_kwargs = {"optimize": False, "compress_level": 1} + else: save_kwargs = {} diff --git a/modules/sd_models.py b/modules/sd_models.py index e612be10..4bd8783e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -583,23 +583,27 @@ def unload_model_weights(sd_model=None, info=None): return sd_model -def apply_token_merging(sd_model, hr: bool): +def apply_token_merging(sd_model, token_merging_ratio): """ Applies speed and memory optimizations from tomesd. - - Args: - hr (bool): True if called in the context of a high-res pass """ - ratio = shared.opts.token_merging_ratio - if hr: - ratio = shared.opts.token_merging_ratio_hr - - tomesd.apply_patch( - sd_model, - ratio=ratio, - use_rand=False, # can cause issues with some samplers - merge_attn=True, - merge_crossattn=False, - merge_mlp=False - ) + current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0) + + if current_token_merging_ratio == token_merging_ratio: + return + + if current_token_merging_ratio > 0: + tomesd.remove_patch(sd_model) + + if token_merging_ratio > 0: + tomesd.apply_patch( + sd_model, + ratio=token_merging_ratio, + use_rand=False, # can cause issues with some samplers + merge_attn=True, + merge_crossattn=False, + merge_mlp=False + ) + + sd_model.applied_token_merged_ratio = token_merging_ratio diff --git a/modules/shared.py b/modules/shared.py index 47bc6d0e..76af8b9c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -413,8 +413,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"), +})) + +options_templates.update(options_section(('optimizations', "Optimizations"), { + "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), - "token_merging_ratio_hr": OptionInfo(0.0, "Togen merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}), + "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), + "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -498,7 +503,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"), "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.3