")
--
cgit v1.2.3
From bdbe09827b39be63c9c0b3636132ca58da38ebf6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 31 Dec 2022 22:49:09 +0300
Subject: changed embedding accepted shape detection to use existing code and
support the new alt-diffusion model, and reformatted messages a bit #6149
---
modules/textual_inversion/textual_inversion.py | 30 ++++++--------------------
1 file changed, 6 insertions(+), 24 deletions(-)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 103ace60..66f40367 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -80,23 +80,8 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
- expected_shape = -1 # initialize with unknown
- idx = torch.tensor(0).to(shared.device)
- if expected_shape == -1:
- try: # matches sd15 signature
- first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
- expected_shape = first_embedding.shape[0]
- except:
- pass
- if expected_shape == -1:
- try: # matches sd20 signature
- first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
- expected_shape = first_embedding.shape[0]
- except:
- pass
- if expected_shape == -1:
- print('Could not determine expected embeddings shape from model')
- return expected_shape
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
@@ -112,8 +97,6 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = []
-
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
@@ -150,11 +133,10 @@ class EmbeddingDatabase:
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
- if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings.append(name)
- # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os.listdir(self.embeddings_dir):
try:
@@ -169,9 +151,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr)
continue
- print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
- if (len(self.skipped_embeddings) > 0):
- print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
--
cgit v1.2.3
From f4535f6e4f001314bd155bc6e1b6908e02792b9a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 31 Dec 2022 23:40:55 +0300
Subject: make it so that memory/embeddings info is displayed in a separate UI
element from generation parameters, and is preserved when you change the
displayed infotext by clicking on gallery images
---
modules/img2img.py | 2 +-
modules/processing.py | 5 +++--
modules/txt2img.py | 2 +-
modules/ui.py | 31 +++++++++++++++++--------------
4 files changed, 22 insertions(+), 18 deletions(-)
diff --git a/modules/img2img.py b/modules/img2img.py
index 81da4b13..ca58b5d8 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/processing.py b/modules/processing.py
index 0a9a8f95..42dc19ea 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -239,7 +239,7 @@ class StableDiffusionProcessing():
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -247,6 +247,7 @@ class Processed:
self.subseed = subseed
self.subseed_strength = p.subseed_strength
self.info = info
+ self.comments = comments
self.width = p.width
self.height = p.height
self.sampler_name = p.sampler_name
@@ -646,7 +647,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+ res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None:
p.scripts.postprocess(p, res)
diff --git a/modules/txt2img.py b/modules/txt2img.py
index c8f81176..7f61e19a 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -59,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/ui.py b/modules/ui.py
index 397dd804..f550ad00 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -159,7 +159,7 @@ def save_files(js_data, images, do_make_zip, index):
zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath)
- return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
@@ -593,6 +593,8 @@ Requested path was: {f}
with gr.Group():
html_info = gr.HTML()
+ html_log = gr.HTML()
+
generation_info = gr.Textbox(visible=False)
if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
@@ -615,16 +617,16 @@ Requested path was: {f}
],
outputs=[
download_files,
- html_info,
- html_info,
- html_info,
+ html_log,
]
)
else:
html_info_x = gr.HTML()
html_info = gr.HTML()
+ html_log = gr.HTML()
+
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
def create_ui():
@@ -686,14 +688,14 @@ def create_ui():
with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
- txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
txt2img_prompt,
@@ -720,7 +722,8 @@ def create_ui():
outputs=[
txt2img_gallery,
generation_info,
- html_info
+ html_info,
+ html_log,
],
show_progress=False,
)
@@ -799,7 +802,6 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
-
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -883,7 +885,7 @@ def create_ui():
with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
- img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
@@ -915,7 +917,7 @@ def create_ui():
)
img2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.img2img.img2img),
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
inputs=[
dummy_component,
@@ -954,7 +956,8 @@ def create_ui():
outputs=[
img2img_gallery,
generation_info,
- html_info
+ html_info,
+ html_log,
],
show_progress=False,
)
@@ -1078,10 +1081,10 @@ def create_ui():
with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
- result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
+ result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
- fn=wrap_gradio_gpu_call(modules.extras.run_extras),
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
_js="get_extras_tab_index",
inputs=[
dummy_component,
--
cgit v1.2.3
From 360feed9b55fb03060c236773867b08b4265645d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 00:38:58 +0300
Subject: HAPPY NEW YEAR
make save to zip into its own button instead of a checkbox
---
modules/ui.py | 30 ++++++++++++++++++++++--------
style.css | 6 ++++++
2 files changed, 28 insertions(+), 8 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index f550ad00..279b5110 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -570,13 +570,14 @@ Requested path was: {f}
generation_info = None
with gr.Column():
- with gr.Row():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
+
if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
open_folder_button.click(
fn=lambda: open_folder(opts.outdir_samples or outdir),
@@ -585,9 +586,6 @@ Requested path was: {f}
)
if tabname != "extras":
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
@@ -608,11 +606,11 @@ Requested path was: {f}
save.click(
fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[
generation_info,
result_gallery,
- do_make_zip,
+ html_info,
html_info,
],
outputs=[
@@ -620,6 +618,22 @@ Requested path was: {f}
html_log,
]
)
+
+ save_zip.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
else:
html_info_x = gr.HTML()
html_info = gr.HTML()
diff --git a/style.css b/style.css
index 3ad78006..f245f674 100644
--- a/style.css
+++ b/style.css
@@ -568,6 +568,12 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
font-size: 95%;
}
+#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
+ min-width: auto;
+ padding-left: 0.5em;
+ padding-right: 0.5em;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
--
cgit v1.2.3
From 29a3a7eb13478297bc7093971b48827ab8246f45 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 01:19:10 +0300
Subject: show sampler selection in dropdown, add option selection to revert to
old radio group
---
modules/shared.py | 1 +
modules/ui.py | 22 +++++++++++++++-------
2 files changed, 16 insertions(+), 7 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 715b9169..948b9542 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -406,6 +406,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
diff --git a/modules/ui.py b/modules/ui.py
index 279b5110..c7b8ea5d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -643,6 +643,19 @@ Requested path was: {f}
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+def create_sampler_and_steps_selection(choices, tabname):
+ if opts.samplers_in_dropdown:
+ with gr.Row(elem_id=f"sampler_selection_{tabname}"):
+ sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+ else:
+ with gr.Group(elem_id=f"sampler_selection_{tabname}"):
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+ sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+ return steps, sampler_index
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -660,9 +673,6 @@ def create_ui():
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
-
-
-
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
pass
@@ -674,8 +684,7 @@ def create_ui():
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel', elem_id="txt2img_settings"):
- steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
@@ -875,8 +884,7 @@ def create_ui():
with gr.Row():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
- steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
+ steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
--
cgit v1.2.3
From 210449b374d522c94a67fe54289a9eb515933a9f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 02:41:15 +0300
Subject: fix 'RuntimeError: Expected all tensors to be on the same device'
error preventing models from loading on lowvram/medvram.
---
modules/sd_hijack_clip.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 6ec50cca..ca92b142 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -298,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
--
cgit v1.2.3
From a939e82a0b982517aa212197a0e5f6d11daec7d0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 03:24:58 +0300
Subject: fix weird padding for sampler dropdown in chrome
---
style.css | 5 -----
1 file changed, 5 deletions(-)
diff --git a/style.css b/style.css
index f245f674..4b98b84d 100644
--- a/style.css
+++ b/style.css
@@ -245,11 +245,6 @@ input[type="range"]{
margin: 0.5em 0 -0.3em 0;
}
-#txt2img_sampling label{
- padding-left: 0.6em;
- padding-right: 0.6em;
-}
-
#mask_bug_info {
text-align: center;
display: block;
--
cgit v1.2.3
From 16b9661d2741b241c3964fcbd56559c078b84822 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 09:51:37 +0300
Subject: change karras scheduler sigmas to values recommended by SD from old
0.1 to 10 with an option to revert to old
---
modules/sd_samplers.py | 4 +++-
modules/shared.py | 6 +++++-
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 177b5338..e904d860 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -465,7 +465,9 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
diff --git a/modules/shared.py b/modules/shared.py
index 948b9542..7f430b93 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -368,13 +368,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
- "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
+options_templates.update(options_section(('compatibility', "Compatibility"), {
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
+ "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+}))
+
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
--
cgit v1.2.3
From 11d432d92d63660c516540dcb48faac87669b4f0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 10:35:38 +0300
Subject: add refresh buttons to checkpoint merger
---
modules/ui.py | 6 ++++++
style.css | 2 +-
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/modules/ui.py b/modules/ui.py
index c7b8ea5d..4cc2ce4f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1167,8 +1167,14 @@ def create_ui():
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
+
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
+
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
+ create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
+
custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
diff --git a/style.css b/style.css
index 4b98b84d..516ef7bf 100644
--- a/style.css
+++ b/style.css
@@ -496,7 +496,7 @@ input[type="range"]{
padding: 0;
}
-#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
+#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_checkpoint_A, #refresh_checkpoint_B, #refresh_checkpoint_C{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
--
cgit v1.2.3
From 76f256fe8f844641f4e9b41f35c7dd2cba5090d6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 11:08:39 +0300
Subject: Bump gradio version #YOLO
---
modules/ui_tempdir.py | 3 ++-
requirements.txt | 2 +-
requirements_versions.txt | 2 +-
3 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 07210d14..8d519310 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -15,7 +15,8 @@ Savedfile = namedtuple("Savedfile", ["name"])
def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
- shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
+ shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(already_saved_as)}
+
file_obj = Savedfile(already_saved_as)
return file_obj
diff --git a/requirements.txt b/requirements.txt
index 5bed694e..e2c3876b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,7 +5,7 @@ fairscale==0.4.4
fonts
font-roboto
gfpgan
-gradio==3.9
+gradio==3.15.0
invisible-watermark
numpy
omegaconf
diff --git a/requirements_versions.txt b/requirements_versions.txt
index c126c8c4..836523ba 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -3,7 +3,7 @@ transformers==4.19.2
accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
-gradio==3.9
+gradio==3.15.0
numpy==1.23.3
Pillow==9.2.0
realesrgan==0.3.0
--
cgit v1.2.3
From b46b97fa297b3a4a654da77cf98a775a2bcab4c7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 11:38:17 +0300
Subject: more fixes for gradio update
---
modules/generation_parameters_copypaste.py | 2 +-
modules/ui_tempdir.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index fbd91300..54b3372d 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -38,7 +38,7 @@ def quote(text):
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
- is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
+ is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets])
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename)
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 8d519310..363d449d 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -45,7 +45,7 @@ def on_tmpdir_changed():
os.makedirs(shared.opts.temp_dir, exist_ok=True)
- shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
+ shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(shared.opts.temp_dir)}
def cleanup_tmpdr():
--
cgit v1.2.3
From e5f1a37cb9b537d95b2df47c96b4a4f7242fd294 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 13:08:40 +0300
Subject: make refresh buttons look more nice
---
modules/ui.py | 6 +++---
modules/ui_components.py | 18 ++++++++++++++++++
style.css | 28 +++++++++++++++++++++-------
3 files changed, 42 insertions(+), 10 deletions(-)
create mode 100644 modules/ui_components.py
diff --git a/modules/ui.py b/modules/ui.py
index 4cc2ce4f..32fa80d1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, ui_components
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -532,7 +532,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
return gr.update(**(args or {}))
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button = ui_components.ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
@@ -1476,7 +1476,7 @@ def create_ui():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- with gr.Row(variant="compact"):
+ with ui_components.FormRow():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
diff --git a/modules/ui_components.py b/modules/ui_components.py
new file mode 100644
index 00000000..d0519d2d
--- /dev/null
+++ b/modules/ui_components.py
@@ -0,0 +1,18 @@
+import gradio as gr
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+class FormRow(gr.Row, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "row"
diff --git a/style.css b/style.css
index 516ef7bf..f168571e 100644
--- a/style.css
+++ b/style.css
@@ -496,13 +496,6 @@ input[type="range"]{
padding: 0;
}
-#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_checkpoint_A, #refresh_checkpoint_B, #refresh_checkpoint_C{
- max-width: 2.5em;
- min-width: 2.5em;
- height: 2.4em;
-}
-
-
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
@@ -569,6 +562,27 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
padding-right: 0.5em;
}
+.gr-form{
+ background-color: white;
+}
+
+.dark .gr-form{
+ background-color: rgb(31 41 55 / var(--tw-bg-opacity));
+}
+
+.gr-button-tool{
+ max-width: 2.5em;
+ min-width: 2.5em !important;
+ height: 2.4em;
+ margin: 0.55em 0;
+}
+
+#quicksettings .gr-button-tool{
+ margin: 0;
+}
+
+
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
--
cgit v1.2.3
From 5f12b23b8bb7fca585a3a1e844881d06f171364e Mon Sep 17 00:00:00 2001
From: AlUlkesh <99896447+AlUlkesh@users.noreply.github.com>
Date: Wed, 28 Dec 2022 22:18:19 +0100
Subject: Adding image numbers on grids
New grid option in settings enables adding of image numbers on grids. This makes identifying the images, especially in larger batches, much easier.
Revert "Adding image numbers on grids"
This reverts commit 3530c283b4b1d3a3cab40efbffe4cf2697938b6f.
Implements Callback for image grid loop
Necessary to make "Add image's number to its picture in the grid" extension possible.
---
modules/images.py | 1 +
modules/script_callbacks.py | 20 ++++++++++++++++++++
2 files changed, 21 insertions(+)
diff --git a/modules/images.py b/modules/images.py
index 31d4528d..5afd3891 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -43,6 +43,7 @@ def image_grid(imgs, batch_size=1, rows=None):
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
for i, img in enumerate(imgs):
+ script_callbacks.image_grid_loop_callback(img)
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 8e22f875..0c854407 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -51,6 +51,11 @@ class UiTrainTabParams:
self.txt2img_preview_params = txt2img_preview_params
+class ImageGridLoopParams:
+ def __init__(self, img):
+ self.img = img
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
@@ -63,6 +68,7 @@ callback_map = dict(
callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
+ callbacks_image_grid_loop=[],
)
@@ -154,6 +160,12 @@ def after_component_callback(component, **kwargs):
except Exception:
report_exception(c, 'after_component_callback')
+def image_grid_loop_callback(component, **kwargs):
+ for c in callback_map['callbacks_image_grid_loop']:
+ try:
+ c.callback(component, **kwargs)
+ except Exception:
+ report_exception(c, 'image_grid_loop')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
@@ -255,3 +267,11 @@ def on_before_component(callback):
def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)
+
+
+def on_image_grid_loop(callback):
+ """register a function to be called inside the image grid loop.
+ The callback is called with one argument:
+ - params: ImageGridLoopParams - parameters to be used inside the image grid loop.
+ """
+ add_callback(callback_map['callbacks_image_grid_loop'], callback)
--
cgit v1.2.3
From 524d532b387732d4d32f237e792c7f201a934400 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 14:07:40 +0300
Subject: moved roll artist to built-in extensions
---
.../roll-artist/scripts/roll-artist.py | 50 ++++++++++++++++++++++
modules/ui.py | 37 ++--------------
2 files changed, 53 insertions(+), 34 deletions(-)
create mode 100644 extensions-builtin/roll-artist/scripts/roll-artist.py
diff --git a/extensions-builtin/roll-artist/scripts/roll-artist.py b/extensions-builtin/roll-artist/scripts/roll-artist.py
new file mode 100644
index 00000000..c3bc1fd0
--- /dev/null
+++ b/extensions-builtin/roll-artist/scripts/roll-artist.py
@@ -0,0 +1,50 @@
+import random
+
+from modules import script_callbacks, shared
+import gradio as gr
+
+art_symbol = '\U0001f3a8' # 🎨
+global_prompt = None
+related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
+
+
+def roll_artist(prompt):
+ allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
+ artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
+
+ return prompt + ", " + artist.name if prompt != '' else artist.name
+
+
+def add_roll_button(prompt):
+ roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
+
+ roll.click(
+ fn=roll_artist,
+ _js="update_txt2img_tokens",
+ inputs=[
+ prompt,
+ ],
+ outputs=[
+ prompt,
+ ]
+ )
+
+
+def after_component(component, **kwargs):
+ global global_prompt
+
+ elem_id = kwargs.get('elem_id', None)
+ if elem_id not in related_ids:
+ return
+
+ if elem_id == "txt2img_prompt":
+ global_prompt = component
+ elif elem_id == "txt2img_clear_prompt":
+ add_roll_button(global_prompt)
+ elif elem_id == "img2img_prompt":
+ global_prompt = component
+ elif elem_id == "img2img_clear_prompt":
+ add_roll_button(global_prompt)
+
+
+script_callbacks.on_after_component(after_component)
diff --git a/modules/ui.py b/modules/ui.py
index 32fa80d1..27da2c2c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -80,7 +80,6 @@ css_hide_progressbar = """
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
-art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@@ -234,13 +233,6 @@ def check_progress_call_initial(id_part):
return check_progress_call(id_part)
-def roll_artist(prompt):
- allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
- artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
-
- return prompt + ", " + artist.name if prompt != '' else artist.name
-
-
def visit(x, func, path=""):
if hasattr(x, 'children'):
for c in x.children:
@@ -403,7 +395,6 @@ def create_toprow(is_img2img):
)
with gr.Column(scale=1, elem_id="roll_col"):
- roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
@@ -452,7 +443,7 @@ def create_toprow(is_img2img):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -668,7 +659,7 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -771,16 +762,6 @@ def create_ui():
outputs=[hr_options],
)
- roll.click(
- fn=roll_artist,
- _js="update_txt2img_tokens",
- inputs=[
- txt2img_prompt,
- ],
- outputs=[
- txt2img_prompt,
- ]
- )
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
@@ -823,7 +804,7 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -999,18 +980,6 @@ def create_ui():
outputs=[img2img_prompt],
)
-
- roll.click(
- fn=roll_artist,
- _js="update_img2img_tokens",
- inputs=[
- img2img_prompt,
- ],
- outputs=[
- img2img_prompt,
- ]
- )
-
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
--
cgit v1.2.3
From e672cfb07418a1a3130d3bf21c14a0d3819f81fb Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 1 Jan 2023 18:37:37 +0300
Subject: rework of callback for #6094
---
modules/images.py | 10 ++++++----
modules/script_callbacks.py | 26 +++++++++++++++-----------
2 files changed, 21 insertions(+), 15 deletions(-)
diff --git a/modules/images.py b/modules/images.py
index 719aaf3b..f84fd485 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -39,12 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows)
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
+ script_callbacks.image_grid_callback(params)
+
w, h = imgs[0].size
- grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
- for i, img in enumerate(imgs):
- script_callbacks.image_grid_loop_callback(img)
- grid.paste(img, box=(i % cols * w, i // cols * h))
+ for i, img in enumerate(params.imgs):
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 0c854407..de69fd9f 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -52,8 +52,10 @@ class UiTrainTabParams:
class ImageGridLoopParams:
- def __init__(self, img):
- self.img = img
+ def __init__(self, imgs, cols, rows):
+ self.imgs = imgs
+ self.cols = cols
+ self.rows = rows
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
@@ -68,7 +70,7 @@ callback_map = dict(
callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
- callbacks_image_grid_loop=[],
+ callbacks_image_grid=[],
)
@@ -160,12 +162,14 @@ def after_component_callback(component, **kwargs):
except Exception:
report_exception(c, 'after_component_callback')
-def image_grid_loop_callback(component, **kwargs):
- for c in callback_map['callbacks_image_grid_loop']:
+
+def image_grid_callback(params: ImageGridLoopParams):
+ for c in callback_map['callbacks_image_grid']:
try:
- c.callback(component, **kwargs)
+ c.callback(params)
except Exception:
- report_exception(c, 'image_grid_loop')
+ report_exception(c, 'image_grid')
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
@@ -269,9 +273,9 @@ def on_after_component(callback):
add_callback(callback_map['callbacks_after_component'], callback)
-def on_image_grid_loop(callback):
- """register a function to be called inside the image grid loop.
+def on_image_grid(callback):
+ """register a function to be called before making an image grid.
The callback is called with one argument:
- - params: ImageGridLoopParams - parameters to be used inside the image grid loop.
+ - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
- add_callback(callback_map['callbacks_image_grid_loop'], callback)
+ add_callback(callback_map['callbacks_image_grid'], callback)
--
cgit v1.2.3
From a005fccddd5a37c57f1afe5234660b59b9a41508 Mon Sep 17 00:00:00 2001
From: me <25877290+Kryptortio@users.noreply.github.com>
Date: Sun, 1 Jan 2023 14:51:12 +0100
Subject: Add a lot more elem_id/HTML id, modified some that were duplicates
for seed section
---
modules/generation_parameters_copypaste.py | 2 +-
modules/ui.py | 254 ++++++++++++++---------------
style.css | 12 +-
3 files changed, 134 insertions(+), 134 deletions(-)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 54b3372d..8e7f0df0 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -93,7 +93,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
- buttons[tab] = gr.Button(f"Send to {tab}")
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons
diff --git a/modules/ui.py b/modules/ui.py
index 27da2c2c..7070ea15 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -272,17 +272,17 @@ def interrogate_deepbooru(image):
return gr_show(True) if prompt is None else prompt
-def create_seed_inputs():
+def create_seed_inputs(target_interface):
with gr.Row():
with gr.Box():
- with gr.Row(elem_id='seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
+ with gr.Row(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id='random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
- with gr.Box(elem_id='subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False)
+ with gr.Box(elem_id=target_interface + '_subseed_show_box'):
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
@@ -290,17 +290,17 @@ def create_seed_inputs():
with gr.Row(visible=False) as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
with gr.Box():
- with gr.Row(elem_id='subseed_row'):
- subseed = gr.Number(label='Variation seed', value=-1)
+ with gr.Row(elem_id=target_interface + '_subseed_row'):
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id='random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
with gr.Row(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0)
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -678,28 +678,28 @@ def create_ui():
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Row():
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
- tiling = gr.Checkbox(label='Tiling', value=False)
- enable_hr = gr.Checkbox(label='Highres. fix', value=False)
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr")
with gr.Row(visible=False) as hr_options:
- firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0)
- firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0)
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
+ firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width")
+ firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with gr.Row(equal_height=True):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
- with gr.Group():
+ with gr.Group(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
@@ -821,10 +821,10 @@ def create_ui():
with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img'):
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
- with gr.TabItem('Inpaint', id='inpaint'):
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
init_img_with_mask_orig = gr.State(None)
@@ -843,24 +843,24 @@ def create_ui():
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
with gr.Row():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch)
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
with gr.Row():
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
+ inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
with gr.Row():
- inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
- inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
+ inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, elem_id="img2img_inpaint_full_res")
+ inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
- with gr.TabItem('Batch img2img', id='batch'):
+ with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
")
- img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
- img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
with gr.Row():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
@@ -872,20 +872,20 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
- tiling = gr.Checkbox(label='Tiling', value=False)
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
with gr.Row():
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
with gr.Group():
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
- with gr.Group():
+ with gr.Group(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
@@ -1032,45 +1032,45 @@ def create_ui():
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image'):
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
+ with gr.TabItem('Single Image', elem_id="extras_single_tab"):
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
- with gr.TabItem('Batch Process'):
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
- with gr.TabItem('Batch from Directory'):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
- show_extras_results = gr.Checkbox(label='Show result images', value=True)
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by'):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
- with gr.TabItem('Scale to'):
+ with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
+ with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
with gr.Group():
with gr.Row():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
with gr.Group():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
with gr.Group():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
with gr.Group():
- upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
+ upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
@@ -1117,7 +1117,7 @@ def create_ui():
with gr.Column(variant='panel'):
html = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
html2 = gr.HTML()
with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
@@ -1144,13 +1144,13 @@ def create_ui():
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
- custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
- interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
+ custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
+ interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with gr.Row():
- checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
- save_as_half = gr.Checkbox(value=False, label="Save as float16")
+ checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
@@ -1165,58 +1165,58 @@ def create_ui():
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
- new_embedding_name = gr.Textbox(label="Name")
- initialization_text = gr.Textbox(label="Initialization text", value="*")
- nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
- overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
+ new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create embedding", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
with gr.Tab(label="Create hypernetwork"):
- new_hypernetwork_name = gr.Textbox(label="Name")
- new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"])
- new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
- new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
- new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
- new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
- overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
+ new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
with gr.Tab(label="Preprocess images"):
- process_src = gr.Textbox(label='Source directory')
- process_dst = gr.Textbox(label='Destination directory')
- process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
- process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
- preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
+ process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
+ process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
+ process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
+ process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
- process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images')
- process_focal_crop = gr.Checkbox(label='Auto focal point crop')
- process_caption = gr.Checkbox(label='Use BLIP for caption')
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True)
+ process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
+ process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
+ process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+ process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
with gr.Row(visible=False) as process_split_extra_row:
- process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
- process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
with gr.Row(visible=False) as process_focal_crop_row:
- process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
- process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
- process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
- process_focal_crop_debug = gr.Checkbox(label='Create debug image')
+ process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
+ process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
+ process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
with gr.Row():
with gr.Column(scale=3):
@@ -1224,8 +1224,8 @@ def create_ui():
with gr.Column():
with gr.Row():
- interrupt_preprocessing = gr.Button("Interrupt")
- run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
+ run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
process_split.change(
fn=lambda show: gr_show(show),
@@ -1248,31 +1248,31 @@ def create_ui():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
with gr.Row():
- embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
- hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
-
- batch_size = gr.Number(label='Batch size', value=1, precision=0)
- gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
- dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
- log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
- training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
- training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
- steps = gr.Number(label='Max steps', value=100000, precision=0)
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
- save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
- preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
+
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
with gr.Row():
- shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
- tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
with gr.Row():
- latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
- train_embedding = gr.Button(value="Train Embedding", variant='primary')
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
@@ -1490,7 +1490,7 @@ def create_ui():
return gr.update(value=value), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
- settings_submit = gr.Button(value="Apply settings", variant='primary')
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
result = gr.HTML()
settings_cols = 3
@@ -1541,8 +1541,8 @@ def create_ui():
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary', elem_id="settings_restart_gradio")
request_notifications.click(
fn=lambda: None,
diff --git a/style.css b/style.css
index f168571e..924d4ae7 100644
--- a/style.css
+++ b/style.css
@@ -73,7 +73,7 @@
margin-right: auto;
}
-#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
+[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
min-width: auto;
flex-grow: 0;
padding-left: 0.25em;
@@ -84,27 +84,27 @@
display: none;
}
-#seed_row, #subseed_row{
+[id$=_seed_row], [id$=_subseed_row]{
gap: 0.5rem;
}
-#subseed_show_box{
+[id$=_subseed_show_box]{
min-width: auto;
flex-grow: 0;
}
-#subseed_show_box > div{
+[id$=_subseed_show_box] > div{
border: 0;
height: 100%;
}
-#subseed_show{
+[id$=_subseed_show]{
min-width: auto;
flex-grow: 0;
padding: 0;
}
-#subseed_show label{
+[id$=_subseed_show] label{
height: 100%;
}
--
cgit v1.2.3
From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 00:38:09 +0300
Subject: fix the issue with training on SD2.0
---
modules/sd_models.py | 2 ++
modules/textual_inversion/textual_inversion.py | 3 +--
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ebd4dff7..bff8d6c9 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ model.logvar = model.logvar.to(devices.device) # fix for training
+
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 66f40367..1e5722e7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -282,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- # dataset loading may take a while, so input validations and early returns should be done before this
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -310,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0
_loss_step = 0 #internal
-
last_saved_file = ""
last_saved_image = ""
forced_filename = ""
--
cgit v1.2.3
From b5819d9bf1794071139c640b5f1e72c84a0e051a Mon Sep 17 00:00:00 2001
From: Philpax
Date: Mon, 2 Jan 2023 10:17:33 +1100
Subject: feat(api): add /sdapi/v1/embeddings
---
modules/api/api.py | 8 ++++++++
modules/api/models.py | 3 +++
2 files changed, 11 insertions(+)
diff --git a/modules/api/api.py b/modules/api/api.py
index 11daff0d..30bf3dac 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -100,6 +100,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
@@ -327,6 +328,13 @@ class Api:
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
+ def get_embeddings(self):
+ db = sd_hijack.model_hijack.embedding_db
+ return {
+ "loaded": sorted(db.word_embeddings.keys()),
+ "skipped": sorted(db.skipped_embeddings),
+ }
+
def refresh_checkpoints(self):
shared.refresh_checkpoints()
diff --git a/modules/api/models.py b/modules/api/models.py
index c446ce7a..a8472dc9 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,3 +249,6 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingsResponse(BaseModel):
+ loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
--
cgit v1.2.3
From c65909ad16a1962129114c6251de092f49479b06 Mon Sep 17 00:00:00 2001
From: Philpax
Date: Mon, 2 Jan 2023 12:21:22 +1100
Subject: feat(api): return more data for embeddings
---
modules/api/api.py | 17 +++++++++++++++--
modules/api/models.py | 11 +++++++++--
modules/textual_inversion/textual_inversion.py | 8 ++++----
3 files changed, 28 insertions(+), 8 deletions(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index 30bf3dac..9c670f00 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -330,9 +330,22 @@ class Api:
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
return {
- "loaded": sorted(db.word_embeddings.keys()),
- "skipped": sorted(db.skipped_embeddings),
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self):
diff --git a/modules/api/models.py b/modules/api/models.py
index a8472dc9..4a632c68 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,6 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
class EmbeddingsResponse(BaseModel):
- loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1e5722e7..fd253477 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -59,7 +59,7 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
- self.skipped_embeddings = []
+ self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1
@@ -91,7 +91,7 @@ class EmbeddingDatabase:
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
- self.skipped_embeddings = []
+ self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
@@ -136,7 +136,7 @@ class EmbeddingDatabase:
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
- self.skipped_embeddings.append(name)
+ self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
- print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
--
cgit v1.2.3
From ef27a18b6b7cb1a8eebdc9b2e88d25baf2c2414d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 19:42:10 +0300
Subject: Hires fix rework
---
modules/generation_parameters_copypaste.py | 32 ++++++++++++++
modules/images.py | 24 +++++++++--
modules/processing.py | 68 ++++++++++++------------------
modules/shared.py | 7 ++-
modules/txt2img.py | 6 +--
modules/ui.py | 15 +++----
scripts/xy_grid.py | 4 +-
7 files changed, 96 insertions(+), 60 deletions(-)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 8e7f0df0..d6fa822b 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,5 +1,6 @@
import base64
import io
+import math
import os
import re
from pathlib import Path
@@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
return None
+def restore_old_hires_fix_params(res):
+ """for infotexts that specify old First pass size parameter, convert it into
+ width, height, and hr scale"""
+
+ firstpass_width = res.get('First pass size-1', None)
+ firstpass_height = res.get('First pass size-2', None)
+
+ if firstpass_width is None or firstpass_height is None:
+ return
+
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
+ width = int(res.get("Size-1", 512))
+ height = int(res.get("Size-2", 512))
+
+ if firstpass_width == 0 or firstpass_height == 0:
+ # old algorithm for auto-calculating first pass size
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ firstpass_width = math.ceil(scale * width / 64) * 64
+ firstpass_height = math.ceil(scale * height / 64) * 64
+
+ hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
+
+ res['Size-1'] = firstpass_width
+ res['Size-2'] = firstpass_height
+ res['Hires upscale'] = hr_scale
+
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+ restore_old_hires_fix_params(res)
+
return res
diff --git a/modules/images.py b/modules/images.py
index f84fd485..c3a5fc8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
-def resize_image(resize_mode, im, width, height):
+def resize_image(resize_mode, im, width, height, upscaler_name=None):
+ """
+ Resizes an image with the specified resize_mode, width, and height.
+
+ Args:
+ resize_mode: The mode to use when resizing the image.
+ 0: Resize the image to the specified width and height.
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+ im: The image to resize.
+ width: The width to resize the image to.
+ height: The height to resize the image to.
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
+ """
+
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
+
def resize(im, w, h):
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height)
if scale > 1.0:
- upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
- assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
+ assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
diff --git a/modules/processing.py b/modules/processing.py
index 42dc19ea..4654570c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
- self.firstphase_width = firstphase_width
- self.firstphase_height = firstphase_height
- self.truncate_x = 0
- self.truncate_y = 0
+ self.hr_scale = hr_scale
+ self.hr_upscaler = hr_upscaler
+
+ if firstphase_width != 0 or firstphase_height != 0:
+ print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
+ self.hr_scale = self.width / firstphase_width
+ self.width = firstphase_width
+ self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
-
- if self.firstphase_width == 0 or self.firstphase_height == 0:
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- firstphase_width_truncated = int(scale * self.width)
- firstphase_height_truncated = int(scale * self.height)
-
- else:
-
- width_ratio = self.width / self.firstphase_width
- height_ratio = self.height / self.firstphase_height
-
- if width_ratio > height_ratio:
- firstphase_width_truncated = self.firstphase_width
- firstphase_height_truncated = self.firstphase_width * self.height / self.width
- else:
- firstphase_width_truncated = self.firstphase_height * self.width / self.height
- firstphase_height_truncated = self.firstphase_height
-
- self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
- self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ if self.hr_upscaler is not None:
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode
+ if self.enable_hr and latent_scale_mode is None:
+ assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+
if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
- x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
-
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ target_width = int(self.width * self.hr_scale)
+ target_height = int(self.height * self.hr_scale)
- """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
def save_intermediate(image, index):
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
+
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
@@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
- if opts.use_scale_latent_for_hires_fix:
+ if latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i)
- image = images.resize_image(0, image, self.width, self.height)
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
@@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
diff --git a/modules/shared.py b/modules/shared.py
index 7f430b93..b65559ee 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
- "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -545,6 +544,12 @@ opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
+latent_upscale_default_mode = "Latent"
+latent_upscale_modes = {
+ "Latent": "bilinear",
+ "Latent (nearest)": "nearest",
+}
+
sd_upscalers = []
sd_model = None
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 7f61e19a..e189a899 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling,
enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None,
- firstphase_width=firstphase_width if enable_hr else None,
- firstphase_height=firstphase_height if enable_hr else None,
+ hr_scale=hr_scale,
+ hr_upscaler=hr_upscaler,
)
p.scripts = modules.scripts.scripts_txt2img
diff --git a/modules/ui.py b/modules/ui.py
index 7070ea15..27cd9ddd 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -684,11 +684,11 @@ def create_ui():
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
- enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
with gr.Row(visible=False) as hr_options:
- firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width")
- firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height")
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with gr.Row(equal_height=True):
@@ -729,8 +729,8 @@ def create_ui():
width,
enable_hr,
denoising_strength,
- firstphase_width,
- firstphase_height,
+ hr_scale,
+ hr_upscaler,
] + custom_inputs,
outputs=[
@@ -762,7 +762,6 @@ def create_ui():
outputs=[hr_options],
)
-
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
@@ -781,8 +780,8 @@ def create_ui():
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (firstphase_width, "First pass size-1"),
- (firstphase_height, "First pass size-2"),
+ (hr_scale, "Hires upscale"),
+ (hr_upscaler, "Hires upscaler"),
*modules.scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 3e0b2805..f92f9776 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -202,7 +202,7 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
- AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None),
+ AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
AxisOption("VAE", str, apply_vae, format_value_add_label, None),
AxisOption("Styles", str, apply_styles, format_value_add_label, None),
@@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model
- self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix
self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb):
@@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object):
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
- opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
--
cgit v1.2.3
From 4dbde228ff48dbb105241b1ed25c21ce3f87d182 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 20:01:16 +0300
Subject: make it possible to use fractional values for SD upscale.
---
modules/upscaler.py | 6 +++---
scripts/sd_upscale.py | 2 +-
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index c4e6e6bd..231680cb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -53,10 +53,10 @@ class Upscaler:
def do_upscale(self, img: PIL.Image, selected_model: str):
return img
- def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
+ def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = img.width * scale
- dest_h = img.height * scale
+ dest_w = int(img.width * scale)
+ dest_h = int(img.height * scale)
for i in range(3):
shape = (img.width, img.height)
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index e8c80a6c..9739545c 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -19,7 +19,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
info = gr.HTML("
Will upscale the image by the selected scale factor; use width and height sliders to set tile size
")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
- scale_factor = gr.Slider(minimum=1, maximum=4, step=1, label='Scale Factor', value=2)
+ scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
return [info, overlap, upscaler_index, scale_factor]
--
cgit v1.2.3
From 84dd7e8e2495c4fc2997e97f8267aa831eb90d11 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 20:30:02 +0300
Subject: error out with a readable message in chwewckpoint merger for
incompatible tensor shapes (ie when trying to merge SD1.5 with SD2.0)
---
modules/extras.py | 2 ++
modules/ui.py | 2 +-
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/modules/extras.py b/modules/extras.py
index 68939dea..5e270250 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -303,6 +303,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
+ assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
+
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
diff --git a/modules/ui.py b/modules/ui.py
index 27cd9ddd..67a51888 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1663,7 +1663,7 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list
- return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
+ return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
return results
modelmerger_merge.click(
--
cgit v1.2.3
From 8d12a729b8b036cb765cf2d87576d5ae256135c8 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 20:46:51 +0300
Subject: fix possible error with accessing nonexistent setting
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/ui.py b/modules/ui.py
index 67a51888..9350a80f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -491,7 +491,7 @@ def apply_setting(key, value):
return
valtype = type(opts.data_labels[key].default)
- oldval = opts.data[key]
+ oldval = opts.data.get(key, None)
opts.data[key] = valtype(value) if valtype != type(None) else value
if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
--
cgit v1.2.3
From 251ecee6949c36e9df1d99a950b3e1af2b5fa2b6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 2 Jan 2023 22:44:46 +0300
Subject: make "send to" buttons send actual dimension of the sent image rather
than fields
---
javascript/ui.js | 4 +--
modules/generation_parameters_copypaste.py | 58 ++++++++++++++++++++----------
2 files changed, 42 insertions(+), 20 deletions(-)
diff --git a/javascript/ui.js b/javascript/ui.js
index 587dd782..d0c054d9 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -19,7 +19,7 @@ function selected_gallery_index(){
function extract_image_from_gallery(gallery){
if(gallery.length == 1){
- return gallery[0]
+ return [gallery[0]]
}
index = selected_gallery_index()
@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
return [null]
}
- return gallery[index];
+ return [gallery[index]];
}
function args_to_array(args){
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d6fa822b..ec60319a 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -103,35 +103,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
+def send_image_and_dimensions(x):
+ if isinstance(x, Image.Image):
+ img = x
+ else:
+ img = image_from_url_text(x)
+
+ if shared.opts.send_size and isinstance(img, Image.Image):
+ w = img.width
+ h = img.height
+ else:
+ w = gr.update()
+ h = gr.update()
+
+ return img, w, h
+
+
def run_bind():
- for buttons, send_image, send_generate_info in bind_list:
+ for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons:
button = buttons[tab]
- if send_image and paste_fields[tab]["init_img"]:
- if type(send_image) == gr.Gallery:
- button.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery",
- inputs=[send_image],
- outputs=[paste_fields[tab]["init_img"]],
- )
+ destination_image_component = paste_fields[tab]["init_img"]
+ fields = paste_fields[tab]["fields"]
+
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if source_image_component and destination_image_component:
+ if isinstance(source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
else:
- button.click(
- fn=lambda x: x,
- inputs=[send_image],
- outputs=[paste_fields[tab]["init_img"]],
- )
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+
+ button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ )
- if send_generate_info and paste_fields[tab]["fields"] is not None:
+ if send_generate_info and fields is not None:
if send_generate_info in paste_fields:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else [])
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
- outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
)
else:
- connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
+ connect_paste(button, fields, send_generate_info)
button.click(
fn=None,
--
cgit v1.2.3
From 1d7a31def8b5f4c348e2dd07536ac56cb4350614 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 06:21:53 +0300
Subject: make edit fields for sliders not get hidden by slider's label when
there's not enough space
---
style.css | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/style.css b/style.css
index 924d4ae7..77551dd7 100644
--- a/style.css
+++ b/style.css
@@ -509,7 +509,7 @@ canvas[key="mask"] {
position: absolute;
right: 0.5em;
top: -0.6em;
- z-index: 200;
+ z-index: 400;
width: 8em;
}
#quicksettings .gr-box > div > div > input.gr-text-input {
--
cgit v1.2.3
From 269f6e867651cadef40d2c939a79d13291280bcd Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 07:20:20 +0300
Subject: change settings UI to use vertical tabs
---
modules/ui.py | 45 +++++++++++++++++----------------------------
style.css | 27 +++++++++++++++++++++++++++
2 files changed, 44 insertions(+), 28 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index 9350a80f..f8c973ba 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1489,41 +1489,34 @@ def create_ui():
return gr.update(value=value), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- result = gr.HTML()
+ with gr.Row():
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ restart_gradio = gr.Button(value='Restart UI', variant='primary', elem_id="settings_restart_gradio")
- settings_cols = 3
- items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ result = gr.HTML(elem_id="settings_result")
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
quicksettings_list = []
- cols_displayed = 0
- items_displayed = 0
previous_section = None
- column = None
- with gr.Row(elem_id="settings").style(equal_height=False):
+ current_tab = None
+ with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
if previous_section != item.section and not section_must_be_skipped:
- if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
- if column is not None:
- column.__exit__()
+ elem_id, text = item.section
- column = gr.Column(variant='panel')
- column.__enter__()
+ if current_tab is not None:
+ current_tab.__exit__()
- items_displayed = 0
- cols_displayed += 1
+ current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
+ current_tab.__enter__()
previous_section = item.section
- elem_id, text = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='
{}
'.format(text))
-
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
@@ -1533,15 +1526,14 @@ def create_ui():
component = create_setting_component(k)
component_dict[k] = component
components.append(component)
- items_displayed += 1
- with gr.Row():
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
- download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ if current_tab is not None:
+ current_tab.__exit__()
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary', elem_id="settings_restart_gradio")
+ with gr.TabItem("Actions"):
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
request_notifications.click(
fn=lambda: None,
@@ -1578,9 +1570,6 @@ def create_ui():
outputs=[],
)
- if column is not None:
- column.__exit__()
-
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"),
diff --git a/style.css b/style.css
index 77551dd7..7df4d960 100644
--- a/style.css
+++ b/style.css
@@ -241,6 +241,33 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
z-index: 200;
}
+#settings{
+ display: block;
+}
+
+#settings > div{
+ border: none;
+ margin-left: 10em;
+}
+
+#settings > div.flex-wrap{
+ float: left;
+ display: block;
+ margin-left: 0;
+ width: 10em;
+}
+
+#settings > div.flex-wrap button{
+ display: block;
+ border: none;
+ text-align: left;
+}
+
+#settings_result{
+ height: 1.4em;
+ margin: 0 1.2em;
+}
+
input[type="range"]{
margin: 0.5em 0 -0.3em 0;
}
--
cgit v1.2.3
From 18c03cdeac6272734b0c09afd3fbe47d1372dd07 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 09:04:29 +0300
Subject: styling rework to make things more compact
---
modules/ui.py | 121 ++++++++++++++++++++++++-----------------------
modules/ui_components.py | 7 +++
style.css | 35 ++++++++------
3 files changed, 89 insertions(+), 74 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index f8c973ba..f787b518 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -19,7 +19,8 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, ui_components
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules.ui_components import FormRow, FormGroup, ToolButton
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -273,31 +274,27 @@ def interrogate_deepbooru(image):
def create_seed_inputs(target_interface):
- with gr.Row():
- with gr.Box():
- with gr.Row(elem_id=target_interface + '_seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
-
- with gr.Box(elem_id=target_interface + '_subseed_show_box'):
+ with FormRow(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ seed.style(container=False)
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+
+ with gr.Group(elem_id=target_interface + '_subseed_show_box'):
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
- with gr.Row(visible=False) as seed_extra_row_1:
+ with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
- with gr.Box():
- with gr.Row(elem_id=target_interface + '_subseed_row'):
- subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed.style(container=False)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
- with gr.Row(visible=False) as seed_extra_row_2:
+ with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
@@ -523,7 +520,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
return gr.update(**(args or {}))
- refresh_button = ui_components.ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
@@ -636,11 +633,11 @@ Requested path was: {f}
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
- with gr.Row(elem_id=f"sampler_selection_{tabname}"):
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
else:
- with gr.Group(elem_id=f"sampler_selection_{tabname}"):
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
@@ -677,29 +674,29 @@ def create_ui():
with gr.Column(variant='panel', elem_id="txt2img_settings"):
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
- with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
- with gr.Row():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ with FormRow(elem_id="txt2img_checkboxes"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
- with gr.Row(visible=False) as hr_options:
+ with FormRow(visible=False) as hr_options:
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
- with gr.Row(equal_height=True):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
-
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
-
- with gr.Group(elem_id="txt2img_script_container"):
+ with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
@@ -816,7 +813,7 @@ def create_ui():
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
setup_progressbar(progressbar, img2img_preview, 'img2img')
- with gr.Row().style(equal_height=False):
+ with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
@@ -841,19 +838,23 @@ def create_ui():
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
- with gr.Row():
+ with FormRow():
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
- with gr.Row():
- mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+ with FormRow():
+ mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
- with gr.Row():
- inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, elem_id="img2img_inpaint_full_res")
- inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
@@ -861,30 +862,30 @@ def create_ui():
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- with gr.Row():
- resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+ with FormRow():
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
- with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
-
- with gr.Row():
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
+ with FormRow():
+ with gr.Column(elem_id="img2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+ with gr.Column(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
- with gr.Row():
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
-
- with gr.Group():
+ with FormGroup():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
- with gr.Group(elem_id="img2img_script_container"):
+ with FormRow(elem_id="img2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
+
+ with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
@@ -1444,7 +1445,7 @@ def create_ui():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- with ui_components.FormRow():
+ with FormRow():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
diff --git a/modules/ui_components.py b/modules/ui_components.py
index d0519d2d..91eb0e3d 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -16,3 +16,10 @@ class FormRow(gr.Row, gr.components.FormComponent):
def get_block_name(self):
return "row"
+
+
+class FormGroup(gr.Group, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "group"
diff --git a/style.css b/style.css
index 7df4d960..86a265f6 100644
--- a/style.css
+++ b/style.css
@@ -74,7 +74,8 @@
}
[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
- min-width: auto;
+ min-width: 2.3em;
+ height: 2.5em;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
@@ -86,6 +87,7 @@
[id$=_seed_row], [id$=_subseed_row]{
gap: 0.5rem;
+ padding: 0.6em;
}
[id$=_subseed_show_box]{
@@ -206,24 +208,24 @@ button{
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute;
- top: -0.6em;
+ top: -0.5em;
line-height: 1.2em;
- padding: 0 0.5em;
- margin: 0;
+ padding: 0;
+ margin: 0 0.5em;
background-color: white;
- border-top: 1px solid #eee;
- border-left: 1px solid #eee;
- border-right: 1px solid #eee;
+ box-shadow: 0 0 5px 5px white;
z-index: 300;
}
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55);
- border-top: 1px solid rgb(55 65 81);
- border-left: 1px solid rgb(55 65 81);
- border-right: 1px solid rgb(55 65 81);
+ box-shadow: 0 0 5px 5px rgb(31, 41, 55);
+}
+
+#txt2img_column_batch, #img2img_column_batch{
+ min-width: min(13.5em, 100%) !important;
}
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
@@ -232,10 +234,6 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
margin-right: 8em;
}
-.gr-panel div.flex-col div.justify-between label span{
- margin: 0;
-}
-
#settings .gr-panel div.flex-col div.justify-between div{
position: relative;
z-index: 200;
@@ -609,6 +607,15 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
}
+#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
+ padding-top: 0.9em;
+}
+
+#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{
+ border: none;
+ padding-bottom: 0.5em;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
--
cgit v1.2.3
From 2bc86712ec16cada01a2353f1d978c1aabc84dbb Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 09:13:35 +0300
Subject: make quicksettings UI elements appear in same order as they are
listed in the setting
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index f787b518..d7b911da 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1497,7 +1497,7 @@ def create_ui():
result = gr.HTML(elem_id="settings_result")
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
- quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+ quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
quicksettings_list = []
@@ -1604,7 +1604,7 @@ def create_ui():
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
- for i, k, item in quicksettings_list:
+ for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
--
cgit v1.2.3
From 9d4eff097deff6153c4023f158bd9fbd4f3e88b3 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 10:01:06 +0300
Subject: add a button to show all setting pages
---
javascript/ui.js | 11 +++++++++++
modules/ui.py | 2 ++
2 files changed, 13 insertions(+)
diff --git a/javascript/ui.js b/javascript/ui.js
index d0c054d9..34406f3f 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -188,6 +188,17 @@ onUiUpdate(function(){
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
}
+
+ show_all_pages = gradioApp().getElementById('settings_show_all_pages')
+ settings_tabs = gradioApp().querySelector('#settings div')
+ if(show_all_pages && settings_tabs){
+ settings_tabs.appendChild(show_all_pages)
+ show_all_pages.onclick = function(){
+ gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
+ elem.style.display = "block";
+ })
+ }
+ }
})
let txt2img_textarea, img2img_textarea = undefined;
diff --git a/modules/ui.py b/modules/ui.py
index d7b911da..2c92c422 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1536,6 +1536,8 @@ def create_ui():
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
request_notifications.click(
fn=lambda: None,
inputs=[],
--
cgit v1.2.3
From a1cf55a9d1c82f8e56c00d549bca5c8fa069f412 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 10:39:21 +0300
Subject: add option to reorder items in main UI
---
modules/shared.py | 13 ++++++
modules/ui.py | 130 +++++++++++++++++++++++++++++++++++-------------------
2 files changed, 97 insertions(+), 46 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index b65559ee..23657a93 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -109,6 +109,17 @@ restricted_opts = {
"outdir_save",
}
+ui_reorder_categories = [
+ "sampler",
+ "dimensions",
+ "cfg",
+ "seed",
+ "checkboxes",
+ "hires_fix",
+ "batch",
+ "scripts",
+]
+
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
@@ -410,7 +421,9 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
diff --git a/modules/ui.py b/modules/ui.py
index 2c92c422..f2e7c0d6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -644,6 +644,13 @@ def create_sampler_and_steps_selection(choices, tabname):
return steps, sampler_index
+def ordered_ui_categories():
+ user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ yield category
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -672,32 +679,48 @@ def create_ui():
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel', elem_id="txt2img_settings"):
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
-
- with FormRow():
- with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
- with gr.Column(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
-
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
-
- with FormRow(elem_id="txt2img_checkboxes"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
- enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
- with FormRow(visible=False) as hr_options:
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
-
- with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "cfg":
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="txt2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+
+ elif category == "hires_fix":
+ with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="txt2img_script_container"):
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -865,28 +888,43 @@ def create_ui():
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
- steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
-
- with FormRow():
- with gr.Column(elem_id="img2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
- with gr.Column(elem_id="img2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
-
- with FormGroup():
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
-
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
- with FormRow(elem_id="img2img_checkboxes"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
-
- with FormGroup(elem_id="img2img_script_container"):
- custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="img2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "cfg":
+ with FormGroup():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="img2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="img2img_script_container"):
+ custom_inputs = modules.scripts.scripts_img2img.setup_ui()
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
--
cgit v1.2.3
From fda1ed184381fdf8aa81be4f64e77787f3fac1b2 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 12:01:32 +0300
Subject: some minor improvements for dark mode UI
---
style.css | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/style.css b/style.css
index 86a265f6..7296ce91 100644
--- a/style.css
+++ b/style.css
@@ -208,20 +208,20 @@ button{
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute;
- top: -0.5em;
+ top: -0.7em;
line-height: 1.2em;
padding: 0;
margin: 0 0.5em;
background-color: white;
- box-shadow: 0 0 5px 5px white;
+ box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white;
z-index: 300;
}
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55);
- box-shadow: 0 0 5px 5px rgb(31, 41, 55);
+ box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
}
#txt2img_column_batch, #img2img_column_batch{
--
cgit v1.2.3
From 9a3b0ee960b0c61c4f60e3081ae6f2098533d393 Mon Sep 17 00:00:00 2001
From: hithereai <121192995+hithereai@users.noreply.github.com>
Date: Tue, 3 Jan 2023 11:22:06 +0200
Subject: update req.txt
The old 'opencv-python' package is very limiting in terms of optical flow - so I propose a package change to 'opencv-contrib-python', which has more cv2.optflow methods.
These are needed for optical flow trickery in auto1111 and its extensions, and it cannot be installed by an extension as only a single package of opencv needs to be installed for optical flow to work properly. Change of the main one is Inevitable.
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index e2c3876b..4f09385f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,7 +9,7 @@ gradio==3.15.0
invisible-watermark
numpy
omegaconf
-opencv-python
+opencv-contrib-python
requests
piexif
Pillow
--
cgit v1.2.3
From c0ee1488702d5a6ae35fbf7e0422f9f685394920 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 14:18:48 +0300
Subject: add support for running with gradio 3.9 installed
---
modules/generation_parameters_copypaste.py | 4 ++--
modules/ui_tempdir.py | 23 +++++++++++++++++++++--
2 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index ec60319a..d94f11a3 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -7,7 +7,7 @@ from pathlib import Path
import gradio as gr
from modules.shared import script_path
-from modules import shared
+from modules import shared, ui_tempdir
import tempfile
from PIL import Image
@@ -39,7 +39,7 @@ def quote(text):
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
- is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets])
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename)
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 363d449d..21945235 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -1,6 +1,7 @@
import os
import tempfile
from collections import namedtuple
+from pathlib import Path
import gradio as gr
@@ -12,10 +13,28 @@ from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
+def register_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
+
+ if hasattr(gradio, 'temp_dirs'): # gradio 3.9
+ gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
+
+
+def check_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'):
+ return any([filename in fileset for fileset in gradio.temp_file_sets])
+
+ if hasattr(gradio, 'temp_dirs'):
+ return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
+
+ return False
+
+
def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
- shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(already_saved_as)}
+ register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as)
return file_obj
@@ -45,7 +64,7 @@ def on_tmpdir_changed():
os.makedirs(shared.opts.temp_dir, exist_ok=True)
- shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(shared.opts.temp_dir)}
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr():
--
cgit v1.2.3
From bddebe09edeb6a18f2c06986d5658a7be3a563ea Mon Sep 17 00:00:00 2001
From: Shondoit
Date: Tue, 3 Jan 2023 10:26:37 +0100
Subject: Save Optimizer next to TI embedding
Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory)
---
modules/shared.py | 2 +-
modules/textual_inversion/textual_inversion.py | 40 ++++++++++++++++++++------
2 files changed, 33 insertions(+), 9 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 23657a93..c541d18c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd253477..16176e90 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -28,6 +28,7 @@ class Embedding:
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -41,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -95,9 +103,10 @@ class EmbeddingDatabase:
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -105,8 +114,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
+
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- #if shared.opts.save_optimizer_state:
- #embedding.optimizer_state_dict = optimizer.state_dict()
- save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
@@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}
return embedding, filename
-def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint
--
cgit v1.2.3
From e9fb9bb0c25f59109a816fc53c385bed58965c24 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 17:40:20 +0300
Subject: fix hires fix not working in API when user does not specify upscaler
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/processing.py b/modules/processing.py
index 4654570c..a172af0b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -685,7 +685,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
--
cgit v1.2.3
From aaa4c2aacbb6523077334093c81bd475d757f7a1 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Tue, 3 Jan 2023 09:45:16 -0500
Subject: add api logging
---
modules/api/api.py | 24 +++++++++++++++++++++++-
modules/shared.py | 1 +
2 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index 9c670f00..53135470 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,11 +1,12 @@
import base64
import io
import time
+import datetime
import uvicorn
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
-from fastapi import APIRouter, Depends, FastAPI, HTTPException
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
@@ -67,6 +68,26 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
+def init_api_middleware(app: FastAPI):
+ @app.middleware("http")
+ async def log_and_time(req: Request, call_next):
+ ts = time.time()
+ res: Response = await call_next(req)
+ duration = str(round(time.time() - ts, 4))
+ res.headers["X-Process-Time"] = duration
+ if shared.cmd_opts.api_log:
+ print('API {t} {code} {prot}/{ver} {method} {p} {cli} {duration}'.format(
+ t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code = res.status_code,
+ ver = req.scope.get('http_version', '0.0'),
+ cli = req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot = req.scope.get('scheme', 'err'),
+ method = req.scope.get('method', 'err'),
+ p = req.scope.get('path', 'err'),
+ duration = duration,
+ ))
+ return res
+
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
@@ -78,6 +99,7 @@ class Api:
self.router = APIRouter()
self.app = app
+ init_api_middleware(self.app)
self.queue_lock = queue_lock
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
diff --git a/modules/shared.py b/modules/shared.py
index 23657a93..2a03d716 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -82,6 +82,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
--
cgit v1.2.3
From 1d9dc48efda2e8da6d13fc62e65500198a9b041c Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Tue, 3 Jan 2023 10:21:51 -0500
Subject: init job and add info to model merge
---
modules/extras.py | 14 ++++++++++++--
1 file changed, 12 insertions(+), 2 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 5e270250..7e222313 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -242,6 +242,9 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+ shared.state.begin()
+ shared.state.job = 'model-merge'
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -263,8 +266,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_info:
+ shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
+ shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
@@ -281,6 +287,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -291,6 +298,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key]
b = theta_1[key]
+ shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -303,8 +311,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
- assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
-
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
@@ -332,6 +338,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename)
+ shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -343,4 +350,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models()
print("Checkpoint saved.")
+ shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ shared.state.end()
+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
--
cgit v1.2.3
From 192ddc04d6de0d780f73aa5fbaa8c66cd4642e1c Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Tue, 3 Jan 2023 10:34:51 -0500
Subject: add job info to modules
---
modules/extras.py | 17 +++++++++++++----
modules/hypernetworks/hypernetwork.py | 1 +
modules/textual_inversion/preprocess.py | 1 +
modules/textual_inversion/textual_inversion.py | 1 +
4 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 7e222313..d665440a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -58,6 +58,9 @@ cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
+ shared.state.begin()
+ shared.state.job = 'extras'
+
imageArr = []
# Also keep track of original file names
imageNameArr = []
@@ -94,6 +97,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
@@ -104,6 +108,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
@@ -114,6 +119,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
@@ -180,6 +186,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
+
+ shared.state.textinfo = f'Processing image {image_name}'
+
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -193,6 +202,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
basename = ''
+ if opts.enable_pnginfo: # append info before save
+ image.info = existing_pnginfo
+ image.info["extras"] = info
+
if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
@@ -203,10 +216,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
- if opts.enable_pnginfo:
- image.info = existing_pnginfo
- image.info["extras"] = info
-
if extras_mode != 2 or show_extras_results :
outputs.append(image)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 109e8078..450fecac 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -417,6 +417,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
+ shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 56b9b2eb..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd253477..2c1251d6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -245,6 +245,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
--
cgit v1.2.3
From 2d5a5076bb2a0c05cc27d75a1bcadab7f32a46d0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 18:38:21 +0300
Subject: Make it so that upscalers are not repeated when restarting UI.
---
modules/modelloader.py | 20 ++++++++++++++++++++
webui.py | 14 +++++++-------
2 files changed, 27 insertions(+), 7 deletions(-)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e647f6fa..6a1a7ac8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+ load_upscalers()
+
+ builtin_upscaler_classes.clear()
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+ for cls in Upscaler.__subclasses__():
+ if cls not in builtin_upscaler_classes:
+ forbidden_upscaler_classes.add(cls)
+
+
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -139,6 +156,9 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
+ if cls in forbidden_upscaler_classes:
+ continue
+
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))
diff --git a/webui.py b/webui.py
index 3aee8792..c7d55a97 100644
--- a/webui.py
+++ b/webui.py
@@ -1,4 +1,5 @@
import os
+import sys
import threading
import time
import importlib
@@ -55,8 +56,8 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
+ modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
-
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
@@ -169,23 +170,22 @@ def webui():
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
+ print('Restarting UI...')
sd_samplers.set_samplers()
- print('Reloading extensions')
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
- print('Reloading custom scripts')
+ modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
modelloader.load_upscalers()
- print('Reloading modules: modules.ui')
- importlib.reload(modules.ui)
- print('Refreshing Model List')
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+
modules.sd_models.list_models()
- print('Restarting Gradio')
if __name__ == "__main__":
--
cgit v1.2.3
From 8f96f9289981a66741ba770d14f3d27ce335a0fb Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 18:39:14 +0300
Subject: call script callbacks for reloaded model after loading embeddings
---
modules/sd_models.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index bff8d6c9..b98b05fc 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -324,12 +324,12 @@ def load_model(checkpoint_info=None):
sd_model.eval()
shared.sd_model = sd_model
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
+
script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.")
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model
-
return sd_model
--
cgit v1.2.3
From cec209981ee988536c2521297baf9bc1b256005f Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Tue, 3 Jan 2023 10:58:52 -0500
Subject: log only sdapi
---
modules/api/api.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index 53135470..78751c57 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -68,22 +68,23 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
-def init_api_middleware(app: FastAPI):
+def api_middleware(app: FastAPI):
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
- if shared.cmd_opts.api_log:
- print('API {t} {code} {prot}/{ver} {method} {p} {cli} {duration}'.format(
+ endpoint = req.scope.get('path', 'err')
+ if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
+ print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code = res.status_code,
ver = req.scope.get('http_version', '0.0'),
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
prot = req.scope.get('scheme', 'err'),
method = req.scope.get('method', 'err'),
- p = req.scope.get('path', 'err'),
+ endpoint = endpoint,
duration = duration,
))
return res
--
cgit v1.2.3
From d8d206c1685d1e7027d4af82ed18d106f41d1cc4 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Tue, 3 Jan 2023 11:01:04 -0500
Subject: add state to interrogate
---
modules/interrogate.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 6f761c5a..738d8ff7 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -136,7 +136,8 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
-
+ shared.state.begin()
+ shared.state.job = 'interrogate'
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -177,5 +178,6 @@ class InterrogateModels:
res += ""
self.unload()
+ shared.state.end()
return res
--
cgit v1.2.3
From 82cfc227d735c140447d5b8dca29a71ee9bde127 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 20:23:17 +0300
Subject: added licenses screen to settings added footer removed unused
inpainting code
---
README.md | 2 +
html/footer.html | 9 +
html/licenses.html | 392 ++++++++++++++++++++++++++++++++++++++++
modules/sd_hijack_inpainting.py | 232 ------------------------
modules/ui.py | 15 +-
style.css | 11 ++
6 files changed, 427 insertions(+), 234 deletions(-)
create mode 100644 html/footer.html
create mode 100644 html/licenses.html
diff --git a/README.md b/README.md
index 556000fb..88250a6b 100644
--- a/README.md
+++ b/README.md
@@ -127,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits
+Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
+
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
diff --git a/html/footer.html b/html/footer.html
new file mode 100644
index 00000000..a8f2adf7
--- /dev/null
+++ b/html/footer.html
@@ -0,0 +1,9 @@
+
+Parts of CodeFormer code had to be copied to be compatible with GFPGAN.
+
+S-Lab License 1.0
+
+Copyright 2022 S-Lab
+
+Redistribution and use for non-commercial purpose in source and
+binary forms, with or without modification, are permitted provided
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in
+source or binary forms, with or without modification is required,
+please contact the contributor(s) of the work.
+
+Code for architecture and reading models copied.
+
+MIT License
+
+Copyright (c) 2021 victorca25
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+BSD 3-Clause License
+
+Copyright (c) 2021, Xintao Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+Some code for compatibility with OSX is taken from lstein's repository.
+
+MIT License
+
+Copyright (c) 2022 InvokeAI Team
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Code added by contirubtors, most likely copied from this repository.
+
+MIT License
+
+Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Some small amounts of code borrowed and reworked.
+
+MIT License
+
+Copyright (c) 2022 pharmapsychotic
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Code added by contirubtors, most likely copied from this repository.
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [2021] [SwinIR Authors]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 06b75772..3c214a35 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -12,191 +12,6 @@ from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
-# =================================================================================================
-# Monkey patch DDIMSampler methods from RunwayML repo directly.
-# Adapted from:
-# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
-# =================================================================================================
-@torch.no_grad()
-def sample_ddim(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list):
- ctmp = ctmp[0]
- cbs = ctmp.shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
-
- samples, intermediates = self.ddim_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
-@torch.no_grad()
-def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None):
- b, *_, device = *x.shape, x.device
-
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- else:
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
-
-# =================================================================================================
-# Monkey patch PLMSSampler methods.
-# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
-# Adapted from:
-# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
-# =================================================================================================
-@torch.no_grad()
-def sample_plms(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list):
- ctmp = ctmp[0]
- cbs = ctmp.shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- # print(f'Data shape for PLMS sampling is {size}') # remove unnecessary message
-
- samples, intermediates = self.plms_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -280,44 +95,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t
-# =================================================================================================
-# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
-# Adapted from:
-# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
-# =================================================================================================
-
-@torch.no_grad()
-def get_unconditional_conditioning(self, batch_size, null_label=None):
- if null_label is not None:
- xc = null_label
- if isinstance(xc, ListConfig):
- xc = list(xc)
- if isinstance(xc, dict) or isinstance(xc, list):
- c = self.get_learned_conditioning(xc)
- else:
- if hasattr(xc, "to"):
- xc = xc.to(self.device)
- c = self.get_learned_conditioning(xc)
- else:
- # todo: get null label from cond_stage_model
- raise NotImplementedError()
- c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
- return c
-
-
-class LatentInpaintDiffusion(LatentDiffusion):
- def __init__(
- self,
- concat_keys=("mask", "masked_image"),
- masked_image_key="masked_image",
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- self.masked_image_key = masked_image_key
- assert self.masked_image_key in concat_keys
- self.concat_keys = concat_keys
-
def should_hijack_inpainting(checkpoint_info):
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
@@ -326,15 +103,6 @@ def should_hijack_inpainting(checkpoint_info):
def do_inpainting_hijack():
- # most of this stuff seems to no longer be needed because it is already included into SD2.0
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
- # this file should be cleaned up later if everything turns out to work fine
-
- # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
- # ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
-
- # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
- # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
- # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
diff --git a/modules/ui.py b/modules/ui.py
index f2e7c0d6..d941cb5f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1529,8 +1529,10 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- restart_gradio = gr.Button(value='Restart UI', variant='primary', elem_id="settings_restart_gradio")
+ with gr.Column(scale=6):
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
result = gr.HTML(elem_id="settings_result")
@@ -1574,6 +1576,11 @@ def create_ui():
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ if os.path.exists("html/licenses.html"):
+ with open("html/licenses.html", encoding="utf8") as file:
+ with gr.TabItem("Licenses"):
+ gr.HTML(file.read(), elem_id="licenses")
+
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
request_notifications.click(
@@ -1659,6 +1666,10 @@ def create_ui():
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+ if os.path.exists("html/footer.html"):
+ with open("html/footer.html", encoding="utf8") as file:
+ gr.HTML(file.read(), elem_id="footer")
+
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
diff --git a/style.css b/style.css
index 7296ce91..2116ec3c 100644
--- a/style.css
+++ b/style.css
@@ -616,6 +616,17 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
padding-bottom: 0.5em;
}
+footer {
+ display: none !important;
+}
+
+#footer{
+ text-align: center;
+}
+
+#footer div{
+ display: inline-block;
+}
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
--
cgit v1.2.3
From 7c89f3718f9f078113833a88a86f02d3205855b4 Mon Sep 17 00:00:00 2001
From: MMaker
Date: Tue, 3 Jan 2023 12:46:48 -0500
Subject: Add image paste fallback
Fixes Firefox pasting support
(and possibly other browsers)
---
javascript/dragdrop.js | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
index 3ed1cb3c..fe008924 100644
--- a/javascript/dragdrop.js
+++ b/javascript/dragdrop.js
@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
return;
}
+ const tmpFile = files[0];
+
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) {
- fileInput.files = files;
+ if ( files.length === 0 ) {
+ files = new DataTransfer();
+ files.items.add(tmpFile);
+ fileInput.files = files.files;
+ } else {
+ fileInput.files = files;
+ }
fileInput.dispatchEvent(new Event('change'));
}
};
--
cgit v1.2.3
From 3e22e294135ed0327ce9d9738655ff03c53df3c0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 3 Jan 2023 21:49:24 +0300
Subject: fix broken send to extras button
---
modules/generation_parameters_copypaste.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d94f11a3..4baf4d9a 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -37,7 +37,10 @@ def quote(text):
def image_from_url_text(filedata):
- if type(filedata) == dict and filedata["is_file"]:
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ filedata = filedata[0]
+
+ if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"]
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
--
cgit v1.2.3
From 917b5bd8d0cd47c9dc241c1852ccd440a8c61668 Mon Sep 17 00:00:00 2001
From: Max Weber
Date: Tue, 3 Jan 2023 18:19:56 -0700
Subject: ui: save dropdown sampling method to the ui-config
---
modules/ui.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modules/ui.py b/modules/ui.py
index d941cb5f..bfc93634 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -635,6 +635,7 @@ def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ sampler_index.save_to_config = True
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
--
cgit v1.2.3
From 4fc81542077af73610279ad7b6b26e38718a0f81 Mon Sep 17 00:00:00 2001
From: Gerschel
Date: Tue, 3 Jan 2023 23:25:34 -0800
Subject: better targetting, class tabs was autoassigned
I moved a preset manager into quicksettings, this function
was targeting my component instead of the tabs. This is
because class tabs is autoassigned, while element id #tabs
is not, this allows a tabbed component to live in the quicksettings.
---
script.js | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/script.js b/script.js
index 9748ec90..0e117d06 100644
--- a/script.js
+++ b/script.js
@@ -4,7 +4,7 @@ function gradioApp() {
}
function get_uiCurrentTab() {
- return gradioApp().querySelector('.tabs button:not(.border-transparent)')
+ return gradioApp().querySelector('#tabs button:not(.border-transparent)')
}
function get_uiCurrentTabContent() {
--
cgit v1.2.3
From e5b7ee910e7bb88f08e8876b5732cb034c6fe529 Mon Sep 17 00:00:00 2001
From: MMaker
Date: Wed, 4 Jan 2023 04:22:01 -0500
Subject: fix: Save full res of intermediate step
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/processing.py b/modules/processing.py
index a172af0b..93e75ba6 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -705,7 +705,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return
if not isinstance(image, Image.Image):
- image = sd_samplers.sample_to_image(image, index)
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
--
cgit v1.2.3
From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 12:35:07 +0300
Subject: helpful error message when trying to load 2.0 without config failing
to load model weights from settings won't break generation for currently
loaded model anymore
---
modules/errors.py | 25 +++++++++++++++++++++++--
modules/sd_models.py | 26 ++++++++++++++++++--------
modules/shared.py | 9 +++++++--
webui.py | 12 ++++++++++--
4 files changed, 58 insertions(+), 14 deletions(-)
diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
import traceback
+def print_error_explanation(message):
+ lines = message.strip().split("\n")
+ max_len = max([len(x) for x in lines])
+
+ print('=' * max_len, file=sys.stderr)
+ for line in lines:
+ print(line, file=sys.stderr)
+ print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ message = str(e)
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+ print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+ """)
+
+
def run(code, task):
try:
code()
except Exception as e:
- print(f"{task}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ display(task, e)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b98b05fc..6846b74a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -278,6 +278,7 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model)
+
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -336,10 +338,12 @@ def load_model(checkpoint_info=None):
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
-
+
if not sd_model:
sd_model = shared.sd_model
+ current_checkpoint_info = sd_model.sd_checkpoint_info
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
@@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
-
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
-
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
+ try:
+ load_model_weights(sd_model, checkpoint_info)
+ except Exception as e:
+ print("Failed to load checkpoint, restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info)
+ raise
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
+
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
print("Weights loaded.")
+
return sd_model
diff --git a/modules/shared.py b/modules/shared.py
index 23657a93..7588c47b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading
+from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path
@@ -494,7 +494,12 @@ class Options:
return False
if self.data_labels[key].onchange is not None:
- self.data_labels[key].onchange()
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ errors.display(e, f"changing setting {key} to {value}")
+ setattr(self, key, oldval)
+ return False
return True
diff --git a/webui.py b/webui.py
index c7d55a97..13375e71 100644
--- a/webui.py
+++ b/webui.py
@@ -9,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from modules import import_hook
+from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
@@ -61,7 +61,15 @@ def initialize():
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
- modules.sd_models.load_model()
+
+ try:
+ modules.sd_models.load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+ exit(1)
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
--
cgit v1.2.3
From 8d8a05a3bbb50fdfeab51679a919d2487bd97976 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 12:47:42 +0300
Subject: find configs for models at runtime rather than when starting
---
modules/sd_hijack_inpainting.py | 5 ++++-
modules/sd_models.py | 31 ++++++++++++++++++-------------
2 files changed, 22 insertions(+), 14 deletions(-)
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 3c214a35..31d2c898 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def should_hijack_inpainting(checkpoint_info):
+ from modules import sd_models
+
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
- cfg_basename = os.path.basename(checkpoint_info.config).lower()
+ cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
+
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6846b74a..6dca4ddf 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
@@ -48,6 +48,14 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def find_checkpoint_config(info):
+ config = os.path.splitext(info.filename)[0] + ".yaml"
+ if os.path.exists(config):
+ return config
+
+ return shared.cmd_opts.config
+
+
def list_models():
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
@@ -73,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -81,12 +89,7 @@ def list_models():
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
- basename, _ = os.path.splitext(filename)
- config = basename + ".yaml"
- if not os.path.exists(config):
- config = shared.cmd_opts.config
-
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
def get_closet_checkpoint_match(searchString):
@@ -282,9 +285,10 @@ def enable_midas_autodownload():
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
+ checkpoint_config = find_checkpoint_config(checkpoint_info)
- if checkpoint_info.config != shared.cmd_opts.config:
- print(f"Loading config from: {checkpoint_info.config}")
+ if checkpoint_config != shared.cmd_opts.config:
+ print(f"Loading config from: {checkpoint_config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -292,7 +296,7 @@ def load_model(checkpoint_info=None):
gc.collect()
devices.torch_gc()
- sd_config = OmegaConf.load(checkpoint_info.config)
+ sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
@@ -302,7 +306,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.finetune_keys = None
# Create a "fake" config with a different name so that we know to unload it when switching models.
- checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+ checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml"))
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
@@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None):
sd_model = shared.sd_model
current_checkpoint_info = sd_model.sd_checkpoint_info
+ checkpoint_config = find_checkpoint_config(current_checkpoint_info)
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
--
cgit v1.2.3
From 96cf15bedecbed97ef9b70b8413d543a9aee5adf Mon Sep 17 00:00:00 2001
From: MMaker
Date: Wed, 4 Jan 2023 05:12:06 -0500
Subject: Add new latent upscale modes
---
modules/shared.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 7588c47b..a10f69a9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -564,8 +564,11 @@ if os.path.exists(config_filename):
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
- "Latent": "bilinear",
- "Latent (nearest)": "nearest",
+ "Latent": {"mode": "bilinear", "antialias": False},
+ "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
+ "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
+ "Latent (bicubic, antialiased)": {"mode": "bicubic", "antialias": True},
+ "Latent (nearest)": {"mode": "nearest", "antialias": False},
}
sd_upscalers = []
--
cgit v1.2.3
From 15fd0b8bc4734ea85bca1acfb12b51465ab9817d Mon Sep 17 00:00:00 2001
From: MMaker
Date: Wed, 4 Jan 2023 05:12:54 -0500
Subject: Update processing.py
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/processing.py b/modules/processing.py
index a172af0b..7c72b56a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -713,7 +713,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
--
cgit v1.2.3
From 4ec6470a1a2d9430b91266426f995e48f59564e1 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 13:26:23 +0300
Subject: fix checkpoint list API
---
modules/api/api.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index 9c670f00..2b1f180c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
+from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import List
@@ -303,7 +303,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
--
cgit v1.2.3
From b2151b934fe0a3613570c6abd7615d3788fd1c8f Mon Sep 17 00:00:00 2001
From: MMaker
Date: Wed, 4 Jan 2023 05:36:18 -0500
Subject: Rename bicubic antialiased option
Comma was causing the the value in PNG info to be quoted, which causes the upscaler dropdown option to be blank when sending to UI
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/shared.py b/modules/shared.py
index a10f69a9..c1b20081 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -567,7 +567,7 @@ latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False},
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
- "Latent (bicubic, antialiased)": {"mode": "bicubic", "antialias": True},
+ "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
"Latent (nearest)": {"mode": "nearest", "antialias": False},
}
--
cgit v1.2.3
From 3bd737767b071878ea980e94b8705f603bcf545e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 14:20:32 +0300
Subject: disable broken API logging
---
modules/api/api.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index a6c1d6ed..6267afdc 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -100,7 +100,6 @@ class Api:
self.router = APIRouter()
self.app = app
- init_api_middleware(self.app)
self.queue_lock = queue_lock
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
--
cgit v1.2.3
From 0cd6399b8b1699b8b7acad6f0ad2988111fe618e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 14:29:13 +0300
Subject: fix broken inpainting model
---
modules/sd_models.py | 3 ---
1 file changed, 3 deletions(-)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6dca4ddf..a568823d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -305,9 +305,6 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
- # Create a "fake" config with a different name so that we know to unload it when switching models.
- checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml"))
-
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
--
cgit v1.2.3
From 11b8160a086c434d5baf4971edda46e6d2126800 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 4 Jan 2023 06:36:57 -0500
Subject: fix typo
---
modules/api/api.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modules/api/api.py b/modules/api/api.py
index 6267afdc..48a70a44 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -101,6 +101,7 @@ class Api:
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
+ api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
--
cgit v1.2.3
From 642142556d8ecdea9beb86d7618b628b1803ab98 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 15:09:53 +0300
Subject: use commandline-supplied cuda device name instead of cuda:0 for
safetensors PR that doesn't fix anything
---
modules/sd_models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ee918f24..76a89e88 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location
if device is None:
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
--
cgit v1.2.3
From 21ee77db314ede7ccbb18787962347c09a4df0c7 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 4 Jan 2023 08:04:38 -0500
Subject: add cross-attention info
---
modules/sd_hijack.py | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index edcbaf52..fa2cd4bb 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -35,26 +35,35 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
+ optimization_method = None
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ optimization_method = 'xformers'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ optimization_method = 'V1'
else:
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
+ optimization_method = 'Doggettx'
+
+ return optimization_method
def undo_optimizations():
@@ -75,6 +84,7 @@ class StableDiffusionModelHijack:
layers = None
circular_enabled = False
clip = None
+ optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
@@ -94,7 +104,7 @@ class StableDiffusionModelHijack:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- apply_optimizations()
+ self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
--
cgit v1.2.3
From 1cfd8aec4ae5a6ca1afd67b44cb4ef6dd14d8c34 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 16:05:42 +0300
Subject: make it possible to work with opts.show_progress_every_n_steps = -1
with medvram
---
modules/shared.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 4fcc6edd..54a6ba23 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -214,12 +214,13 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ if not parallel_processing_allowed:
+ return
+
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
self.do_set_current_image()
def do_set_current_image(self):
- if not parallel_processing_allowed:
- return
if self.current_latent is None:
return
@@ -231,6 +232,7 @@ class State:
self.current_image_sampling_step = self.sampling_step
+
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
--
cgit v1.2.3
From 79c682ad4f2d982b26fa1a15044582d1005134f9 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 4 Jan 2023 08:20:42 -0500
Subject: fix jpeg
---
modules/extras.py | 2 --
modules/images.py | 2 ++
requirements_versions.txt | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index d665440a..7407bfe3 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -19,8 +19,6 @@ from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
import modules.codeformer_model
-import piexif
-import piexif.helper
import gradio as gr
import safetensors.torch
diff --git a/modules/images.py b/modules/images.py
index c3a5fc8b..a73be3fa 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -22,6 +22,8 @@ from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+Image.init() # initialize once all known file format handlers
+
def image_grid(imgs, batch_size=1, rows=None):
if rows is None:
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 975102d9..7ae118cb 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -5,7 +5,7 @@ basicsr==1.4.2
gfpgan==1.3.8
gradio==3.15.0
numpy==1.23.3
-Pillow==9.2.0
+Pillow==9.3.0
realesrgan==0.3.0
torch
omegaconf==2.2.3
--
cgit v1.2.3
From 4d66bf2c0d27702cc83b9cc57ebb1f359d18d938 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 17:24:46 +0300
Subject: add infotext to "-before-highres-fix" images
---
modules/processing.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/modules/processing.py b/modules/processing.py
index fd7c7015..c03e77e7 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -136,6 +136,7 @@ class StableDiffusionProcessing():
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
+ self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
@@ -544,6 +545,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ p.iteration = n
+
if state.skipped:
state.skipped = False
@@ -707,7 +710,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index, approximation=0)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
if latent_scale_mode is not None:
for i in range(samples.shape[0]):
--
cgit v1.2.3
From 184e670126f5fc50ba56fa0fedcf0cf60e45ed7e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 17:45:01 +0300
Subject: fix the merge
---
modules/textual_inversion/textual_inversion.py | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5421a758..8731ea5d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -251,6 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+
def create_dummy_mask(x, width=None, height=None):
if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
@@ -380,17 +381,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
break
with devices.autocast():
- # c = stack_conds(batch.cond).to(devices.device)
- # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
- # print(mask)
- # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
-
-
- if img_c is None:
- img_c = create_dummy_mask(c, training_width, training_height)
-
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
+
+ if img_c is None:
+ img_c = create_dummy_mask(c, training_width, training_height)
+
cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x
--
cgit v1.2.3
From 590c5ae016ae494f4873ca20079b30684ea3060c Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 4 Jan 2023 09:48:54 -0500
Subject: update pillow
---
modules/images.py | 2 --
requirements_versions.txt | 2 +-
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/modules/images.py b/modules/images.py
index a73be3fa..c3a5fc8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -22,8 +22,6 @@ from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
-Image.init() # initialize once all known file format handlers
-
def image_grid(imgs, batch_size=1, rows=None):
if rows is None:
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 7ae118cb..d2899292 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -5,7 +5,7 @@ basicsr==1.4.2
gfpgan==1.3.8
gradio==3.15.0
numpy==1.23.3
-Pillow==9.3.0
+Pillow==9.4.0
realesrgan==0.3.0
torch
omegaconf==2.2.3
--
cgit v1.2.3
From 525cea924562afd676f55470095268a0f6fca59e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 17:58:07 +0300
Subject: use shared function from processing for creating dummy mask when
training inpainting model
---
modules/processing.py | 39 +++++++++++++-------------
modules/textual_inversion/textual_inversion.py | 33 ++++++----------------
2 files changed, 29 insertions(+), 43 deletions(-)
diff --git a/modules/processing.py b/modules/processing.py
index c03e77e7..c7264aff 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
@@ -139,26 +157,9 @@ class StableDiffusionProcessing():
self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- self.is_using_inpainting_conditioning = True
-
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 8731ea5d..2250e41b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -252,26 +252,6 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def create_dummy_mask(x, width=None, height=None):
- if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- else:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- return image_conditioning
-
-
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@@ -346,7 +326,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
else:
print("No saved optimizer exists in checkpoint")
-
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -362,7 +341,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
forced_filename = ""
embedding_yet_to_be_embedded = False
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
img_c = None
+
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@@ -384,10 +365,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
- if img_c is None:
- img_c = create_dummy_mask(c, training_width, training_height)
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
- cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x
--
cgit v1.2.3
From a8eb9e3bf814f72293e474c11e9ff0098859a942 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 18:20:38 +0300
Subject: Revert "Merge pull request #3791 from shirayu/fix/filename"
This reverts commit eed58279e7cb0e873ebd88a29609f9bab0f1f3af, reversing
changes made to 4ae960b01c6711c66985479f14809dc7fa549fc2.
---
modules/images.py | 16 ++++------------
1 file changed, 4 insertions(+), 12 deletions(-)
diff --git a/modules/images.py b/modules/images.py
index 2967fa9a..c3a5fc8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -447,14 +447,6 @@ def get_next_sequence_number(path, basename):
return result + 1
-def truncate_fullpath(full_path, encoding='utf-8'):
- dir_name, full_name = os.path.split(full_path)
- file_name, file_ext = os.path.splitext(full_name)
- max_length = os.statvfs(dir_name).f_namemax
- file_name_truncated = file_name.encode(encoding)[:max_length - len(file_ext)].decode(encoding, 'ignore')
- return os.path.join(dir_name , file_name_truncated + file_ext)
-
-
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
"""Save an image.
@@ -495,7 +487,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if save_to_dirs:
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
- path = truncate_fullpath(os.path.join(path, dirname))
+ path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
@@ -519,13 +511,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
fullfn = None
for i in range(500):
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
- fullfn = truncate_fullpath(os.path.join(path, f"{fn}{file_decoration}.{extension}"))
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
if not os.path.exists(fullfn):
break
else:
- fullfn = truncate_fullpath(os.path.join(path, f"{file_decoration}.{extension}"))
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
else:
- fullfn = truncate_fullpath(os.path.join(path, f"{forced_filename}.{extension}"))
+ fullfn = os.path.join(path, f"{forced_filename}.{extension}")
pnginfo = existing_info or {}
if info is not None:
--
cgit v1.2.3
From 3dae545a03f5102ba5d9c3f27bb6241824c5a916 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 18:42:51 +0300
Subject: rename weirdly named variables from #3176
---
modules/ui.py | 12 +++++-------
1 file changed, 5 insertions(+), 7 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index e4859020..184af7ad 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -162,16 +162,14 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-
-
-def calc_time_left(progress, threshold, label, force_display, showTime):
+def calc_time_left(progress, threshold, label, force_display, show_eta):
if progress == 0:
return ""
else:
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and showTime) or force_display:
+ if (eta_relative > threshold and show_eta) or force_display:
if eta_relative > 3600:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
elif eta_relative > 60:
@@ -194,9 +192,9 @@ def check_progress_call(id_part):
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
# Show progress percentage and time left at the same moment, and base it also on steps done
- showPBText = progress >= 0.01 or shared.state.sampling_step >= 10
+ show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
- time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display, showPBText )
+ time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
if time_left != "":
shared.state.time_left_force_display = True
@@ -204,7 +202,7 @@ def check_progress_call(id_part):
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""
+The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.
+
+MIT License
+
+Copyright (c) 2023 Alex Birch
+Copyright (c) 2023 Amin Rezaei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index b416e9ac..cdc63ed7 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -216,6 +216,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
+# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
def sub_quad_attention_forward(self, x, context=None, mask=None):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 95924d24..fea7aaac 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -1,7 +1,7 @@
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
-# unspecified
+# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
--
cgit v1.2.3
From 82c1f10b144f733460feead0bdc37a861489dc57 Mon Sep 17 00:00:00 2001
From: Dean Hopkins
Date: Fri, 6 Jan 2023 22:00:12 +0000
Subject: increase upscale api validation limit
---
modules/api/models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/api/models.py b/modules/api/models.py
index f77951fc..22b88c59 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel):
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
- upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
--
cgit v1.2.3
From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 7 Jan 2023 01:45:28 +0300
Subject: CLIP hijack rework
---
modules/sd_hijack.py | 6 +-
modules/sd_hijack_clip.py | 348 ++++++++++++-------------
modules/sd_hijack_clip_old.py | 81 ++++++
modules/textual_inversion/textual_inversion.py | 1 -
modules/ui.py | 2 +-
5 files changed, 256 insertions(+), 182 deletions(-)
create mode 100644 modules/sd_hijack_clip_old.py
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fa2cd4bb..71cc145a 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -150,10 +150,10 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
- def tokenize(self, text):
- _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
+ def get_prompt_lengths(self, text):
+ _, token_count = self.clip.process_texts([text])
- return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index ca92b142..ac3020d7 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -1,12 +1,28 @@
import math
+from collections import namedtuple
import torch
from modules import prompt_parser, devices
from modules.shared import opts
-def get_target_prompt_token_count(token_count):
- return math.ceil(max(token_count, 1) / 75) * 75
+
+class PromptChunk:
+ """
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+ so just 75 tokens from prompt.
+ """
+
+ def __init__(self):
+ self.tokens = []
+ self.multipliers = []
+ self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk"""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
@@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
+ self.chunk_length = 75
+
+ def empty_chunk(self):
+ """creates an empty PromptChunk and returns it"""
+
+ chunk = PromptChunk()
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
+ return chunk
+
+ def get_target_prompt_token_count(self, token_count):
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
+ """Converts a batch of texts into a batch of token ids"""
+
raise NotImplementedError
def encode_with_transformers(self, tokens):
+ """
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+ All python lists with tokens are assumed to have same length, usually 77.
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+ model - can be 768 and 1024
+ """
+
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+
raise NotImplementedError
- def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ def tokenize_line(self, line):
+ """
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+ represent the prompt.
+ Returns the list and the total number of tokens in the prompt.
+ """
+
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
@@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
tokenized = self.tokenize([text for text, _ in parsed])
- fixes = []
- remade_tokens = []
- multipliers = []
+ chunks = []
+ chunk = PromptChunk()
+ token_count = 0
last_comma = -1
- for tokens, (text, weight) in zip(tokenized, parsed):
- i = 0
- while i < len(tokens):
- token = tokens[i]
+ def next_chunk():
+ """puts current chunk into the list of results and produces the next one - empty"""
+ nonlocal token_count
+ nonlocal last_comma
+ nonlocal chunk
+
+ token_count += len(chunk.tokens)
+ to_add = self.chunk_length - len(chunk.tokens)
+ if to_add > 0:
+ chunk.tokens += [self.id_end] * to_add
+ chunk.multipliers += [1.0] * to_add
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+
+ last_comma = -1
+ chunks.append(chunk)
+ chunk = PromptChunk()
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ position = 0
+ while position < len(tokens):
+ token = tokens[position]
if token == self.comma_token:
- last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
- last_comma += 1
- reloc_tokens = remade_tokens[last_comma:]
- reloc_mults = multipliers[last_comma:]
+ last_comma = len(chunk.tokens)
+
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+ # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next.
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+ break_location = last_comma + 1
+
+ reloc_tokens = chunk.tokens[break_location:]
+ reloc_mults = chunk.multipliers[break_location:]
- remade_tokens = remade_tokens[:last_comma]
- length = len(remade_tokens)
+ chunk.tokens = chunk.tokens[:break_location]
+ chunk.multipliers = chunk.multipliers[:break_location]
- rem = int(math.ceil(length / 75)) * 75 - length
- remade_tokens += [self.id_end] * rem + reloc_tokens
- multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+ next_chunk()
+ chunk.tokens = reloc_tokens
+ chunk.multipliers = reloc_mults
+ if len(chunk.tokens) == self.chunk_length:
+ next_chunk()
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
if embedding is None:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- iteration = len(remade_tokens) // 75
- if (len(remade_tokens) + emb_len) // 75 != iteration:
- rem = (75 * (iteration + 1) - len(remade_tokens))
- remade_tokens += [self.id_end] * rem
- multipliers += [1.0] * rem
- iteration += 1
- fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- token_count = len(remade_tokens)
- prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens)
-
- remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
- multipliers = multipliers + [1.0] * tokens_to_add
-
- return remade_tokens, fixes, multipliers, token_count
-
- def process_text(self, texts):
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
+ chunk.tokens.append(token)
+ chunk.multipliers.append(weight)
+ position += 1
+ continue
+
+ emb_len = int(embedding.vec.shape[0])
+ if len(chunk.tokens) + emb_len > self.chunk_length:
+ next_chunk()
+
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+
+ chunk.tokens += [0] * emb_len
+ chunk.multipliers += [weight] * emb_len
+ position += embedding_length_in_tokens
+
+ if len(chunk.tokens) > 0:
+ next_chunk()
+
+ return chunks, token_count
+
+ def process_texts(self, texts):
+ """
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+ length, in tokens, of all texts.
+ """
+
token_count = 0
cache = {}
- batch_multipliers = []
+ batch_chunks = []
for line in texts:
if line in cache:
- remade_tokens, fixes, multipliers = cache[line]
+ chunks = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
- cache[line] = (remade_tokens, fixes, multipliers)
+ cache[line] = chunks
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
+ batch_chunks.append(chunks)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+ return batch_chunks, token_count
- def process_text_old(self, texts):
- id_start = self.id_start
- id_end = self.id_end
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
+ def forward(self, texts):
+ """
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ An example shape returned by this function can be: (2, 77, 768).
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+ """
- cache = {}
- batch_tokens = self.tokenize(texts)
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
+ if opts.use_old_emphasis_implementation:
+ import modules.sd_hijack_clip_old
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
- def forward(self, text):
- use_old = opts.use_old_emphasis_implementation
- if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
- else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
- if use_old:
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
- z = None
- i = 0
- while max(map(len, remade_batch_tokens)) != 0:
- rem_tokens = [x[75:] for x in remade_batch_tokens]
- rem_multipliers = [x[75:] for x in batch_multipliers]
-
- self.hijack.fixes = []
- for unfiltered in hijack_fixes:
- fixes = []
- for fix in unfiltered:
- if fix[0] == i:
- fixes.append(fix[1])
- self.hijack.fixes.append(fixes)
-
- tokens = []
- multipliers = []
- for j in range(len(remade_batch_tokens)):
- if len(remade_batch_tokens[j]) > 0:
- tokens.append(remade_batch_tokens[j][:75])
- multipliers.append(batch_multipliers[j][:75])
- else:
- tokens.append([self.id_end] * 75)
- multipliers.append([1.0] * 75)
-
- z1 = self.process_tokens(tokens, multipliers)
- z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- remade_batch_tokens = rem_tokens
- batch_multipliers = rem_multipliers
- i += 1
+ batch_chunks, token_count = self.process_texts(texts)
- return z
+ used_embeddings = {}
+ chunk_count = max([len(x) for x in batch_chunks])
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
- batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+ zs = []
+ for i in range(chunk_count):
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+
+ tokens = [x.tokens for x in batch_chunk]
+ multipliers = [x.multipliers for x in batch_chunk]
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
+ for fixes in self.hijack.fixes:
+ for position, embedding in fixes:
+ used_embeddings[embedding.name] = embedding
+
+ z = self.process_tokens(tokens, multipliers)
+ zs.append(z)
+
+ if len(used_embeddings) > 0:
+ embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+
+ return torch.hstack(zs)
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ """
+ sends one single prompt chunk to be encoded by transformers neural network.
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+ corresponds to one token.
+ """
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
@@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
- batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
new file mode 100644
index 00000000..6d9fbbe6
--- /dev/null
+++ b/modules/sd_hijack_clip_old.py
@@ -0,0 +1,81 @@
+from modules import sd_hijack_clip
+from modules import shared
+
+
+def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ id_start = self.id_start
+ id_end = self.id_end
+ maxlen = self.wrapped.max_length # you get to stay at 77
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_tokens = self.tokenize(texts)
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ i += 1
+ elif embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
+
+ self.hijack.comments += hijack_comments
+
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f9f5e8cd..45882ed6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -79,7 +79,6 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
- # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
diff --git a/modules/ui.py b/modules/ui.py
index b79d24ee..5d2f5bad 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -368,7 +368,7 @@ def update_token_counter(text, steps):
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
- tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"{token_count}/{max_length}"
--
cgit v1.2.3
From f94cfc563bbedd923d5e95563a5e8d93c8516ac3 Mon Sep 17 00:00:00 2001
From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com>
Date: Sat, 7 Jan 2023 01:15:22 +0100
Subject: Changed HTML to textbox instead
Using HTML caused an issue where the row would expand for a frame when changing the sliders because of the loading animation. This solution also doesn't use any additional HTML padding
---
modules/ui.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index 6fc8b7d7..6ea1b5d7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -260,7 +260,7 @@ def calc_resolution_hires(x, y, scale):
scaled_x = int(x * scale // 8) * 8
scaled_y = int(y * scale // 8) * 8
- return "
"""
-
- image = gr_show(False)
- preview_visibility = gr_show(False)
-
- if opts.show_progress_every_n_steps != 0:
- shared.state.set_current_image()
- image = shared.state.current_image
-
- if image is None:
- image = gr.update(value=None)
- else:
- preview_visibility = gr_show(True)
-
- if shared.state.textinfo is not None:
- textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
- else:
- textinfo_result = gr_show(False)
-
- return f"{time.time()}
{progressbar}
", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
- shared.state.job_count = -1
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.textinfo = None
- shared.state.time_start = time.time()
- shared.state.time_left_force_display = False
-
- return check_progress_call(id_part)
-
-
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(path + "/" + str(x.label), x)
-
-
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
-
-
-def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
- from modules import processing, devices
-
- if not enable:
- return ""
-
- p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
-
- with devices.autocast():
- p.init([""], [0], [0])
-
- return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
-
-
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
-
-
-def interrogate(image):
- prompt = shared.interrogator.interrogate(image.convert("RGB"))
-
- return gr_show(True) if prompt is None else prompt
-
-
-def interrogate_deepbooru(image):
- prompt = deepbooru.model.tag(image)
- return gr_show(True) if prompt is None else prompt
-
-
-def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
-
- with gr.Group(elem_id=target_interface + '_subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
-
- # Components to show/hide based on the 'Extra' checkbox
- seed_extras = []
-
- with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
- seed_extras.append(seed_extra_row_1)
- subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
-
- with FormRow(visible=False) as seed_extra_row_2:
- seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
-
- random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
- random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
-
- def change_visibility(show):
- return {comp: gr_show(show) for comp in seed_extras}
-
- seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
-
- return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-
-
-
-def connect_clear_prompt(button):
- """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
- button.click(
- _js="clear_prompt",
- fn=None,
- inputs=[],
- outputs=[],
- )
-
-
-def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
- """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
- (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
- was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
- def copy_seed(gen_info_string: str, index):
- res = -1
-
- try:
- gen_info = json.loads(gen_info_string)
- index -= gen_info.get('index_of_first_image', 0)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
-
- except json.decoder.JSONDecodeError as e:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
-
- return [res, gr_show(False)]
-
- reuse_seed.click(
- fn=copy_seed,
- _js="(x, y) => [x, selected_gallery_index()]",
- show_progress=False,
- inputs=[generation_info, dummy_component],
- outputs=[seed, dummy_component]
- )
-
-
-def update_token_counter(text, steps):
- try:
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
- prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
-
- except Exception:
- # a parsing error can happen here during typing, and we don't want to bother the user with
- # messages related to it in console
- prompt_schedules = [[[steps, text]]]
-
- flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step, prompt_text in flat_prompts]
- token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- style_class = ' class="red"' if (token_count > max_length) else ""
- return f"{token_count}/{max_length}"
-
-
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
-
- with gr.Row(elem_id="toprow"):
- with gr.Column(scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
- placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
- placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_id="interrogate_col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1):
- with gr.Row():
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
-
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
- if textinfo is None:
- textinfo = gr.HTML(visible=False)
-
- check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
- check_progress.click(
- fn=lambda: check_progress_call(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
- check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
- check_progress_initial.click(
- fn=lambda: check_progress_call_initial(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
-
-def apply_setting(key, value):
- if value is None:
- return gr.update()
-
- if shared.cmd_opts.freeze_settings:
- return gr.update()
-
- # dont allow model to be swapped when model hash exists in prompt
- if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
- return gr.update()
-
- if key == "sd_model_checkpoint":
- ckpt_info = sd_models.get_closet_checkpoint_match(value)
-
- if ckpt_info is not None:
- value = ckpt_info.title
- else:
- return gr.update()
-
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- return
-
- valtype = type(opts.data_labels[key].default)
- oldval = opts.data.get(key, None)
- opts.data[key] = valtype(value) if valtype != type(None) else value
- if oldval != value and opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
-
- opts.save(shared.config_filename)
- return value
-
-
-def update_generation_info(args):
- generation_info, html_info, img_index = args
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info
-
-
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
- for k, v in args.items():
- setattr(refresh_component, k, v)
-
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn=refresh,
- inputs=[],
- outputs=[refresh_component]
- )
- return refresh_button
-
-
-def create_output_panel(tabname, outdir):
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel'):
- with gr.Group():
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- save_zip.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
-
-
-def create_sampler_and_steps_selection(choices, tabname):
- if opts.samplers_in_dropdown:
- with FormRow(elem_id=f"sampler_selection_{tabname}"):
- sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- else:
- with FormGroup(elem_id=f"sampler_selection_{tabname}"):
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
-
- return steps, sampler_index
-
-
-def ordered_ui_categories():
- user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
-
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
- yield category
-
-
-def create_ui():
- import modules.img2img
- import modules.txt2img
-
- reload_javascript()
-
- parameters_copypaste.reset()
-
- modules.scripts.scripts_current = modules.scripts.scripts_txt2img
- modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
-
- with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
-
- dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
- for category in ordered_ui_categories():
- if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
-
- elif category == "dimensions":
- with FormRow():
- with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
-
- if opts.dimensions_and_batch_together:
- with gr.Column(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "cfg":
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
-
- elif category == "seed":
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
-
- elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
- enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
- hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
-
- elif category == "hires_fix":
- with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
-
- with FormRow(elem_id="txt2img_hires_fix_row2"):
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
- hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
-
- elif category == "batch":
- if not opts.dimensions_and_batch_together:
- with FormRow(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "scripts":
- with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
-
- hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- for input in hr_resolution_preview_inputs:
- input.change(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False,
- )
- input.change(
- None,
- _js="onCalcResolutionHires",
- inputs=hr_resolution_preview_inputs,
- outputs=[],
- show_progress=False,
- )
-
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
- parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
-
- connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
- connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
-
- txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
- _js="submit",
- inputs=[
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
- steps,
- sampler_index,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- enable_hr,
- denoising_strength,
- hr_scale,
- hr_upscaler,
- hr_second_pass_steps,
- hr_resize_x,
- hr_resize_y,
- ] + custom_inputs,
-
- outputs=[
- txt2img_gallery,
- generation_info,
- html_info,
- html_log,
- ],
- show_progress=False,
- )
-
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
-
- txt_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- txt_prompt_img
- ],
- outputs=[
- txt2img_prompt,
- txt_prompt_img
- ]
- )
-
- enable_hr.change(
- fn=lambda x: gr_show(x),
- inputs=[enable_hr],
- outputs=[hr_options],
- show_progress = False,
- )
-
- txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_index, "Sampler"),
- (restore_faces, "Face restoration"),
- (cfg_scale, "CFG scale"),
- (seed, "Seed"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d),
- (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- *modules.scripts.scripts_txt2img.infotext_fields
- ]
- parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
-
- txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
- steps,
- sampler_index,
- cfg_scale,
- seed,
- width,
- height,
- ]
-
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
-
- modules.scripts.scripts_current = modules.scripts.scripts_img2img
- modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
-
- with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
-
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
-
- with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
-
- with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
- init_img_with_mask_orig = gr.State(None)
-
- use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
- if use_color_sketch:
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
-
- init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
-
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
-
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
-
- with FormRow():
- mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
- hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
"""
+
+ image = gr_show(False)
+ preview_visibility = gr_show(False)
+
+ if opts.show_progress_every_n_steps != 0:
+ shared.state.set_current_image()
+ image = shared.state.current_image
+
+ if image is None:
+ image = gr.update(value=None)
+ else:
+ preview_visibility = gr_show(True)
+
+ if shared.state.textinfo is not None:
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
+ else:
+ textinfo_result = gr_show(False)
+
+ return f"{time.time()}
{progressbar}
", preview_visibility, image, textinfo_result
+
+
+def check_progress_call_initial(id_part):
+ shared.state.job_count = -1
+ shared.state.current_latent = None
+ shared.state.current_image = None
+ shared.state.textinfo = None
+ shared.state.time_start = time.time()
+ shared.state.time_left_force_display = False
+
+ return check_progress_call(id_part)
+
+
+def visit(x, func, path=""):
+ if hasattr(x, 'children'):
+ for c in x.children:
+ visit(c, func, path)
+ elif x.label is not None:
+ func(path + "/" + str(x.label), x)
+
+
+def add_style(name: str, prompt: str, negative_prompt: str):
+ if name is None:
+ return [gr_show() for x in range(4)]
+
+ style = modules.styles.PromptStyle(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
+ # reserialize all styles every time we save them
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+
+
+def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
+ from modules import processing, devices
+
+ if not enable:
+ return ""
+
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
+
+ with devices.autocast():
+ p.init([""], [0], [0])
+
+ return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
+
+
+def apply_styles(prompt, prompt_neg, style1_name, style2_name):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
+
+
+def interrogate(image):
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
+
+ return gr_show(True) if prompt is None else prompt
+
+
+def interrogate_deepbooru(image):
+ prompt = deepbooru.model.tag(image)
+ return gr_show(True) if prompt is None else prompt
+
+
+def create_seed_inputs(target_interface):
+ with FormRow(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ seed.style(container=False)
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+
+ with gr.Group(elem_id=target_interface + '_subseed_show_box'):
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
+
+ # Components to show/hide based on the 'Extra' checkbox
+ seed_extras = []
+
+ with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
+ seed_extras.append(seed_extra_row_1)
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed.style(container=False)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
+
+ with FormRow(visible=False) as seed_extra_row_2:
+ seed_extras.append(seed_extra_row_2)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
+
+ random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
+ random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
+
+ def change_visibility(show):
+ return {comp: gr_show(show) for comp in seed_extras}
+
+ seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
+
+ return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
+
+
+
+def connect_clear_prompt(button):
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
+ button.click(
+ _js="clear_prompt",
+ fn=None,
+ inputs=[],
+ outputs=[],
+ )
+
+
+def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
+ """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
+ (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
+ was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
+ def copy_seed(gen_info_string: str, index):
+ res = -1
+
+ try:
+ gen_info = json.loads(gen_info_string)
+ index -= gen_info.get('index_of_first_image', 0)
+
+ if is_subseed and gen_info.get('subseed_strength', 0) > 0:
+ all_subseeds = gen_info.get('all_subseeds', [-1])
+ res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
+ else:
+ all_seeds = gen_info.get('all_seeds', [-1])
+ res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
+
+ except json.decoder.JSONDecodeError as e:
+ if gen_info_string != '':
+ print("Error parsing JSON generation info:", file=sys.stderr)
+ print(gen_info_string, file=sys.stderr)
+
+ return [res, gr_show(False)]
+
+ reuse_seed.click(
+ fn=copy_seed,
+ _js="(x, y) => [x, selected_gallery_index()]",
+ show_progress=False,
+ inputs=[generation_info, dummy_component],
+ outputs=[seed, dummy_component]
+ )
+
+
+def update_token_counter(text, steps):
+ try:
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
+
+ except Exception:
+ # a parsing error can happen here during typing, and we don't want to bother the user with
+ # messages related to it in console
+ prompt_schedules = [[[steps, text]]]
+
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
+ style_class = ' class="red"' if (token_count > max_length) else ""
+ return f"{token_count}/{max_length}"
+
+
+def create_toprow(is_img2img):
+ id_part = "img2img" if is_img2img else "txt2img"
+
+ with gr.Row(elem_id="toprow"):
+ with gr.Column(scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
+ with gr.Column(scale=1, elem_id="roll_col"):
+ paste = gr.Button(value=paste_symbol, elem_id="paste")
+ save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
+ prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
+ clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
+
+ button_interrogate = None
+ button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_id="interrogate_col"):
+ button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+
+ with gr.Column(scale=1):
+ with gr.Row():
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+ skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+
+
+def setup_progressbar(progressbar, preview, id_part, textinfo=None):
+ if textinfo is None:
+ textinfo = gr.HTML(visible=False)
+
+ check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
+ check_progress.click(
+ fn=lambda: check_progress_call(id_part),
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview, textinfo],
+ )
+
+ check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
+ check_progress_initial.click(
+ fn=lambda: check_progress_call_initial(id_part),
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview, textinfo],
+ )
+
+
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
+ # dont allow model to be swapped when model hash exists in prompt
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
+ return gr.update()
+
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return
+
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data.get(key, None)
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+ return value
+
+
+def update_generation_info(args):
+ generation_info, html_info, img_index = args
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info
+ return plaintext_to_html(generation_info["infotexts"][img_index])
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_output_panel(tabname, outdir):
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel'):
+ with gr.Group():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+ open_folder_button.click(
+ fn=lambda: open_folder(opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="(x, y) => [x, y, selected_gallery_index()]",
+ inputs=[generation_info, html_info],
+ outputs=[html_info],
+ preprocess=False
+ )
+
+ save.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ save_zip.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ else:
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_sampler_and_steps_selection(choices, tabname):
+ if opts.samplers_in_dropdown:
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
+ sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ else:
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+ return steps, sampler_index
+
+
+def ordered_ui_categories():
+ user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ yield category
+
+
+def create_ui():
+ import modules.img2img
+ import modules.txt2img
+
+ reload_javascript()
+
+ parameters_copypaste.reset()
+
+ modules.scripts.scripts_current = modules.scripts.scripts_txt2img
+ modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
+ txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+
+ dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+ with gr.Row(elem_id='txt2img_progress_row'):
+ with gr.Column(scale=1):
+ pass
+
+ with gr.Column(scale=1):
+ progressbar = gr.HTML(elem_id="txt2img_progressbar")
+ txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
+ setup_progressbar(progressbar, txt2img_preview, 'txt2img')
+
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "cfg":
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="txt2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
+
+ elif category == "hires_fix":
+ with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ with FormRow(elem_id="txt2img_hires_fix_row1"):
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ with FormRow(elem_id="txt2img_hires_fix_row2"):
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="txt2img_script_container"):
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
+ for input in hr_resolution_preview_inputs:
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
+
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
+
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+
+ txt2img_args = dict(
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
+ _js="submit",
+ inputs=[
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ txt2img_prompt_style,
+ txt2img_prompt_style2,
+ steps,
+ sampler_index,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ enable_hr,
+ denoising_strength,
+ hr_scale,
+ hr_upscaler,
+ hr_second_pass_steps,
+ hr_resize_x,
+ hr_resize_y,
+ ] + custom_inputs,
+
+ outputs=[
+ txt2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ txt2img_prompt.submit(**txt2img_args)
+ submit.click(**txt2img_args)
+
+ txt_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ txt_prompt_img
+ ],
+ outputs=[
+ txt2img_prompt,
+ txt_prompt_img
+ ]
+ )
+
+ enable_hr.change(
+ fn=lambda x: gr_show(x),
+ inputs=[enable_hr],
+ outputs=[hr_options],
+ show_progress = False,
+ )
+
+ txt2img_paste_fields = [
+ (txt2img_prompt, "Prompt"),
+ (txt2img_negative_prompt, "Negative prompt"),
+ (steps, "Steps"),
+ (sampler_index, "Sampler"),
+ (restore_faces, "Face restoration"),
+ (cfg_scale, "CFG scale"),
+ (seed, "Seed"),
+ (width, "Size-1"),
+ (height, "Size-2"),
+ (batch_size, "Batch size"),
+ (subseed, "Variation seed"),
+ (subseed_strength, "Variation seed strength"),
+ (seed_resize_from_w, "Seed resize from-1"),
+ (seed_resize_from_h, "Seed resize from-2"),
+ (denoising_strength, "Denoising strength"),
+ (enable_hr, lambda d: "Denoising strength" in d),
+ (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (hr_scale, "Hires upscale"),
+ (hr_upscaler, "Hires upscaler"),
+ (hr_second_pass_steps, "Hires steps"),
+ (hr_resize_x, "Hires resize-1"),
+ (hr_resize_y, "Hires resize-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
+ ]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
+ ]
+
+ token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+
+ modules.scripts.scripts_current = modules.scripts.scripts_img2img
+ modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
+ img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+
+ with gr.Row(elem_id='img2img_progress_row'):
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+ with gr.Column(scale=1):
+ pass
+
+ with gr.Column(scale=1):
+ progressbar = gr.HTML(elem_id="img2img_progressbar")
+ img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
+ setup_progressbar(progressbar, img2img_preview, 'img2img')
+
+ with FormRow().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="img2img_settings"):
+
+ with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
+
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
+ init_img_with_mask_orig = gr.State(None)
+
+ use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
+ if use_color_sketch:
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
+
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
+
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
+ hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
+ gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
"""
+
+ image = gr_show(False)
+ preview_visibility = gr_show(False)
+
+ if opts.show_progress_every_n_steps != 0:
+ shared.state.set_current_image()
+ image = shared.state.current_image
+
+ if image is None:
+ image = gr.update(value=None)
+ else:
+ preview_visibility = gr_show(True)
+
+ if shared.state.textinfo is not None:
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
+ else:
+ textinfo_result = gr_show(False)
+
+ return f"{time.time()}
{progressbar}
", preview_visibility, image, textinfo_result
+
+
+def check_progress_call_initial(id_part):
+ shared.state.job_count = -1
+ shared.state.current_latent = None
+ shared.state.current_image = None
+ shared.state.textinfo = None
+ shared.state.time_start = time.time()
+ shared.state.time_left_force_display = False
+
+ return check_progress_call(id_part)
+
+
+def visit(x, func, path=""):
+ if hasattr(x, 'children'):
+ for c in x.children:
+ visit(c, func, path)
+ elif x.label is not None:
+ func(path + "/" + str(x.label), x)
+
+
+def add_style(name: str, prompt: str, negative_prompt: str):
+ if name is None:
+ return [gr_show() for x in range(4)]
+
+ style = modules.styles.PromptStyle(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
+ # reserialize all styles every time we save them
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+
+
+def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
+ from modules import processing, devices
+
+ if not enable:
+ return ""
+
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
+
+ with devices.autocast():
+ p.init([""], [0], [0])
+
+ return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
+
+
+def apply_styles(prompt, prompt_neg, style1_name, style2_name):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
+
+
+def interrogate(image):
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
+
+ return gr_show(True) if prompt is None else prompt
+
+
+def interrogate_deepbooru(image):
+ prompt = deepbooru.model.tag(image)
+ return gr_show(True) if prompt is None else prompt
+
+
+def create_seed_inputs(target_interface):
+ with FormRow(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ seed.style(container=False)
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+
+ with gr.Group(elem_id=target_interface + '_subseed_show_box'):
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
+
+ # Components to show/hide based on the 'Extra' checkbox
+ seed_extras = []
+
+ with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
+ seed_extras.append(seed_extra_row_1)
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed.style(container=False)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
+
+ with FormRow(visible=False) as seed_extra_row_2:
+ seed_extras.append(seed_extra_row_2)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
+
+ random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
+ random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
+
+ def change_visibility(show):
+ return {comp: gr_show(show) for comp in seed_extras}
+
+ seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
+
+ return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
+
+
+
+def connect_clear_prompt(button):
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
+ button.click(
+ _js="clear_prompt",
+ fn=None,
+ inputs=[],
+ outputs=[],
+ )
+
+
+def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
+ """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
+ (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
+ was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
+ def copy_seed(gen_info_string: str, index):
+ res = -1
+
+ try:
+ gen_info = json.loads(gen_info_string)
+ index -= gen_info.get('index_of_first_image', 0)
+
+ if is_subseed and gen_info.get('subseed_strength', 0) > 0:
+ all_subseeds = gen_info.get('all_subseeds', [-1])
+ res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
+ else:
+ all_seeds = gen_info.get('all_seeds', [-1])
+ res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
+
+ except json.decoder.JSONDecodeError as e:
+ if gen_info_string != '':
+ print("Error parsing JSON generation info:", file=sys.stderr)
+ print(gen_info_string, file=sys.stderr)
+
+ return [res, gr_show(False)]
+
+ reuse_seed.click(
+ fn=copy_seed,
+ _js="(x, y) => [x, selected_gallery_index()]",
+ show_progress=False,
+ inputs=[generation_info, dummy_component],
+ outputs=[seed, dummy_component]
+ )
+
+
+def update_token_counter(text, steps):
+ try:
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
+
+ except Exception:
+ # a parsing error can happen here during typing, and we don't want to bother the user with
+ # messages related to it in console
+ prompt_schedules = [[[steps, text]]]
+
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
+ style_class = ' class="red"' if (token_count > max_length) else ""
+ return f"{token_count}/{max_length}"
+
+
+def create_toprow(is_img2img):
+ id_part = "img2img" if is_img2img else "txt2img"
+
+ with gr.Row(elem_id="toprow"):
+ with gr.Column(scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
+ with gr.Column(scale=1, elem_id="roll_col"):
+ paste = gr.Button(value=paste_symbol, elem_id="paste")
+ save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
+ prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
+ clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
+
+ button_interrogate = None
+ button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_id="interrogate_col"):
+ button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+
+ with gr.Column(scale=1):
+ with gr.Row():
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+ skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+
+
+def setup_progressbar(progressbar, preview, id_part, textinfo=None):
+ if textinfo is None:
+ textinfo = gr.HTML(visible=False)
+
+ check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
+ check_progress.click(
+ fn=lambda: check_progress_call(id_part),
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview, textinfo],
+ )
+
+ check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
+ check_progress_initial.click(
+ fn=lambda: check_progress_call_initial(id_part),
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview, textinfo],
+ )
+
+
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
+ # dont allow model to be swapped when model hash exists in prompt
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
+ return gr.update()
+
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return
+
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data.get(key, None)
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+ return value
+
+
+def update_generation_info(args):
+ generation_info, html_info, img_index = args
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info
+ return plaintext_to_html(generation_info["infotexts"][img_index])
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_output_panel(tabname, outdir):
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel'):
+ with gr.Group():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+ open_folder_button.click(
+ fn=lambda: open_folder(opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="(x, y) => [x, y, selected_gallery_index()]",
+ inputs=[generation_info, html_info],
+ outputs=[html_info],
+ preprocess=False
+ )
+
+ save.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ save_zip.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ else:
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_sampler_and_steps_selection(choices, tabname):
+ if opts.samplers_in_dropdown:
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
+ sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ else:
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+ return steps, sampler_index
+
+
+def ordered_ui_categories():
+ user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ yield category
+
+
+def create_ui():
+ import modules.img2img
+ import modules.txt2img
+
+ reload_javascript()
+
+ parameters_copypaste.reset()
+
+ modules.scripts.scripts_current = modules.scripts.scripts_txt2img
+ modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
+ txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+
+ dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+ with gr.Row(elem_id='txt2img_progress_row'):
+ with gr.Column(scale=1):
+ pass
+
+ with gr.Column(scale=1):
+ progressbar = gr.HTML(elem_id="txt2img_progressbar")
+ txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
+ setup_progressbar(progressbar, txt2img_preview, 'txt2img')
+
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "cfg":
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="txt2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
+
+ elif category == "hires_fix":
+ with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ with FormRow(elem_id="txt2img_hires_fix_row1"):
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ with FormRow(elem_id="txt2img_hires_fix_row2"):
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="txt2img_script_container"):
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
+ for input in hr_resolution_preview_inputs:
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
+
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
+
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+
+ txt2img_args = dict(
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
+ _js="submit",
+ inputs=[
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ txt2img_prompt_style,
+ txt2img_prompt_style2,
+ steps,
+ sampler_index,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ enable_hr,
+ denoising_strength,
+ hr_scale,
+ hr_upscaler,
+ hr_second_pass_steps,
+ hr_resize_x,
+ hr_resize_y,
+ ] + custom_inputs,
+
+ outputs=[
+ txt2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ txt2img_prompt.submit(**txt2img_args)
+ submit.click(**txt2img_args)
+
+ txt_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ txt_prompt_img
+ ],
+ outputs=[
+ txt2img_prompt,
+ txt_prompt_img
+ ]
+ )
+
+ enable_hr.change(
+ fn=lambda x: gr_show(x),
+ inputs=[enable_hr],
+ outputs=[hr_options],
+ show_progress = False,
+ )
+
+ txt2img_paste_fields = [
+ (txt2img_prompt, "Prompt"),
+ (txt2img_negative_prompt, "Negative prompt"),
+ (steps, "Steps"),
+ (sampler_index, "Sampler"),
+ (restore_faces, "Face restoration"),
+ (cfg_scale, "CFG scale"),
+ (seed, "Seed"),
+ (width, "Size-1"),
+ (height, "Size-2"),
+ (batch_size, "Batch size"),
+ (subseed, "Variation seed"),
+ (subseed_strength, "Variation seed strength"),
+ (seed_resize_from_w, "Seed resize from-1"),
+ (seed_resize_from_h, "Seed resize from-2"),
+ (denoising_strength, "Denoising strength"),
+ (enable_hr, lambda d: "Denoising strength" in d),
+ (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (hr_scale, "Hires upscale"),
+ (hr_upscaler, "Hires upscaler"),
+ (hr_second_pass_steps, "Hires steps"),
+ (hr_resize_x, "Hires resize-1"),
+ (hr_resize_y, "Hires resize-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
+ ]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
+ ]
+
+ token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+
+ modules.scripts.scripts_current = modules.scripts.scripts_img2img
+ modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
+ img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+
+ with gr.Row(elem_id='img2img_progress_row'):
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+ with gr.Column(scale=1):
+ pass
+
+ with gr.Column(scale=1):
+ progressbar = gr.HTML(elem_id="img2img_progressbar")
+ img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
+ setup_progressbar(progressbar, img2img_preview, 'img2img')
+
+ with FormRow().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="img2img_settings"):
+
+ with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
+
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
+ init_img_with_mask_orig = gr.State(None)
+
+ use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
+ if use_color_sketch:
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
+
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
+
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
+ hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
+ gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
"""
-
- image = gr_show(False)
- preview_visibility = gr_show(False)
-
- if opts.show_progress_every_n_steps != 0:
- shared.state.set_current_image()
- image = shared.state.current_image
-
- if image is None:
- image = gr.update(value=None)
- else:
- preview_visibility = gr_show(True)
-
- if shared.state.textinfo is not None:
- textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
- else:
- textinfo_result = gr_show(False)
-
- return f"{time.time()}
{progressbar}
", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
- shared.state.job_count = -1
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.textinfo = None
- shared.state.time_start = time.time()
- shared.state.time_left_force_display = False
-
- return check_progress_call(id_part)
-
-
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(path + "/" + str(x.label), x)
-
-
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
-
-
-def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
- from modules import processing, devices
-
- if not enable:
- return ""
-
- p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
-
- with devices.autocast():
- p.init([""], [0], [0])
-
- return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
-
-
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
-
-
-def interrogate(image):
- prompt = shared.interrogator.interrogate(image.convert("RGB"))
-
- return gr_show(True) if prompt is None else prompt
-
-
-def interrogate_deepbooru(image):
- prompt = deepbooru.model.tag(image)
- return gr_show(True) if prompt is None else prompt
-
-
-def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
-
- with gr.Group(elem_id=target_interface + '_subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
-
- # Components to show/hide based on the 'Extra' checkbox
- seed_extras = []
-
- with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
- seed_extras.append(seed_extra_row_1)
- subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
-
- with FormRow(visible=False) as seed_extra_row_2:
- seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
-
- random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
- random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
-
- def change_visibility(show):
- return {comp: gr_show(show) for comp in seed_extras}
-
- seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
-
- return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-
-
-
-def connect_clear_prompt(button):
- """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
- button.click(
- _js="clear_prompt",
- fn=None,
- inputs=[],
- outputs=[],
- )
-
-
-def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
- """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
- (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
- was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
- def copy_seed(gen_info_string: str, index):
- res = -1
-
- try:
- gen_info = json.loads(gen_info_string)
- index -= gen_info.get('index_of_first_image', 0)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
-
- except json.decoder.JSONDecodeError as e:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
-
- return [res, gr_show(False)]
-
- reuse_seed.click(
- fn=copy_seed,
- _js="(x, y) => [x, selected_gallery_index()]",
- show_progress=False,
- inputs=[generation_info, dummy_component],
- outputs=[seed, dummy_component]
- )
-
-
-def update_token_counter(text, steps):
- try:
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
- prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
-
- except Exception:
- # a parsing error can happen here during typing, and we don't want to bother the user with
- # messages related to it in console
- prompt_schedules = [[[steps, text]]]
-
- flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step, prompt_text in flat_prompts]
- token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- style_class = ' class="red"' if (token_count > max_length) else ""
- return f"{token_count}/{max_length}"
-
-
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
-
- with gr.Row(elem_id="toprow"):
- with gr.Column(scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
- placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
- placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_id="interrogate_col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1):
- with gr.Row():
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
-
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
- if textinfo is None:
- textinfo = gr.HTML(visible=False)
-
- check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
- check_progress.click(
- fn=lambda: check_progress_call(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
- check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
- check_progress_initial.click(
- fn=lambda: check_progress_call_initial(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
-
-def apply_setting(key, value):
- if value is None:
- return gr.update()
-
- if shared.cmd_opts.freeze_settings:
- return gr.update()
-
- # dont allow model to be swapped when model hash exists in prompt
- if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
- return gr.update()
-
- if key == "sd_model_checkpoint":
- ckpt_info = sd_models.get_closet_checkpoint_match(value)
-
- if ckpt_info is not None:
- value = ckpt_info.title
- else:
- return gr.update()
-
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- return
-
- valtype = type(opts.data_labels[key].default)
- oldval = opts.data.get(key, None)
- opts.data[key] = valtype(value) if valtype != type(None) else value
- if oldval != value and opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
-
- opts.save(shared.config_filename)
- return value
-
-
-def update_generation_info(args):
- generation_info, html_info, img_index = args
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info
-
-
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
- for k, v in args.items():
- setattr(refresh_component, k, v)
-
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn=refresh,
- inputs=[],
- outputs=[refresh_component]
- )
- return refresh_button
-
-
-def create_output_panel(tabname, outdir):
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel'):
- with gr.Group():
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- save_zip.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
-
-
-def create_sampler_and_steps_selection(choices, tabname):
- if opts.samplers_in_dropdown:
- with FormRow(elem_id=f"sampler_selection_{tabname}"):
- sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- else:
- with FormGroup(elem_id=f"sampler_selection_{tabname}"):
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
-
- return steps, sampler_index
-
-
-def ordered_ui_categories():
- user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
-
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
- yield category
-
-
-def create_ui():
- import modules.img2img
- import modules.txt2img
-
- reload_javascript()
-
- parameters_copypaste.reset()
-
- modules.scripts.scripts_current = modules.scripts.scripts_txt2img
- modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
-
- with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
-
- dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
- for category in ordered_ui_categories():
- if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
-
- elif category == "dimensions":
- with FormRow():
- with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
-
- if opts.dimensions_and_batch_together:
- with gr.Column(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "cfg":
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
-
- elif category == "seed":
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
-
- elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
- enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
- hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
-
- elif category == "hires_fix":
- with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
-
- with FormRow(elem_id="txt2img_hires_fix_row2"):
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
- hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
-
- elif category == "batch":
- if not opts.dimensions_and_batch_together:
- with FormRow(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "scripts":
- with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
-
- hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- for input in hr_resolution_preview_inputs:
- input.change(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False,
- )
- input.change(
- None,
- _js="onCalcResolutionHires",
- inputs=hr_resolution_preview_inputs,
- outputs=[],
- show_progress=False,
- )
-
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
- parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
-
- connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
- connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
-
- txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
- _js="submit",
- inputs=[
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
- steps,
- sampler_index,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- enable_hr,
- denoising_strength,
- hr_scale,
- hr_upscaler,
- hr_second_pass_steps,
- hr_resize_x,
- hr_resize_y,
- ] + custom_inputs,
-
- outputs=[
- txt2img_gallery,
- generation_info,
- html_info,
- html_log,
- ],
- show_progress=False,
- )
-
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
-
- txt_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- txt_prompt_img
- ],
- outputs=[
- txt2img_prompt,
- txt_prompt_img
- ]
- )
-
- enable_hr.change(
- fn=lambda x: gr_show(x),
- inputs=[enable_hr],
- outputs=[hr_options],
- show_progress = False,
- )
-
- txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_index, "Sampler"),
- (restore_faces, "Face restoration"),
- (cfg_scale, "CFG scale"),
- (seed, "Seed"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d),
- (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- *modules.scripts.scripts_txt2img.infotext_fields
- ]
- parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
-
- txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
- steps,
- sampler_index,
- cfg_scale,
- seed,
- width,
- height,
- ]
-
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
-
- modules.scripts.scripts_current = modules.scripts.scripts_img2img
- modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
-
- with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
-
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
-
- with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
-
- with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
- init_img_with_mask_orig = gr.State(None)
-
- use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
- if use_color_sketch:
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
-
- init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
-
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
-
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
-
- with FormRow():
- mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
- hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
"""
-
- image = gr_show(False)
- preview_visibility = gr_show(False)
-
- if opts.show_progress_every_n_steps != 0:
- shared.state.set_current_image()
- image = shared.state.current_image
-
- if image is None:
- image = gr.update(value=None)
- else:
- preview_visibility = gr_show(True)
-
- if shared.state.textinfo is not None:
- textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
- else:
- textinfo_result = gr_show(False)
-
- return f"{time.time()}
{progressbar}
", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
- shared.state.job_count = -1
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.textinfo = None
- shared.state.time_start = time.time()
- shared.state.time_left_force_display = False
-
- return check_progress_call(id_part)
-
-
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(path + "/" + str(x.label), x)
-
-
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
-
-
-def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
- from modules import processing, devices
-
- if not enable:
- return ""
-
- p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
-
- with devices.autocast():
- p.init([""], [0], [0])
-
- return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
-
-
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
-
-
-def interrogate(image):
- prompt = shared.interrogator.interrogate(image.convert("RGB"))
-
- return gr_show(True) if prompt is None else prompt
-
-
-def interrogate_deepbooru(image):
- prompt = deepbooru.model.tag(image)
- return gr_show(True) if prompt is None else prompt
-
-
-def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
-
- with gr.Group(elem_id=target_interface + '_subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
-
- # Components to show/hide based on the 'Extra' checkbox
- seed_extras = []
-
- with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
- seed_extras.append(seed_extra_row_1)
- subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
-
- with FormRow(visible=False) as seed_extra_row_2:
- seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
-
- random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
- random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
-
- def change_visibility(show):
- return {comp: gr_show(show) for comp in seed_extras}
-
- seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
-
- return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-
-
-
-def connect_clear_prompt(button):
- """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
- button.click(
- _js="clear_prompt",
- fn=None,
- inputs=[],
- outputs=[],
- )
-
-
-def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
- """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
- (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
- was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
- def copy_seed(gen_info_string: str, index):
- res = -1
-
- try:
- gen_info = json.loads(gen_info_string)
- index -= gen_info.get('index_of_first_image', 0)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
-
- except json.decoder.JSONDecodeError as e:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
-
- return [res, gr_show(False)]
-
- reuse_seed.click(
- fn=copy_seed,
- _js="(x, y) => [x, selected_gallery_index()]",
- show_progress=False,
- inputs=[generation_info, dummy_component],
- outputs=[seed, dummy_component]
- )
-
-
-def update_token_counter(text, steps):
- try:
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
- prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
-
- except Exception:
- # a parsing error can happen here during typing, and we don't want to bother the user with
- # messages related to it in console
- prompt_schedules = [[[steps, text]]]
-
- flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step, prompt_text in flat_prompts]
- token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- style_class = ' class="red"' if (token_count > max_length) else ""
- return f"{token_count}/{max_length}"
-
-
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
-
- with gr.Row(elem_id="toprow"):
- with gr.Column(scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
- placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
- placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_id="interrogate_col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1):
- with gr.Row():
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
-
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
- if textinfo is None:
- textinfo = gr.HTML(visible=False)
-
- check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
- check_progress.click(
- fn=lambda: check_progress_call(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
- check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
- check_progress_initial.click(
- fn=lambda: check_progress_call_initial(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
-
-def apply_setting(key, value):
- if value is None:
- return gr.update()
-
- if shared.cmd_opts.freeze_settings:
- return gr.update()
-
- # dont allow model to be swapped when model hash exists in prompt
- if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
- return gr.update()
-
- if key == "sd_model_checkpoint":
- ckpt_info = sd_models.get_closet_checkpoint_match(value)
-
- if ckpt_info is not None:
- value = ckpt_info.title
- else:
- return gr.update()
-
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- return
-
- valtype = type(opts.data_labels[key].default)
- oldval = opts.data.get(key, None)
- opts.data[key] = valtype(value) if valtype != type(None) else value
- if oldval != value and opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
-
- opts.save(shared.config_filename)
- return value
-
-
-def update_generation_info(args):
- generation_info, html_info, img_index = args
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info
-
-
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
- for k, v in args.items():
- setattr(refresh_component, k, v)
-
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn=refresh,
- inputs=[],
- outputs=[refresh_component]
- )
- return refresh_button
-
-
-def create_output_panel(tabname, outdir):
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel'):
- with gr.Group():
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- save_zip.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
-
-
-def create_sampler_and_steps_selection(choices, tabname):
- if opts.samplers_in_dropdown:
- with FormRow(elem_id=f"sampler_selection_{tabname}"):
- sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- else:
- with FormGroup(elem_id=f"sampler_selection_{tabname}"):
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
-
- return steps, sampler_index
-
-
-def ordered_ui_categories():
- user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
-
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
- yield category
-
-
-def create_ui():
- import modules.img2img
- import modules.txt2img
-
- reload_javascript()
-
- parameters_copypaste.reset()
-
- modules.scripts.scripts_current = modules.scripts.scripts_txt2img
- modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
-
- with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
-
- dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
- for category in ordered_ui_categories():
- if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
-
- elif category == "dimensions":
- with FormRow():
- with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
-
- if opts.dimensions_and_batch_together:
- with gr.Column(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "cfg":
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
-
- elif category == "seed":
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
-
- elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
- enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
- hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
-
- elif category == "hires_fix":
- with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
-
- with FormRow(elem_id="txt2img_hires_fix_row2"):
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
- hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
-
- elif category == "batch":
- if not opts.dimensions_and_batch_together:
- with FormRow(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "scripts":
- with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
-
- hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- for input in hr_resolution_preview_inputs:
- input.change(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False,
- )
- input.change(
- None,
- _js="onCalcResolutionHires",
- inputs=hr_resolution_preview_inputs,
- outputs=[],
- show_progress=False,
- )
-
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
- parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
-
- connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
- connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
-
- txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
- _js="submit",
- inputs=[
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
- steps,
- sampler_index,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- enable_hr,
- denoising_strength,
- hr_scale,
- hr_upscaler,
- hr_second_pass_steps,
- hr_resize_x,
- hr_resize_y,
- ] + custom_inputs,
-
- outputs=[
- txt2img_gallery,
- generation_info,
- html_info,
- html_log,
- ],
- show_progress=False,
- )
-
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
-
- txt_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- txt_prompt_img
- ],
- outputs=[
- txt2img_prompt,
- txt_prompt_img
- ]
- )
-
- enable_hr.change(
- fn=lambda x: gr_show(x),
- inputs=[enable_hr],
- outputs=[hr_options],
- show_progress = False,
- )
-
- txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_index, "Sampler"),
- (restore_faces, "Face restoration"),
- (cfg_scale, "CFG scale"),
- (seed, "Seed"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d),
- (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- *modules.scripts.scripts_txt2img.infotext_fields
- ]
- parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
-
- txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
- steps,
- sampler_index,
- cfg_scale,
- seed,
- width,
- height,
- ]
-
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
-
- modules.scripts.scripts_current = modules.scripts.scripts_img2img
- modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
-
- with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
-
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
-
- with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
-
- with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
- init_img_with_mask_orig = gr.State(None)
-
- use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
- if use_color_sketch:
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
-
- init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
-
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
-
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
-
- with FormRow():
- mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
- hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
"""
+
+ image = gr_show(False)
+ preview_visibility = gr_show(False)
+
+ if opts.show_progress_every_n_steps != 0:
+ shared.state.set_current_image()
+ image = shared.state.current_image
+
+ if image is None:
+ image = gr.update(value=None)
+ else:
+ preview_visibility = gr_show(True)
+
+ if shared.state.textinfo is not None:
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
+ else:
+ textinfo_result = gr_show(False)
+
+ return f"{time.time()}
{progressbar}
", preview_visibility, image, textinfo_result
+
+
+def check_progress_call_initial(id_part):
+ shared.state.job_count = -1
+ shared.state.current_latent = None
+ shared.state.current_image = None
+ shared.state.textinfo = None
+ shared.state.time_start = time.time()
+ shared.state.time_left_force_display = False
+
+ return check_progress_call(id_part)
+
+
+def visit(x, func, path=""):
+ if hasattr(x, 'children'):
+ for c in x.children:
+ visit(c, func, path)
+ elif x.label is not None:
+ func(path + "/" + str(x.label), x)
+
+
+def add_style(name: str, prompt: str, negative_prompt: str):
+ if name is None:
+ return [gr_show() for x in range(4)]
+
+ style = modules.styles.PromptStyle(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
+ # reserialize all styles every time we save them
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+
+
+def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
+ from modules import processing, devices
+
+ if not enable:
+ return ""
+
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
+
+ with devices.autocast():
+ p.init([""], [0], [0])
+
+ return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
+
+
+def apply_styles(prompt, prompt_neg, style1_name, style2_name):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
+
+
+def interrogate(image):
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
+
+ return gr_show(True) if prompt is None else prompt
+
+
+def interrogate_deepbooru(image):
+ prompt = deepbooru.model.tag(image)
+ return gr_show(True) if prompt is None else prompt
+
+
+def create_seed_inputs(target_interface):
+ with FormRow(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ seed.style(container=False)
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+
+ with gr.Group(elem_id=target_interface + '_subseed_show_box'):
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
+
+ # Components to show/hide based on the 'Extra' checkbox
+ seed_extras = []
+
+ with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
+ seed_extras.append(seed_extra_row_1)
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed.style(container=False)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
+
+ with FormRow(visible=False) as seed_extra_row_2:
+ seed_extras.append(seed_extra_row_2)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
+
+ random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
+ random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
+
+ def change_visibility(show):
+ return {comp: gr_show(show) for comp in seed_extras}
+
+ seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
+
+ return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
+
+
+
+def connect_clear_prompt(button):
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
+ button.click(
+ _js="clear_prompt",
+ fn=None,
+ inputs=[],
+ outputs=[],
+ )
+
+
+def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
+ """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
+ (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
+ was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
+ def copy_seed(gen_info_string: str, index):
+ res = -1
+
+ try:
+ gen_info = json.loads(gen_info_string)
+ index -= gen_info.get('index_of_first_image', 0)
+
+ if is_subseed and gen_info.get('subseed_strength', 0) > 0:
+ all_subseeds = gen_info.get('all_subseeds', [-1])
+ res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
+ else:
+ all_seeds = gen_info.get('all_seeds', [-1])
+ res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
+
+ except json.decoder.JSONDecodeError as e:
+ if gen_info_string != '':
+ print("Error parsing JSON generation info:", file=sys.stderr)
+ print(gen_info_string, file=sys.stderr)
+
+ return [res, gr_show(False)]
+
+ reuse_seed.click(
+ fn=copy_seed,
+ _js="(x, y) => [x, selected_gallery_index()]",
+ show_progress=False,
+ inputs=[generation_info, dummy_component],
+ outputs=[seed, dummy_component]
+ )
+
+
+def update_token_counter(text, steps):
+ try:
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
+
+ except Exception:
+ # a parsing error can happen here during typing, and we don't want to bother the user with
+ # messages related to it in console
+ prompt_schedules = [[[steps, text]]]
+
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
+ style_class = ' class="red"' if (token_count > max_length) else ""
+ return f"{token_count}/{max_length}"
+
+
+def create_toprow(is_img2img):
+ id_part = "img2img" if is_img2img else "txt2img"
+
+ with gr.Row(elem_id="toprow"):
+ with gr.Column(scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
+ with gr.Column(scale=1, elem_id="roll_col"):
+ paste = gr.Button(value=paste_symbol, elem_id="paste")
+ save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
+ prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
+ clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
+
+ button_interrogate = None
+ button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_id="interrogate_col"):
+ button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+
+ with gr.Column(scale=1):
+ with gr.Row():
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+ skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ with gr.Row():
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+
+ return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+
+
+def setup_progressbar(progressbar, preview, id_part, textinfo=None):
+ if textinfo is None:
+ textinfo = gr.HTML(visible=False)
+
+ check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
+ check_progress.click(
+ fn=lambda: check_progress_call(id_part),
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview, textinfo],
+ )
+
+ check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
+ check_progress_initial.click(
+ fn=lambda: check_progress_call_initial(id_part),
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview, textinfo],
+ )
+
+
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
+ # dont allow model to be swapped when model hash exists in prompt
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
+ return gr.update()
+
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return
+
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data.get(key, None)
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+ return value
+
+
+def update_generation_info(args):
+ generation_info, html_info, img_index = args
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info
+ return plaintext_to_html(generation_info["infotexts"][img_index])
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_output_panel(tabname, outdir):
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel'):
+ with gr.Group():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+ open_folder_button.click(
+ fn=lambda: open_folder(opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="(x, y) => [x, y, selected_gallery_index()]",
+ inputs=[generation_info, html_info],
+ outputs=[html_info],
+ preprocess=False
+ )
+
+ save.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ save_zip.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ else:
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_sampler_and_steps_selection(choices, tabname):
+ if opts.samplers_in_dropdown:
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
+ sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ else:
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+ return steps, sampler_index
+
+
+def ordered_ui_categories():
+ user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ yield category
+
+
+def create_ui():
+ import modules.img2img
+ import modules.txt2img
+
+ reload_javascript()
+
+ parameters_copypaste.reset()
+
+ modules.scripts.scripts_current = modules.scripts.scripts_txt2img
+ modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
+ txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+
+ dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+ with gr.Row(elem_id='txt2img_progress_row'):
+ with gr.Column(scale=1):
+ pass
+
+ with gr.Column(scale=1):
+ progressbar = gr.HTML(elem_id="txt2img_progressbar")
+ txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
+ setup_progressbar(progressbar, txt2img_preview, 'txt2img')
+
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "cfg":
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="txt2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
+
+ elif category == "hires_fix":
+ with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ with FormRow(elem_id="txt2img_hires_fix_row1"):
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ with FormRow(elem_id="txt2img_hires_fix_row2"):
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="txt2img_script_container"):
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
+ for input in hr_resolution_preview_inputs:
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
+
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
+
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+
+ txt2img_args = dict(
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
+ _js="submit",
+ inputs=[
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ txt2img_prompt_style,
+ txt2img_prompt_style2,
+ steps,
+ sampler_index,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ enable_hr,
+ denoising_strength,
+ hr_scale,
+ hr_upscaler,
+ hr_second_pass_steps,
+ hr_resize_x,
+ hr_resize_y,
+ ] + custom_inputs,
+
+ outputs=[
+ txt2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ txt2img_prompt.submit(**txt2img_args)
+ submit.click(**txt2img_args)
+
+ txt_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ txt_prompt_img
+ ],
+ outputs=[
+ txt2img_prompt,
+ txt_prompt_img
+ ]
+ )
+
+ enable_hr.change(
+ fn=lambda x: gr_show(x),
+ inputs=[enable_hr],
+ outputs=[hr_options],
+ show_progress = False,
+ )
+
+ txt2img_paste_fields = [
+ (txt2img_prompt, "Prompt"),
+ (txt2img_negative_prompt, "Negative prompt"),
+ (steps, "Steps"),
+ (sampler_index, "Sampler"),
+ (restore_faces, "Face restoration"),
+ (cfg_scale, "CFG scale"),
+ (seed, "Seed"),
+ (width, "Size-1"),
+ (height, "Size-2"),
+ (batch_size, "Batch size"),
+ (subseed, "Variation seed"),
+ (subseed_strength, "Variation seed strength"),
+ (seed_resize_from_w, "Seed resize from-1"),
+ (seed_resize_from_h, "Seed resize from-2"),
+ (denoising_strength, "Denoising strength"),
+ (enable_hr, lambda d: "Denoising strength" in d),
+ (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (hr_scale, "Hires upscale"),
+ (hr_upscaler, "Hires upscaler"),
+ (hr_second_pass_steps, "Hires steps"),
+ (hr_resize_x, "Hires resize-1"),
+ (hr_resize_y, "Hires resize-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
+ ]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
+ ]
+
+ token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+
+ modules.scripts.scripts_current = modules.scripts.scripts_img2img
+ modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
+ img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+
+ with gr.Row(elem_id='img2img_progress_row'):
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+ with gr.Column(scale=1):
+ pass
+
+ with gr.Column(scale=1):
+ progressbar = gr.HTML(elem_id="img2img_progressbar")
+ img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
+ setup_progressbar(progressbar, img2img_preview, 'img2img')
+
+ with FormRow().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="img2img_settings"):
+
+ with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
+
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
+ init_img_with_mask_orig = gr.State(None)
+
+ use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
+ if use_color_sketch:
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
+
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
+
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
+ hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
+ gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
", preview_visibility, image, textinfo_result
@@ -235,227 +80,6 @@ def check_progress_call_initial(id_part):
return check_progress_call(id_part)
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(path + "/" + str(x.label), x)
-
-
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
-
-
-def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
- from modules import processing, devices
-
- if not enable:
- return ""
-
- p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
-
- with devices.autocast():
- p.init([""], [0], [0])
-
- return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
-
-
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
-
-
-def interrogate(image):
- prompt = shared.interrogator.interrogate(image.convert("RGB"))
-
- return gr_show(True) if prompt is None else prompt
-
-
-def interrogate_deepbooru(image):
- prompt = deepbooru.model.tag(image)
- return gr_show(True) if prompt is None else prompt
-
-
-def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
-
- with gr.Group(elem_id=target_interface + '_subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
-
- # Components to show/hide based on the 'Extra' checkbox
- seed_extras = []
-
- with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
- seed_extras.append(seed_extra_row_1)
- subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
-
- with FormRow(visible=False) as seed_extra_row_2:
- seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
-
- random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
- random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
-
- def change_visibility(show):
- return {comp: gr_show(show) for comp in seed_extras}
-
- seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
-
- return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-
-
-
-def connect_clear_prompt(button):
- """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
- button.click(
- _js="clear_prompt",
- fn=None,
- inputs=[],
- outputs=[],
- )
-
-
-def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
- """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
- (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
- was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
- def copy_seed(gen_info_string: str, index):
- res = -1
-
- try:
- gen_info = json.loads(gen_info_string)
- index -= gen_info.get('index_of_first_image', 0)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
-
- except json.decoder.JSONDecodeError as e:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
-
- return [res, gr_show(False)]
-
- reuse_seed.click(
- fn=copy_seed,
- _js="(x, y) => [x, selected_gallery_index()]",
- show_progress=False,
- inputs=[generation_info, dummy_component],
- outputs=[seed, dummy_component]
- )
-
-
-def update_token_counter(text, steps):
- try:
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
- prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
-
- except Exception:
- # a parsing error can happen here during typing, and we don't want to bother the user with
- # messages related to it in console
- prompt_schedules = [[[steps, text]]]
-
- flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step, prompt_text in flat_prompts]
- token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- style_class = ' class="red"' if (token_count > max_length) else ""
- return f"{token_count}/{max_length}"
-
-
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
-
- with gr.Row(elem_id="toprow"):
- with gr.Column(scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
- placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
- placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_id="interrogate_col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1):
- with gr.Row():
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
-
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
-
-
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
if textinfo is None:
textinfo = gr.HTML(visible=False)
@@ -475,1454 +99,3 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None):
inputs=[],
outputs=[progressbar, preview, preview, textinfo],
)
-
-
-def apply_setting(key, value):
- if value is None:
- return gr.update()
-
- if shared.cmd_opts.freeze_settings:
- return gr.update()
-
- # dont allow model to be swapped when model hash exists in prompt
- if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
- return gr.update()
-
- if key == "sd_model_checkpoint":
- ckpt_info = sd_models.get_closet_checkpoint_match(value)
-
- if ckpt_info is not None:
- value = ckpt_info.title
- else:
- return gr.update()
-
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- return
-
- valtype = type(opts.data_labels[key].default)
- oldval = opts.data.get(key, None)
- opts.data[key] = valtype(value) if valtype != type(None) else value
- if oldval != value and opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
-
- opts.save(shared.config_filename)
- return value
-
-
-def update_generation_info(args):
- generation_info, html_info, img_index = args
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info
-
-
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
- for k, v in args.items():
- setattr(refresh_component, k, v)
-
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn=refresh,
- inputs=[],
- outputs=[refresh_component]
- )
- return refresh_button
-
-
-def create_output_panel(tabname, outdir):
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel'):
- with gr.Group():
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- save_zip.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
-
-
-def create_sampler_and_steps_selection(choices, tabname):
- if opts.samplers_in_dropdown:
- with FormRow(elem_id=f"sampler_selection_{tabname}"):
- sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- else:
- with FormGroup(elem_id=f"sampler_selection_{tabname}"):
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
-
- return steps, sampler_index
-
-
-def ordered_ui_categories():
- user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
-
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
- yield category
-
-
-def create_ui():
- import modules.img2img
- import modules.txt2img
-
- reload_javascript()
-
- parameters_copypaste.reset()
-
- modules.scripts.scripts_current = modules.scripts.scripts_txt2img
- modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
-
- with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
-
- dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
- for category in ordered_ui_categories():
- if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
-
- elif category == "dimensions":
- with FormRow():
- with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
-
- if opts.dimensions_and_batch_together:
- with gr.Column(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "cfg":
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
-
- elif category == "seed":
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
-
- elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
- enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
- hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
-
- elif category == "hires_fix":
- with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
-
- with FormRow(elem_id="txt2img_hires_fix_row2"):
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
- hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
-
- elif category == "batch":
- if not opts.dimensions_and_batch_together:
- with FormRow(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "scripts":
- with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
-
- hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- for input in hr_resolution_preview_inputs:
- input.change(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False,
- )
- input.change(
- None,
- _js="onCalcResolutionHires",
- inputs=hr_resolution_preview_inputs,
- outputs=[],
- show_progress=False,
- )
-
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
- parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
-
- connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
- connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
-
- txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
- _js="submit",
- inputs=[
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
- steps,
- sampler_index,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- enable_hr,
- denoising_strength,
- hr_scale,
- hr_upscaler,
- hr_second_pass_steps,
- hr_resize_x,
- hr_resize_y,
- ] + custom_inputs,
-
- outputs=[
- txt2img_gallery,
- generation_info,
- html_info,
- html_log,
- ],
- show_progress=False,
- )
-
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
-
- txt_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- txt_prompt_img
- ],
- outputs=[
- txt2img_prompt,
- txt_prompt_img
- ]
- )
-
- enable_hr.change(
- fn=lambda x: gr_show(x),
- inputs=[enable_hr],
- outputs=[hr_options],
- show_progress = False,
- )
-
- txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_index, "Sampler"),
- (restore_faces, "Face restoration"),
- (cfg_scale, "CFG scale"),
- (seed, "Seed"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d),
- (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- *modules.scripts.scripts_txt2img.infotext_fields
- ]
- parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
-
- txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
- steps,
- sampler_index,
- cfg_scale,
- seed,
- width,
- height,
- ]
-
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
-
- modules.scripts.scripts_current = modules.scripts.scripts_img2img
- modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
-
- with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
-
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
-
- with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
-
- with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
- init_img_with_mask_orig = gr.State(None)
-
- use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
- if use_color_sketch:
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
-
- init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
-
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
-
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
-
- with FormRow():
- mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
- hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
")
-
- with gr.Row().style(equal_height=False):
- with gr.Tabs(elem_id="train_tabs"):
-
- with gr.Tab(label="Create embedding"):
- new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
- initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
- nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
- overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
-
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
-
- with gr.Tab(label="Create hypernetwork"):
- new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
- new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
- new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
- new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
- new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
- new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
- new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
- new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
- overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
-
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
-
- with gr.Tab(label="Preprocess images"):
- process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
- process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
- process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
- process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
- preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
-
- with gr.Row():
- process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
- process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
- process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
- process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
-
- with gr.Row(visible=False) as process_split_extra_row:
- process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
- process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
-
- with gr.Row(visible=False) as process_focal_crop_row:
- process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
- process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
- process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
- process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
- run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
-
- process_split.change(
- fn=lambda show: gr_show(show),
- inputs=[process_split],
- outputs=[process_split_extra_row],
- )
-
- process_focal_crop.change(
- fn=lambda show: gr_show(show),
- inputs=[process_focal_crop],
- outputs=[process_focal_crop_row],
- )
-
- def get_textual_inversion_template_names():
- return sorted([x for x in textual_inversion.textual_inversion_templates])
-
- with gr.Tab(label="Train"):
- gr.HTML(value="
Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
- with FormRow():
- train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
- create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
-
- train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
- create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
-
- with FormRow():
- embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
- hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
-
- with FormRow():
- clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
- clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
-
- with FormRow():
- batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
- gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
-
- dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
- log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
-
- with FormRow():
- template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
- create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
-
- training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
- training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
- varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
- steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
-
- with FormRow():
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
-
- save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
- preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
-
- shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
- tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
-
- latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
-
- with gr.Row():
- train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
- interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
- train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
-
- params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
-
- script_callbacks.ui_train_tabs_callback(params)
-
- with gr.Column():
- progressbar = gr.HTML(elem_id="ti_progressbar")
- ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-
- ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
- ti_preview = gr.Image(elem_id='ti_preview', visible=False)
- ti_progress = gr.HTML(elem_id="ti_progress", value="")
- ti_outcome = gr.HTML(elem_id="ti_error", value="")
- setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
-
- create_embedding.click(
- fn=modules.textual_inversion.ui.create_embedding,
- inputs=[
- new_embedding_name,
- initialization_text,
- nvpt,
- overwrite_old_embedding,
- ],
- outputs=[
- train_embedding_name,
- ti_output,
- ti_outcome,
- ]
- )
-
- create_hypernetwork.click(
- fn=modules.hypernetworks.ui.create_hypernetwork,
- inputs=[
- new_hypernetwork_name,
- new_hypernetwork_sizes,
- overwrite_old_hypernetwork,
- new_hypernetwork_layer_structure,
- new_hypernetwork_activation_func,
- new_hypernetwork_initialization_option,
- new_hypernetwork_add_layer_norm,
- new_hypernetwork_use_dropout,
- new_hypernetwork_dropout_structure
- ],
- outputs=[
- train_hypernetwork_name,
- ti_output,
- ti_outcome,
- ]
- )
-
- run_preprocess.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
- _js="start_training_textual_inversion",
- inputs=[
- process_src,
- process_dst,
- process_width,
- process_height,
- preprocess_txt_action,
- process_flip,
- process_split,
- process_caption,
- process_caption_deepbooru,
- process_split_threshold,
- process_overlap_ratio,
- process_focal_crop,
- process_focal_crop_face_weight,
- process_focal_crop_entropy_weight,
- process_focal_crop_edges_weight,
- process_focal_crop_debug,
- ],
- outputs=[
- ti_output,
- ti_outcome,
- ],
- )
-
- train_embedding.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
- _js="start_training_textual_inversion",
- inputs=[
- train_embedding_name,
- embedding_learn_rate,
- batch_size,
- gradient_step,
- dataset_directory,
- log_directory,
- training_width,
- training_height,
- varsize,
- steps,
- clip_grad_mode,
- clip_grad_value,
- shuffle_tags,
- tag_drop_out,
- latent_sampling_method,
- create_image_every,
- save_embedding_every,
- template_file,
- save_image_with_stored_embedding,
- preview_from_txt2img,
- *txt2img_preview_params,
- ],
- outputs=[
- ti_output,
- ti_outcome,
- ]
- )
-
- train_hypernetwork.click(
- fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
- _js="start_training_textual_inversion",
- inputs=[
- train_hypernetwork_name,
- hypernetwork_learn_rate,
- batch_size,
- gradient_step,
- dataset_directory,
- log_directory,
- training_width,
- training_height,
- varsize,
- steps,
- clip_grad_mode,
- clip_grad_value,
- shuffle_tags,
- tag_drop_out,
- latent_sampling_method,
- create_image_every,
- save_embedding_every,
- template_file,
- preview_from_txt2img,
- *txt2img_preview_params,
- ],
- outputs=[
- ti_output,
- ti_outcome,
- ]
- )
-
- interrupt_training.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- interrupt_preprocessing.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- def create_setting_component(key, is_quicksettings=False):
- def fun():
- return opts.data[key] if key in opts.data else opts.data_labels[key].default
-
- info = opts.data_labels[key]
- t = type(info.default)
-
- args = info.component_args() if callable(info.component_args) else info.component_args
-
- if info.component is not None:
- comp = info.component
- elif t == str:
- comp = gr.Textbox
- elif t == int:
- comp = gr.Number
- elif t == bool:
- comp = gr.Checkbox
- else:
- raise Exception(f'bad options item type: {str(t)} for key {key}')
-
- elem_id = "setting_"+key
-
- if info.refresh is not None:
- if is_quicksettings:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
- else:
- with FormRow():
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
- else:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
-
- return res
-
- components = []
- component_dict = {}
-
- script_callbacks.ui_settings_callback()
- opts.reorder()
-
- def run_settings(*args):
- changed = []
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp == dummy_component:
- continue
-
- if opts.set(key, value):
- changed.append(key)
-
- try:
- opts.save(shared.config_filename)
- except RuntimeError:
- return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
- return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
-
- def run_settings_single(value, key):
- if not opts.same_type(value, opts.data_labels[key].default):
- return gr.update(visible=True), opts.dumpjson()
-
- if not opts.set(key, value):
- return gr.update(value=getattr(opts, key)), opts.dumpjson()
-
- opts.save(shared.config_filename)
-
- return gr.update(value=value), opts.dumpjson()
-
- with gr.Blocks(analytics_enabled=False) as settings_interface:
- with gr.Row():
- with gr.Column(scale=6):
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- with gr.Column():
- restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
-
- result = gr.HTML(elem_id="settings_result")
-
- quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
- quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
-
- quicksettings_list = []
-
- previous_section = None
- current_tab = None
- with gr.Tabs(elem_id="settings"):
- for i, (k, item) in enumerate(opts.data_labels.items()):
- section_must_be_skipped = item.section[0] is None
-
- if previous_section != item.section and not section_must_be_skipped:
- elem_id, text = item.section
-
- if current_tab is not None:
- current_tab.__exit__()
-
- current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
- current_tab.__enter__()
-
- previous_section = item.section
-
- if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
- quicksettings_list.append((i, k, item))
- components.append(dummy_component)
- elif section_must_be_skipped:
- components.append(dummy_component)
- else:
- component = create_setting_component(k)
- component_dict[k] = component
- components.append(component)
-
- if current_tab is not None:
- current_tab.__exit__()
-
- with gr.TabItem("Actions"):
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
- download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
-
- if os.path.exists("html/licenses.html"):
- with open("html/licenses.html", encoding="utf8") as file:
- with gr.TabItem("Licenses"):
- gr.HTML(file.read(), elem_id="licenses")
-
- gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
-
- request_notifications.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='function(){}'
- )
-
- download_localization.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='download_localization'
- )
-
- def reload_scripts():
- modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
-
- reload_script_bodies.click(
- fn=reload_scripts,
- inputs=[],
- outputs=[]
- )
-
- def request_restart():
- shared.state.interrupt()
- shared.state.need_restart = True
-
- restart_gradio.click(
- fn=request_restart,
- _js='restart_reload',
- inputs=[],
- outputs=[],
- )
-
- interfaces = [
- (txt2img_interface, "txt2img", "txt2img"),
- (img2img_interface, "img2img", "img2img"),
- (extras_interface, "Extras", "extras"),
- (pnginfo_interface, "PNG Info", "pnginfo"),
- (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
- (train_interface, "Train", "ti"),
- ]
-
- css = ""
-
- for cssfile in modules.scripts.list_files_with_name("style.css"):
- if not os.path.isfile(cssfile):
- continue
-
- with open(cssfile, "r", encoding="utf8") as file:
- css += file.read() + "\n"
-
- if os.path.exists(os.path.join(script_path, "user.css")):
- with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- css += file.read() + "\n"
-
- if not cmd_opts.no_progressbar_hiding:
- css += css_hide_progressbar
-
- interfaces += script_callbacks.ui_tabs_callback()
- interfaces += [(settings_interface, "Settings", "settings")]
-
- extensions_interface = ui_extensions.create_ui()
- interfaces += [(extensions_interface, "Extensions", "extensions")]
-
- with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings"):
- for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
- component = create_setting_component(k, is_quicksettings=True)
- component_dict[k] = component
-
- parameters_copypaste.integrate_settings_paste_fields(component_dict)
- parameters_copypaste.run_bind()
-
- with gr.Tabs(elem_id="tabs") as tabs:
- for interface, label, ifid in interfaces:
- with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
- interface.render()
-
- if os.path.exists(os.path.join(script_path, "notification.mp3")):
- audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
-
- if os.path.exists("html/footer.html"):
- with open("html/footer.html", encoding="utf8") as file:
- footer = file.read()
- footer = footer.format(versions=versions_html())
- gr.HTML(footer, elem_id="footer")
-
- text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
- settings_submit.click(
- fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
- inputs=components,
- outputs=[text_settings, result],
- )
-
- for i, k, item in quicksettings_list:
- component = component_dict[k]
-
- component.change(
- fn=lambda value, k=k: run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, text_settings],
- )
-
- component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
-
- def get_settings_values():
- return [getattr(opts, key) for key in component_keys]
-
- demo.load(
- fn=get_settings_values,
- inputs=[],
- outputs=[component_dict[k] for k in component_keys],
- )
-
- def modelmerger(*args):
- try:
- results = modules.extras.run_modelmerger(*args)
- except Exception as e:
- print("Error loading/saving model file:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- modules.sd_models.list_models() # to remove the potentially missing models from the list
- return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
- return results
-
- modelmerger_merge.click(
- fn=modelmerger,
- inputs=[
- primary_model_name,
- secondary_model_name,
- tertiary_model_name,
- interp_method,
- interp_amount,
- save_as_half,
- custom_name,
- checkpoint_format,
- ],
- outputs=[
- submit_result,
- primary_model_name,
- secondary_model_name,
- tertiary_model_name,
- component_dict['sd_model_checkpoint'],
- ]
- )
-
- ui_config_file = cmd_opts.ui_config_file
- ui_settings = {}
- settings_count = len(ui_settings)
- error_loading = False
-
- try:
- if os.path.exists(ui_config_file):
- with open(ui_config_file, "r", encoding="utf8") as file:
- ui_settings = json.load(file)
- except Exception:
- error_loading = True
- print("Error loading settings:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def loadsave(path, x):
- def apply_field(obj, field, condition=None, init_field=None):
- key = path + "/" + field
-
- if getattr(obj, 'custom_script_source', None) is not None:
- key = 'customscript/' + obj.custom_script_source + '/' + key
-
- if getattr(obj, 'do_not_save_to_config', False):
- return
-
- saved_value = ui_settings.get(key, None)
- if saved_value is None:
- ui_settings[key] = getattr(obj, field)
- elif condition and not condition(saved_value):
- print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
- else:
- setattr(obj, field, saved_value)
- if init_field is not None:
- init_field(saved_value)
-
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
- apply_field(x, 'visible')
-
- if type(x) == gr.Slider:
- apply_field(x, 'value')
- apply_field(x, 'minimum')
- apply_field(x, 'maximum')
- apply_field(x, 'step')
-
- if type(x) == gr.Radio:
- apply_field(x, 'value', lambda val: val in x.choices)
-
- if type(x) == gr.Checkbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Textbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Number:
- apply_field(x, 'value')
-
- if type(x) == gr.Dropdown:
- apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
-
- visit(txt2img_interface, loadsave, "txt2img")
- visit(img2img_interface, loadsave, "img2img")
- visit(extras_interface, loadsave, "extras")
- visit(modelmerger_interface, loadsave, "modelmerger")
- visit(train_interface, loadsave, "train")
-
- if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
- with open(ui_config_file, "w", encoding="utf8") as file:
- json.dump(ui_settings, file, indent=4)
-
- return demo
-
-
-def reload_javascript():
- with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
- javascript = f''
-
- scripts_list = modules.scripts.list_scripts("javascript", ".js")
-
- for basedir, filename, path in scripts_list:
- with open(path, "r", encoding="utf8") as jsfile:
- javascript += f"\n"
-
- if cmd_opts.theme is not None:
- javascript += f"\n\n"
-
- javascript += f"\n"
-
- def template_response(*args, **kwargs):
- res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(
- b'', f'{javascript}'.encode("utf8"))
- res.init_headers()
- return res
-
- gradio.routes.templates.TemplateResponse = template_response
-
-
-if not hasattr(shared, 'GradioTemplateResponseOriginal'):
- shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
-
-
-def versions_html():
- import torch
- import launch
-
- python_version = ".".join([str(x) for x in sys.version_info[0:3]])
- commit = launch.commit_hash()
- short_commit = commit[0:8]
-
- if shared.xformers_available:
- import xformers
- xformers_version = xformers.__version__
- else:
- xformers_version = "N/A"
-
- return f"""
-python: {python_version}
- •
-torch: {torch.__version__}
- •
-xformers: {xformers_version}
- •
-gradio: {gr.__version__}
- •
-commit: {short_commit}
-"""
--
cgit v1.2.3
From 76a21b9626b7556638db188c157e3e8036803326 Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Tue, 10 Jan 2023 12:46:35 +0300
Subject: clear envvar, add assertion
---
launch.py | 1 +
test/basic_features/txt2img_test.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/launch.py b/launch.py
index 49b91b1f..bcbb792c 100644
--- a/launch.py
+++ b/launch.py
@@ -282,6 +282,7 @@ def tests(test_dir):
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
+ os.environ['COMMANDLINE_ARGS'] = ""
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py
index 5b27a7ec..5aa43a44 100644
--- a/test/basic_features/txt2img_test.py
+++ b/test/basic_features/txt2img_test.py
@@ -43,6 +43,7 @@ class TestTxt2ImgWorking(unittest.TestCase):
def test_txt2img_with_complex_prompt_performed(self):
self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
+ self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_not_square_image_performed(self):
self.simple_txt2img["height"] = 128
--
cgit v1.2.3
From 0c3feb202c5714abd50d879c1db2cd9a71ce93e3 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 10 Jan 2023 14:08:29 +0300
Subject: disable torch weight initialization and CLIP downloading/reading
checkpoint to speedup creating sd model from config
---
modules/sd_disable_initialization.py | 44 ++++++++++++++++++++++++++++++++++++
modules/sd_models.py | 5 ++--
2 files changed, 47 insertions(+), 2 deletions(-)
create mode 100644 modules/sd_disable_initialization.py
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
new file mode 100644
index 00000000..c9a3b5e4
--- /dev/null
+++ b/modules/sd_disable_initialization.py
@@ -0,0 +1,44 @@
+import ldm.modules.encoders.modules
+import open_clip
+import torch
+
+
+class DisableInitialization:
+ """
+ When an object of this class enters a `with` block, it starts preventing torch's layer initialization
+ functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves,
+ reverts everything to how it was.
+
+ Use like this:
+ ```
+ with DisableInitialization():
+ do_things()
+ ```
+ """
+
+ def __enter__(self):
+ def do_nothing(*args, **kwargs):
+ pass
+
+ def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
+ return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
+
+ def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
+ return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+
+ self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
+ self.init_no_grad_normal = torch.nn.init._no_grad_normal_
+ self.create_model_and_transforms = open_clip.create_model_and_transforms
+ self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
+
+ torch.nn.init.kaiming_uniform_ = do_nothing
+ torch.nn.init._no_grad_normal_ = do_nothing
+ open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
+ ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
+ torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
+ open_clip.create_model_and_transforms = self.create_model_and_transforms
+ ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
+
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 0a6d55ca..ee241032 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -13,7 +13,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
@@ -319,7 +319,8 @@ def load_model(checkpoint_info=None):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
- sd_model = instantiate_from_config(sd_config.model)
+ with sd_disable_initialization.DisableInitialization():
+ sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
--
cgit v1.2.3
From ce3f639ec8758ce2bc90483336361d2dc25acd3a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 10 Jan 2023 16:51:04 +0300
Subject: add more stuff to ignore when creating model from config prevent
.vae.safetensors files from being listed as stable diffusion models
---
modules/modelloader.py | 4 +++-
modules/sd_disable_initialization.py | 29 +++++++++++++++++++++++++----
modules/sd_models.py | 32 ++++++++++++++++++++++++++++----
3 files changed, 56 insertions(+), 9 deletions(-)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 6a1a7ac8..e9aa514e 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
-def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
+def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path = file
if os.path.isdir(full_path):
continue
+ if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
+ continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index c9a3b5e4..9942bd7e 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -1,15 +1,19 @@
import ldm.modules.encoders.modules
import open_clip
import torch
+import transformers.utils.hub
class DisableInitialization:
"""
- When an object of this class enters a `with` block, it starts preventing torch's layer initialization
- functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves,
- reverts everything to how it was.
+ When an object of this class enters a `with` block, it starts:
+ - preventing torch's layer initialization functions from working
+ - changes CLIP and OpenCLIP to not download model weights
+ - changes CLIP to not make requests to check if there is a new version of a file you already have
- Use like this:
+ When it leaves the block, it reverts everything to how it was before.
+
+ Use it like this:
```
with DisableInitialization():
do_things()
@@ -26,19 +30,36 @@ class DisableInitialization:
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
+
+ # this file is always 404, prevent making request
+ if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
+ raise transformers.utils.hub.EntryNotFoundError
+
+ try:
+ return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs)
+ except Exception as e:
+ return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs)
+
self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
self.init_no_grad_normal = torch.nn.init._no_grad_normal_
+ self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
self.create_model_and_transforms = open_clip.create_model_and_transforms
self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
+ self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache
torch.nn.init.kaiming_uniform_ = do_nothing
torch.nn.init._no_grad_normal_ = do_nothing
+ torch.nn.init._no_grad_uniform_ = do_nothing
open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
+ transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
def __exit__(self, exc_type, exc_val, exc_tb):
torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
+ torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
open_clip.create_model_and_transforms = self.create_model_and_transforms
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
+ transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ee241032..1bb9088b 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -2,6 +2,7 @@ import collections
import os.path
import sys
import gc
+import time
from collections import namedtuple
import torch
import re
@@ -61,7 +62,7 @@ def find_checkpoint_config(info):
def list_models():
checkpoints_list.clear()
- model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
+ model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
@@ -288,6 +289,17 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+class Timer:
+ def __init__(self):
+ self.start = time.time()
+
+ def elapsed(self):
+ end = time.time()
+ res = end - self.start
+ self.start = end
+ return res
+
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -319,11 +331,17 @@ def load_model(checkpoint_info=None):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
+ timer = Timer()
+
with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model)
+ elapsed_create = timer.elapsed()
+
load_model_weights(sd_model, checkpoint_info)
+ elapsed_load_weights = timer.elapsed()
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
@@ -338,7 +356,9 @@ def load_model(checkpoint_info=None):
script_callbacks.model_loaded_callback(sd_model)
- print("Model loaded.")
+ elapsed_the_rest = timer.elapsed()
+
+ print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
return sd_model
@@ -349,7 +369,7 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model:
sd_model = shared.sd_model
- if sd_model is None: # previous model load failed
+ if sd_model is None: # previous model load failed
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
@@ -371,6 +391,8 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
+ timer = Timer()
+
try:
load_model_weights(sd_model, checkpoint_info)
except Exception as e:
@@ -384,6 +406,8 @@ def reload_model_weights(sd_model=None, info=None):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
- print("Weights loaded.")
+ elapsed = timer.elapsed()
+
+ print(f"Weights loaded in {elapsed:.1f}s.")
return sd_model
--
cgit v1.2.3
From 0f8603a55988d22616b17140e6c4a7e9d0736af5 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 10 Jan 2023 17:46:59 +0300
Subject: add support for transformers==4.25.1 add fallback for when quick
model creation fails
---
modules/sd_disable_initialization.py | 42 ++++++++++++++++++++++++++++++------
modules/sd_models.py | 8 +++++--
2 files changed, 42 insertions(+), 8 deletions(-)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 9942bd7e..088ac24b 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -30,30 +30,53 @@ class DisableInitialization:
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
- def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
+ def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
+ args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
+ return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
+
+ def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
# this file is always 404, prevent making request
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
raise transformers.utils.hub.EntryNotFoundError
try:
- return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs)
+ return original(url, *args, local_files_only=True, **kwargs)
except Exception as e:
- return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs)
+ return original(url, *args, local_files_only=False, **kwargs)
+
+ def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
+ return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
+
+ def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
+ return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
+
+ def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
+ return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
self.init_no_grad_normal = torch.nn.init._no_grad_normal_
self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
self.create_model_and_transforms = open_clip.create_model_and_transforms
self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
- self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache
+ self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None)
+ self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None)
+ self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None)
+ self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None)
torch.nn.init.kaiming_uniform_ = do_nothing
torch.nn.init._no_grad_normal_ = do_nothing
torch.nn.init._no_grad_uniform_ = do_nothing
open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
- transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
+ if self.transformers_modeling_utils_load_pretrained_model is not None:
+ transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model
+ if self.transformers_tokenization_utils_base_cached_file is not None:
+ transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file
+ if self.transformers_configuration_utils_cached_file is not None:
+ transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file
+ if self.transformers_utils_hub_get_from_cache is not None:
+ transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
def __exit__(self, exc_type, exc_val, exc_tb):
torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
@@ -61,5 +84,12 @@ class DisableInitialization:
torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
open_clip.create_model_and_transforms = self.create_model_and_transforms
ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
- transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
+ if self.transformers_modeling_utils_load_pretrained_model is not None:
+ transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model
+ if self.transformers_tokenization_utils_base_cached_file is not None:
+ transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file
+ if self.transformers_configuration_utils_cached_file is not None:
+ transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file
+ if self.transformers_utils_hub_get_from_cache is not None:
+ transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 1bb9088b..b5bc12f0 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
@@ -333,7 +333,11 @@ def load_model(checkpoint_info=None):
timer = Timer()
- with sd_disable_initialization.DisableInitialization():
+ try:
+ with sd_disable_initialization.DisableInitialization():
+ sd_model = instantiate_from_config(sd_config.model)
+ except Exception as e:
+ print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)
elapsed_create = timer.elapsed()
--
cgit v1.2.3
From e2c8584f753b6fe8116f3032360a3b02e8398349 Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Tue, 10 Jan 2023 22:26:49 +0300
Subject: make VENV envvar accept absolute path instead of relative
---
webui.bat | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/webui.bat b/webui.bat
index d4d626e2..3a3e310a 100644
--- a/webui.bat
+++ b/webui.bat
@@ -1,7 +1,7 @@
@echo off
if not defined PYTHON (set PYTHON=python)
-if not defined VENV_DIR (set VENV_DIR=venv)
+if not defined VENV_DIR (set VENV_DIR=%~dp0\venv)
set ERROR_REPORTING=FALSE
@@ -26,7 +26,7 @@ echo Unable to create venv in directory %VENV_DIR%
goto :show_stdout_stderr
:activate_venv
-set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
+set PYTHON="%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON%
if [%ACCELERATE%] == ["True"] goto :accelerate
goto :launch
@@ -35,7 +35,7 @@ goto :launch
:accelerate
echo "Checking for accelerate"
-set ACCELERATE="%~dp0%VENV_DIR%\Scripts\accelerate.exe"
+set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
if EXIST %ACCELERATE% goto :accelerate_launch
:launch
--
cgit v1.2.3
From 29fb5327640465fc83111e2170c5d8aa2b15266c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 10 Jan 2023 23:47:02 +0300
Subject: change color selector in settings to be part of form
---
modules/shared.py | 4 ++--
modules/ui_components.py | 6 ++++++
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index aa37c8ce..264264a6 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading, errors
+from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
from modules.paths import models_path, script_path, sd_path
@@ -387,7 +387,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
- "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
+ "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
diff --git a/modules/ui_components.py b/modules/ui_components.py
index cac001dc..97acff06 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -31,3 +31,9 @@ class FormHTML(gr.HTML, gr.components.FormComponent):
def get_block_name(self):
return "html"
+
+class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
+ """Same as gr.ColorPicker but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "colorpicker"
--
cgit v1.2.3
From 6be644fa04ce1542f3a01804310cbbc0a4a91620 Mon Sep 17 00:00:00 2001
From: dan
Date: Wed, 11 Jan 2023 05:31:58 +0800
Subject: Enable batch_size>1 for mixed-sized training
---
modules/textual_inversion/dataset.py | 36 ++++++++++++++++++++++++++++++++----
1 file changed, 32 insertions(+), 4 deletions(-)
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index fa48708e..b47414f3 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -3,8 +3,10 @@ import numpy as np
import PIL
import torch
from PIL import Image
-from torch.utils.data import Dataset, DataLoader
+from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
+from collections import defaultdict
+from random import shuffle, choices
import random
import tqdm
@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
+ groups = defaultdict(list)
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
@@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
-
+ groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
self.length = len(self.dataset)
+ self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
@@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
+class GroupedBatchSampler(Sampler):
+ def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ n = len(data_source)
+ self.groups = data_source.groups
+ self.len = n_batch = n // batch_size
+ expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
+ self.base = [int(e) // batch_size for e in expected]
+ self.n_rand_batches = nrb = n_batch - sum(self.base)
+ self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
+ self.batch_size = batch_size
+ def __len__(self):
+ return self.len
+ def __iter__(self):
+ b = self.batch_size
+ for g in self.groups:
+ shuffle(g)
+ batches = []
+ for g in self.groups:
+ batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
+ for _ in range(self.n_rand_batches):
+ rand_group = choices(self.groups, self.probs)[0]
+ batches.append(choices(rand_group, k=b))
+ shuffle(batches)
+ yield from batches
+
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
- super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
--
cgit v1.2.3
From 9cfd10cdefc7b2966b8e42fbb0e05735967cf87b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 11 Jan 2023 01:27:00 +0300
Subject: repair #6612
---
webui.bat | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/webui.bat b/webui.bat
index 3a3e310a..e6a7a429 100644
--- a/webui.bat
+++ b/webui.bat
@@ -1,7 +1,7 @@
@echo off
if not defined PYTHON (set PYTHON=python)
-if not defined VENV_DIR (set VENV_DIR=%~dp0\venv)
+if not defined VENV_DIR (set VENV_DIR=%~dp0%venv)
set ERROR_REPORTING=FALSE
@@ -13,16 +13,16 @@ echo Couldn't launch python
goto :show_stdout_stderr
:start_venv
-if [%VENV_DIR%] == [-] goto :skip_venv
+if ["%VENV_DIR%"] == ["-"] goto :skip_venv
-dir %VENV_DIR%\Scripts\Python.exe >tmp/stdout.txt 2>tmp/stderr.txt
+dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :activate_venv
for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
-%PYTHON_FULLNAME% -m venv %VENV_DIR% >tmp/stdout.txt 2>tmp/stderr.txt
+%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :activate_venv
-echo Unable to create venv in directory %VENV_DIR%
+echo Unable to create venv in directory "%VENV_DIR%"
goto :show_stdout_stderr
:activate_venv
--
cgit v1.2.3
From f9706acf431f77e0ce9e4270e5be7299922ee963 Mon Sep 17 00:00:00 2001
From: Lee Bousfield
Date: Tue, 10 Jan 2023 18:40:34 -0700
Subject: Support loading textual inversion embeddings from safetensors files
---
modules/textual_inversion/textual_inversion.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5420903f..3866c154 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -9,6 +9,7 @@ import tqdm
import html
import datetime
import csv
+import safetensors.torch
from PIL import Image, PngImagePlugin
@@ -150,6 +151,8 @@ class EmbeddingDatabase:
name = data.get('name', name)
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ data = safetensors.torch.load_file(path, device="cpu")
else:
return
--
cgit v1.2.3
From 5830095b73515fc49b3fd567048470005191ec34 Mon Sep 17 00:00:00 2001
From: catboxanon <122327233+catboxanon@users.noreply.github.com>
Date: Tue, 10 Jan 2023 21:43:24 -0500
Subject: Add old prompt parser compat option
---
modules/shared.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modules/shared.py b/modules/shared.py
index 264264a6..b61bbd3f 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -400,6 +400,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
+ "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
--
cgit v1.2.3
From 7e45fba55b24166501033a221e6268545fa47fbe Mon Sep 17 00:00:00 2001
From: catboxanon <122327233+catboxanon@users.noreply.github.com>
Date: Tue, 10 Jan 2023 21:47:03 -0500
Subject: Fix prompt parser default step transformer w/ test
---
modules/prompt_parser.py | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index f70872c4..b69f1425 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -3,6 +3,11 @@ from collections import namedtuple
from typing import List
import lark
+try:
+ from modules.shared import opts
+except:
+ pass
+
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@@ -49,6 +54,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[5, 'a c'], [10, 'a {b|d{ c']]
>>> g("((a][:b:c [d:3]")
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
+ >>> g("[a|(b:1.1)]")
+ [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
"""
def collect_steps(steps, tree):
@@ -84,7 +91,13 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
- yield from child
+ try:
+ if opts.use_old_prompt_parser_default_step_transformer:
+ yield from child
+ else:
+ yield child
+ except:
+ yield child
return AtStep().transform(tree)
def get_schedule(prompt):
--
cgit v1.2.3
From 37a230112198adcb3f24d59b399cff342a6d479e Mon Sep 17 00:00:00 2001
From: space-nuko <24979496+space-nuko@users.noreply.github.com>
Date: Tue, 10 Jan 2023 20:30:09 -0800
Subject: Expose the compiled class module of scripts to extensions
---
modules/scripts.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/modules/scripts.py b/modules/scripts.py
index 35164093..4ffc369b 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -152,7 +152,7 @@ def basedir():
scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
-ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension):
@@ -206,7 +206,7 @@ def load_scripts():
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
except Exception:
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -241,7 +241,7 @@ class ScriptRunner:
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
- for script_class, path, basedir in scripts_data:
+ for script_class, path, basedir, script_module in scripts_data:
script = script_class()
script.filename = path
script.is_txt2img = not is_img2img
--
cgit v1.2.3
From 954091697fce7a1b7997d5f3d73551f793f6bebc Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 11 Jan 2023 09:10:07 +0300
Subject: add an option to copy config from one of models in checkpoint merger
---
modules/extras.py | 30 +++++++++++++++++++++++++++++-
modules/ui.py | 9 ++++++---
2 files changed, 35 insertions(+), 4 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 7407bfe3..a03d558e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -3,6 +3,7 @@ import math
import os
import sys
import traceback
+import shutil
import numpy as np
from PIL import Image
@@ -248,7 +249,32 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+def create_config(ckpt_result, config_source, a, b, c):
+ def config(x):
+ return sd_models.find_checkpoint_config(x) if x else None
+
+ if config_source == 0:
+ cfg = config(a) or config(b) or config(c)
+ elif config_source == 1:
+ cfg = config(b)
+ elif config_source == 2:
+ cfg = config(c)
+ else:
+ cfg = None
+
+ if cfg is None:
+ return
+
+ filename, _ = os.path.splitext(ckpt_result)
+ checkpoint_filename = filename + ".yaml"
+
+ print("Copying config:")
+ print(" from:", cfg)
+ print(" to:", checkpoint_filename)
+ shutil.copyfile(cfg, checkpoint_filename)
+
+
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models()
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
+
print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
diff --git a/modules/ui.py b/modules/ui.py
index 3c458ce8..82f5dd7c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1129,7 +1129,7 @@ def create_ui():
with gr.Column(variant='panel'):
gr.HTML(value="
A merger of the two checkpoints will be generated in your checkpoint directory.
")
- with gr.Row():
+ with FormRow():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@@ -1143,11 +1143,13 @@ def create_ui():
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
- with gr.Row():
+ with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
- modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
+ config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
@@ -1703,6 +1705,7 @@ def create_ui():
save_as_half,
custom_name,
checkpoint_format,
+ config_source,
],
outputs=[
submit_result,
--
cgit v1.2.3
From 4fdacd31e48c6a7a35c1c25c559932585e8addde Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 11 Jan 2023 10:24:56 +0300
Subject: possible fix for fallback for fast model creation from config
---
modules/sd_models.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b5bc12f0..a0a8a909 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -337,6 +337,9 @@ def load_model(checkpoint_info=None):
with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
+ pass
+
+ if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)
--
cgit v1.2.3
From 1a23dc32ac5e16fac10115cafd0b841abd06e59f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 11 Jan 2023 10:34:36 +0300
Subject: possible fix for fallback for fast model creation from config,
attempt 2
---
modules/sd_models.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index a0a8a909..084ba7fa 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -333,6 +333,7 @@ def load_model(checkpoint_info=None):
timer = Timer()
+ sd_model = None
try:
with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model)
--
cgit v1.2.3
From b202714b65aa2145ff965ed4f197ac1516093f34 Mon Sep 17 00:00:00 2001
From: Alexey Shirokov <40300551+demiurge-ash@users.noreply.github.com>
Date: Wed, 11 Jan 2023 11:41:50 +0300
Subject: Fix keyboard navigation in modal image viewer
---
javascript/imageviewer.js | 1 +
1 file changed, 1 insertion(+)
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index b7bc2fe1..1f29ad7b 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -151,6 +151,7 @@ function showGalleryImage() {
e.addEventListener('mousedown', function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
+ evt.preventDefault()
showModal(evt)
}, true);
}
--
cgit v1.2.3
From ab388d6f8bf51338de1950b3907c324b0ff6a872 Mon Sep 17 00:00:00 2001
From: catboxanon <122327233+catboxanon@users.noreply.github.com>
Date: Wed, 11 Jan 2023 08:59:47 -0500
Subject: Remove compat option check for prompt parser
---
modules/prompt_parser.py | 13 +------------
1 file changed, 1 insertion(+), 12 deletions(-)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index b69f1425..870218db 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -3,11 +3,6 @@ from collections import namedtuple
from typing import List
import lark
-try:
- from modules.shared import opts
-except:
- pass
-
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@@ -91,13 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
- try:
- if opts.use_old_prompt_parser_default_step_transformer:
- yield from child
- else:
- yield child
- except:
- yield child
+ yield child
return AtStep().transform(tree)
def get_schedule(prompt):
--
cgit v1.2.3
From 0b38b72d31ead82c7d0998a29e50da90073831f7 Mon Sep 17 00:00:00 2001
From: catboxanon <122327233+catboxanon@users.noreply.github.com>
Date: Wed, 11 Jan 2023 09:01:37 -0500
Subject: Remove compat option for prompt parser
---
modules/shared.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/modules/shared.py b/modules/shared.py
index b61bbd3f..264264a6 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -400,7 +400,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
- "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
--
cgit v1.2.3
From 39ea251945d70efcf9b59d44eb0e71269d754aa4 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 11 Jan 2023 10:23:51 -0500
Subject: add textinfo to progress response
---
modules/api/api.py | 4 ++--
modules/api/models.py | 1 +
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index 6c564ad8..5767ba90 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -286,7 +286,7 @@ class Api:
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
- return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
# avoid dividing zero
progress = 0.01
@@ -308,7 +308,7 @@ class Api:
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
def interrogateapi(self, interrogatereq: InterrogateRequest):
image_b64 = interrogatereq.image
diff --git a/modules/api/models.py b/modules/api/models.py
index 034b4aa0..c78095ca 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -168,6 +168,7 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
--
cgit v1.2.3
From 3f43d8a966ba8462ba019a5ad573f94508cd45f8 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 11 Jan 2023 10:28:55 -0500
Subject: set descriptions
---
modules/hypernetworks/hypernetwork.py | 4 +++-
modules/textual_inversion/preprocess.py | 7 ++++++-
modules/textual_inversion/textual_inversion.py | 4 +++-
3 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 300d3975..194679e8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -619,7 +619,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index feb876c6..3c1042ad 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
params.process_caption_deepbooru = process_caption_deepbooru
params.preprocess_txt_action = preprocess_txt_action
- for index, imagefile in enumerate(tqdm.tqdm(files)):
+ pbar = tqdm.tqdm(files)
+ for index, imagefile in enumerate(pbar):
params.subindex = 0
filename = os.path.join(src, imagefile)
try:
@@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
except Exception:
continue
+ description = f"Preprocessing [Image {index}/{len(files)}]"
+ pbar.set_description(description)
+ shared.state.textinfo = description
+
params.src = filename
existing_caption = None
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3866c154..b915b091 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -476,7 +476,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
--
cgit v1.2.3
From 4bd490727e156ff53107d53416d6b89be86f2a62 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 11 Jan 2023 18:54:04 +0300
Subject: fix for an error caused by skipping initialization, for realsies this
time: TypeError: expected str, bytes or os.PathLike object, not NoneType
---
modules/sd_disable_initialization.py | 71 ++++++++++++++++--------------------
modules/sd_models.py | 1 +
2 files changed, 33 insertions(+), 39 deletions(-)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 088ac24b..c72d8efc 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -20,6 +20,19 @@ class DisableInitialization:
```
"""
+ def __init__(self):
+ self.replaced = []
+
+ def replace(self, obj, field, func):
+ original = getattr(obj, field, None)
+ if original is None:
+ return None
+
+ self.replaced.append((obj, field, original))
+ setattr(obj, field, func)
+
+ return original
+
def __enter__(self):
def do_nothing(*args, **kwargs):
pass
@@ -37,11 +50,14 @@ class DisableInitialization:
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
# this file is always 404, prevent making request
- if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
- raise transformers.utils.hub.EntryNotFoundError
+ if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
+ return None
try:
- return original(url, *args, local_files_only=True, **kwargs)
+ res = original(url, *args, local_files_only=True, **kwargs)
+ if res is None:
+ res = original(url, *args, local_files_only=False, **kwargs)
+ return res
except Exception as e:
return original(url, *args, local_files_only=False, **kwargs)
@@ -54,42 +70,19 @@ class DisableInitialization:
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
- self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
- self.init_no_grad_normal = torch.nn.init._no_grad_normal_
- self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
- self.create_model_and_transforms = open_clip.create_model_and_transforms
- self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
- self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None)
- self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None)
- self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None)
- self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None)
-
- torch.nn.init.kaiming_uniform_ = do_nothing
- torch.nn.init._no_grad_normal_ = do_nothing
- torch.nn.init._no_grad_uniform_ = do_nothing
- open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
- ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
- if self.transformers_modeling_utils_load_pretrained_model is not None:
- transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model
- if self.transformers_tokenization_utils_base_cached_file is not None:
- transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file
- if self.transformers_configuration_utils_cached_file is not None:
- transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file
- if self.transformers_utils_hub_get_from_cache is not None:
- transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
+ self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
+ self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
+ self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
- torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
- torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
- torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
- open_clip.create_model_and_transforms = self.create_model_and_transforms
- ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
- if self.transformers_modeling_utils_load_pretrained_model is not None:
- transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model
- if self.transformers_tokenization_utils_base_cached_file is not None:
- transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file
- if self.transformers_configuration_utils_cached_file is not None:
- transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file
- if self.transformers_utils_hub_get_from_cache is not None:
- transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
+ for obj, field, original in self.replaced:
+ setattr(obj, field, original)
+
+ self.replaced.clear()
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 084ba7fa..c466f273 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -334,6 +334,7 @@ def load_model(checkpoint_info=None):
timer = Timer()
sd_model = None
+
try:
with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model)
--
cgit v1.2.3
From 0b8911d883118daa54f7735c5b753b5575d9f943 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 11 Jan 2023 20:33:24 +0300
Subject: img2img UI rework: obsolete --gradio-img2img-tool
--gradio-inpaint-tool and always show all tools each in own tab
---
modules/img2img.py | 58 ++++++++++++++----------------
modules/shared.py | 4 +--
modules/ui.py | 103 +++++++++++++++++++++++++++--------------------------
style.css | 4 ++-
4 files changed, 84 insertions(+), 85 deletions(-)
diff --git a/modules/img2img.py b/modules/img2img.py
index ca58b5d8..f62783c6 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
- is_inpaint = mode == 1
- is_batch = mode == 2
-
- if is_inpaint:
- # Drawn mask
- if mask_mode == 0:
- is_mask_sketch = isinstance(init_img_with_mask, dict)
- is_mask_paint = not is_mask_sketch
- if is_mask_sketch:
- # Sketch: mask iff. not transparent
- image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
- else:
- # Color-sketch: mask iff. painted over
- image = init_img_with_mask
- orig = init_img_with_mask_orig or init_img_with_mask
- pred = np.any(np.array(image) != np.array(orig), axis=-1)
- mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
- mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
- blur = ImageFilter.GaussianBlur(mask_blur)
- image = Image.composite(image.filter(blur), orig, mask.filter(blur))
-
- image = image.convert("RGB")
- # Uploaded mask
- else:
- image = init_img_inpaint
- mask = init_mask_inpaint
- # No mask
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+ is_batch = mode == 5
+
+ if mode == 0: # img2img
+ image = init_img.convert("RGB")
+ mask = None
+ elif mode == 1: # img2img sketch
+ image = sketch.convert("RGB")
+ mask = None
+ elif mode == 2: # inpaint
+ image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ image = image.convert("RGB")
+ elif mode == 3: # inpaint sketch
+ image = inpaint_color_sketch
+ orig = inpaint_color_sketch_orig or inpaint_color_sketch
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
+ blur = ImageFilter.GaussianBlur(mask_blur)
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
+ image = image.convert("RGB")
+ elif mode == 4: # inpaint upload mask
+ image = init_img_inpaint
+ mask = init_mask_inpaint
else:
- image = init_img
+ image = None
mask = None
# Use the EXIF orientation of photos taken by smartphones.
diff --git a/modules/shared.py b/modules/shared.py
index 264264a6..1c964237 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
-parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
-parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
+parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
+parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
diff --git a/modules/ui.py b/modules/ui.py
index 82f5dd7c..e86a624b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -795,53 +795,67 @@ def create_ui():
with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"):
+ with gr.Tabs(elem_id="mode_img2img"):
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
- with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
- init_img_with_mask_orig = gr.State(None)
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
- use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
- if use_color_sketch:
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ inpaint_color_sketch_orig = gr.State(None)
- init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
-
- with FormRow():
- mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
@@ -900,20 +914,6 @@ def create_ui():
]
)
- mask_mode.change(
- lambda mode, img: {
- init_img_with_mask: gr_show(mode == 0),
- init_img_inpaint: gr_show(mode == 1),
- init_mask_inpaint: gr_show(mode == 1),
- },
- inputs=[mask_mode, init_img_with_mask],
- outputs=[
- init_img_with_mask,
- init_img_inpaint,
- init_mask_inpaint,
- ],
- )
-
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
@@ -924,11 +924,12 @@ def create_ui():
img2img_prompt_style,
img2img_prompt_style2,
init_img,
+ sketch,
init_img_with_mask,
- init_img_with_mask_orig,
+ inpaint_color_sketch,
+ inpaint_color_sketch_orig,
init_img_inpaint,
init_mask_inpaint,
- mask_mode,
steps,
sampler_index,
mask_blur,
diff --git a/style.css b/style.css
index ec5e4182..ffd6307f 100644
--- a/style.css
+++ b/style.css
@@ -557,7 +557,9 @@ canvas[key="mask"] {
}
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
-img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img
+#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img,
+#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img,
+#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img
{
height: 480px !important;
max-height: 480px !important;
--
cgit v1.2.3
From d52a80f7f7da160c73afd067c8f1bf491391f994 Mon Sep 17 00:00:00 2001
From: Shondoit
Date: Thu, 12 Jan 2023 09:22:29 +0100
Subject: Allow creation of zero vectors for TI
---
modules/textual_inversion/textual_inversion.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index b915b091..853246a6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
--
cgit v1.2.3
From d48dcbd2b29eab492d53d78f482356d78e5beb19 Mon Sep 17 00:00:00 2001
From: Shondoit
Date: Thu, 12 Jan 2023 09:53:35 +0100
Subject: Add zero vector feature to hints.js
Also add the note that some tokens might be skipped. Not everyone is aware of this.
---
javascript/hints.js | 1 +
1 file changed, 1 insertion(+)
diff --git a/javascript/hints.js b/javascript/hints.js
index 856e1389..244bfde2 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -92,6 +92,7 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
+ "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
--
cgit v1.2.3
From 5623a3e7b1beed61f3ae6829a05b7b861d70e203 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 12 Jan 2023 19:47:33 +0300
Subject: fix send to inpaint sending you to wrong place
---
javascript/ui.js | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/javascript/ui.js b/javascript/ui.js
index ee226927..a41dd26f 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -54,7 +54,7 @@ function switch_to_img2img(){
function switch_to_inpaint(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
- gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
+ gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click();
return args_to_array(arguments);
}
--
cgit v1.2.3
From 6ffefdcc9f47b66cbc543690d97cbf8327f4ba58 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 12 Jan 2023 19:47:44 +0300
Subject: fix js errors when restarting UI
---
script.js | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/script.js b/script.js
index 0e117d06..21960d91 100644
--- a/script.js
+++ b/script.js
@@ -1,5 +1,6 @@
function gradioApp() {
- const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot
+ const elems = document.getElementsByTagName('gradio-app')
+ const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
return !!gradioShadowRoot ? gradioShadowRoot : document;
}
--
cgit v1.2.3
From 88416ab5ff787eec3b9962b43b5e544bb75fbad6 Mon Sep 17 00:00:00 2001
From: space-nuko <24979496+space-nuko@users.noreply.github.com>
Date: Thu, 12 Jan 2023 13:46:59 -0800
Subject: Fix extension parameters not being saved to last used parameters
---
modules/processing.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/modules/processing.py b/modules/processing.py
index f04a0e1e..ae04cab7 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -531,16 +531,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
p.scripts.process(p)
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
infotexts = []
output_images = []
--
cgit v1.2.3
From 6c88eaed4f5efca54a882eb1f8f30f01f350332a Mon Sep 17 00:00:00 2001
From: space-nuko <24979496+space-nuko@users.noreply.github.com>
Date: Thu, 12 Jan 2023 13:50:09 -0800
Subject: Add script callback for fixing infotext parameters
---
modules/generation_parameters_copypaste.py | 3 ++-
modules/script_callbacks.py | 20 +++++++++++++++++++-
2 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 620aa606..593d99ef 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -7,7 +7,7 @@ from pathlib import Path
import gradio as gr
from modules.shared import script_path
-from modules import shared, ui_tempdir
+from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image
@@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
prompt = file.read()
params = parse_generation_parameters(prompt)
+ script_callbacks.infotext_pasted_callback(prompt, params)
res = []
for output, key in paste_fields:
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 608c5300..a9e19236 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -2,7 +2,7 @@ import sys
import traceback
from collections import namedtuple
import inspect
-from typing import Optional
+from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
@@ -71,6 +71,7 @@ callback_map = dict(
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
+ callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
)
@@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid')
+def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+ for c in callback_map['callbacks_infotext_pasted']:
+ try:
+ c.callback(infotext, params)
+ except Exception:
+ report_exception(c, 'infotext_pasted')
+
+
def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']):
try:
@@ -290,6 +299,15 @@ def on_image_grid(callback):
add_callback(callback_map['callbacks_image_grid'], callback)
+def on_infotext_pasted(callback):
+ """register a function to be called before applying an infotext.
+ The callback is called with two arguments:
+ - infotext: str - raw infotext.
+ - result: Dict[str, any] - parsed infotext parameters.
+ """
+ add_callback(callback_map['callbacks_infotext_pasted'], callback)
+
+
def on_script_unloaded(callback):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here"""
--
cgit v1.2.3
From 0b262802b86a55c4f71faf377f2cb1aee2960b63 Mon Sep 17 00:00:00 2001
From: Josh R
Date: Thu, 12 Jan 2023 17:31:05 -0800
Subject: add gradient settings to training settings log files
---
modules/textual_inversion/logging.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 8b1981d5..31e50b64 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,7 +2,7 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
--
cgit v1.2.3
From a176d89487d92f5a5b152401e5c424b34ff43b96 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 13 Jan 2023 14:32:15 +0300
Subject: print bucket sizes for training without resizing images #6620 fix an
error when generating a picture with embedding in it
---
modules/textual_inversion/dataset.py | 16 ++++++++++++++++
modules/textual_inversion/image_embedding.py | 4 ++--
modules/textual_inversion/textual_inversion.py | 2 +-
3 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index b47414f3..d31963d4 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -118,6 +118,12 @@ class PersonalizedBase(Dataset):
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
+ if len(groups) > 1:
+ print("Buckets:")
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
+ print(f" {w}x{h}: {len(ids)}")
+ print()
+
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
@@ -140,8 +146,11 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
+
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ super().__init__(data_source)
+
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
@@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler):
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
+
def __len__(self):
return self.len
+
def __iter__(self):
b = self.batch_size
+
for g in self.groups:
shuffle(g)
+
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
+
shuffle(batches)
+
yield from batches
+
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index ea653806..5593f88c 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
- data_np_low.resize(next_size)
+ data_np_low = np.resize(data_np_low, next_size)
data_np_low = data_np_low.reshape((h, -1, d))
- data_np_high.resize(next_size)
+ data_np_high = np.resize(data_np_high, next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 853246a6..e23906ca 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description)
shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
--
cgit v1.2.3
From 82725f0ac439f7e3b67858d55900e95330bbd326 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 13 Jan 2023 15:04:37 +0300
Subject: fix a bug caused by merge
---
modules/textual_inversion/textual_inversion.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 85210b0e..6939efcc 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,6 +11,7 @@ import datetime
import csv
import safetensors.torch
+import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
--
cgit v1.2.3
From d753a9df952ea640acbce724e8153356c8b68424 Mon Sep 17 00:00:00 2001
From: Zaprudin Aleksey
Date: Fri, 13 Jan 2023 22:25:33 +0500
Subject: fix progress bar behavior for "Prompts from file or textbox" script
---
scripts/prompts_from_file.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 2751f98a..1fe10a7c 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -146,7 +146,7 @@ class Script(scripts.Script):
else:
args = {"prompt": line}
- n_iter = args.get("n_iter", 1)
+ n_iter = args.get("n_iter", p.n_iter)
if n_iter != 1:
job_count += n_iter
else:
--
cgit v1.2.3
From cbf4b3472b1da35937ff12c06072214a2e5cbad7 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Fri, 13 Jan 2023 19:18:56 +0100
Subject: Automatic launch argument for AMD GPUs
This commit adds a few lines to detect if the system has an AMD gpu and adds an environment variable needed for torch to recognize the gpu.
---
webui.sh | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index c4d6521d..23629ef9 100755
--- a/webui.sh
+++ b/webui.sh
@@ -165,5 +165,11 @@ else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
- exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ gpu_info=$(lspci | grep VGA)
+ if echo "$gpu_info" | grep -q "AMD"
+ then
+ HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ else
+ exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ fi
fi
--
cgit v1.2.3
From eaebcf638391071172d504568d661931f7e3c740 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Fri, 13 Jan 2023 19:20:18 +0100
Subject: GPU detection script
This commit adds a script that detects which GPU is currently used in Windows and Linux
---
detection.py | 45 +++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 45 insertions(+)
create mode 100644 detection.py
diff --git a/detection.py b/detection.py
new file mode 100644
index 00000000..eb4db0df
--- /dev/null
+++ b/detection.py
@@ -0,0 +1,45 @@
+# This script detects which GPU is currently used in Windows and Linux
+import os
+import sys
+
+def check_gpu():
+ # First, check if the `lspci` command is available
+ if not os.system("which lspci > /dev/null") == 0:
+ # If the `lspci` command is not available, try the `dxdiag` command on Windows
+ if os.name == "nt":
+ # On Windows, run the `dxdiag` command and check the output for the "Card name" field
+ # Create the dxdiag.txt file
+ os.system("dxdiag /t dxdiag.txt")
+
+ # Read the dxdiag.txt file
+ with open("dxdiag.txt", "r") as f:
+ output = f.read()
+
+ if "Card name" in output:
+ card_name_start = output.index("Card name: ") + len("Card name: ")
+ card_name_end = output.index("\n", card_name_start)
+ card_name = output[card_name_start:card_name_end]
+ else:
+ card_name = "Unknown"
+ print(f"Card name: {card_name}")
+ os.remove("dxdiag.txt")
+ if "AMD" in card_name:
+ return "AMD"
+ elif "Intel" in card_name:
+ return "Intel"
+ elif "NVIDIA" in card_name:
+ return "NVIDIA"
+ else:
+ return "Unknown"
+ else:
+ # If the `lspci` command is available, use it to get the GPU vendor and model information
+ output = os.popen("lspci | grep -i vga").read()
+ if "AMD" in output:
+ return "AMD"
+ elif "Intel" in output:
+ return "Intel"
+ elif "NVIDIA" in output:
+ return "NVIDIA"
+ else:
+ return "Unknown"
+
--
cgit v1.2.3
From a407c9f0147c779865c940cbf62c7019dbc1f7b4 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Fri, 13 Jan 2023 19:22:23 +0100
Subject: Automatic torch install for amd on linux
This commit allows the launch script to automatically download rocm's torch version for AMD GPUs using an external GPU detection script. It also prints the operative system and GPU in use.
---
launch.py | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/launch.py b/launch.py
index bcbb792c..668548f1 100644
--- a/launch.py
+++ b/launch.py
@@ -7,6 +7,7 @@ import shlex
import platform
import argparse
import json
+import detection
dir_repos = "repositories"
dir_extensions = "extensions"
@@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
+# Get the GPU vendor and the operating system
+gpu = detection.check_gpu()
+if os.name == "posix":
+ os_name = platform.uname().system
+else:
+ os_name = os.name
def commit_hash():
global stored_commit_hash
@@ -173,7 +180,11 @@ def run_extensions_installers(settings_file):
def prepare_environment():
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
+ if gpu == "AMD" and os_name !="nt":
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2")
+ else:
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -295,6 +306,8 @@ def tests(test_dir):
def start():
+ print(f"Operating System: {os_name}")
+ print(f"GPU: {gpu}")
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
if '--nowebui' in sys.argv:
--
cgit v1.2.3
From a95f1353089bdeaccd7c266b40cdd79efedfe632 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 09:56:59 +0300
Subject: change hash to sha256
---
.gitignore | 1 +
modules/api/api.py | 2 +-
modules/api/models.py | 3 +-
modules/hashes.py | 72 +++++++++++++++
modules/hypernetworks/hypernetwork.py | 4 +-
modules/sd_models.py | 116 ++++++++++++++++---------
modules/shared.py | 2 +-
modules/textual_inversion/textual_inversion.py | 6 +-
webui.py | 2 +
9 files changed, 158 insertions(+), 50 deletions(-)
create mode 100644 modules/hashes.py
diff --git a/.gitignore b/.gitignore
index 21fa26a7..0b1d17ca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,3 +32,4 @@ notification.mp3
/extensions
/test/stdout.txt
/test/stderr.txt
+/cache.json
diff --git a/modules/api/api.py b/modules/api/api.py
index 5767ba90..9814bbc2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -371,7 +371,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/api/models.py b/modules/api/models.py
index c78095ca..1eb1fcf1 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -224,7 +224,8 @@ class UpscalerItem(BaseModel):
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
- hash: str = Field(title="Hash")
+ hash: Optional[str] = Field(title="Short hash")
+ sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
diff --git a/modules/hashes.py b/modules/hashes.py
new file mode 100644
index 00000000..ebfbd90c
--- /dev/null
+++ b/modules/hashes.py
@@ -0,0 +1,72 @@
+import hashlib
+import json
+import os.path
+
+import filelock
+
+
+cache_filename = "cache.json"
+cache_data = None
+
+
+def dump_cache():
+ with filelock.FileLock(cache_filename+".lock"):
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ global cache_data
+
+ if cache_data is None:
+ with filelock.FileLock(cache_filename+".lock"):
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def calculate_sha256(filename):
+ hash_sha256 = hashlib.sha256()
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
+ ondisk_mtime = os.path.getmtime(filename)
+
+ if title in hashes:
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime <= cached_mtime and cached_sha256 is not None:
+ return cached_sha256
+
+ print(f"Calculating sha256 for {filename}: ", end='')
+ sha256_value = calculate_sha256(filename)
+ print(f"{sha256_value}")
+
+ hashes[title] = {
+ "mtime": ondisk_mtime,
+ "sha256": sha256_value,
+ }
+
+ dump_cache()
+
+ return sha256_value
+
+
+
+
+
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 83cbb4f0..9b5f2e79 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -509,7 +509,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
)
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
@@ -737,7 +737,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
- hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.name = hypernetwork_name
hypernetwork.save(filename)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index c466f273..7babb9ae 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,17 +14,56 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
+checkpoint_alisases = {}
checkpoints_loaded = collections.OrderedDict()
+
+class CheckpointInfo:
+ def __init__(self, filename):
+ self.filename = filename
+ abspath = os.path.abspath(filename)
+
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
+ else:
+ name = os.path.basename(filename)
+
+ if name.startswith("\\") or name.startswith("/"):
+ name = name[1:]
+
+ self.title = name
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+ self.hash = model_hash(filename)
+ self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]']
+ self.shorthash = None
+ self.sha256 = None
+
+ def register(self):
+ checkpoints_list[self.title] = self
+ for id in self.ids:
+ checkpoint_alisases[id] = self
+
+ def calculate_shorthash(self):
+ self.sha256 = hashes.sha256(self.filename, self.title)
+ self.shorthash = self.sha256[0:10]
+
+ if self.shorthash not in self.ids:
+ self.ids += [self.shorthash, self.sha256]
+ self.register()
+
+ return self.shorthash
+
+
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -43,10 +82,14 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- convert = lambda name: int(name) if name.isdigit() else name.lower()
- alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
- return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def checkpoint_tiles():
+ def convert(name):
+ return int(name) if name.isdigit() else name.lower()
+
+ def alphanumeric_key(key):
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
@@ -62,48 +105,38 @@ def find_checkpoint_config(info):
def list_models():
checkpoints_list.clear()
+ checkpoint_alisases.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
- def modeltitle(path, shorthash):
- abspath = os.path.abspath(path)
-
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
- elif abspath.startswith(model_path):
- name = abspath.replace(model_path, '')
- else:
- name = os.path.basename(path)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
-
- return f'{name} [{shorthash}]', shortname
-
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
- h = model_hash(cmd_ckpt)
- title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
- shared.opts.data['sd_model_checkpoint'] = title
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
+ checkpoint_info.register()
+
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
+
for filename in model_list:
- h = model_hash(filename)
- title, short_model_name = modeltitle(filename, h)
+ checkpoint_info = CheckpointInfo(filename)
+ checkpoint_info.register()
+
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+def get_closet_checkpoint_match(search_string):
+ checkpoint_info = checkpoint_alisases.get(search_string, None)
+ if checkpoint_info is not None:
+ return
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
-def get_closet_checkpoint_match(searchString):
- applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
- if len(applicable) > 0:
- return applicable[0]
return None
def model_hash(filename):
+ """old hash that only looks at a small part of the file and is prone to collisions"""
+
try:
with open(filename, "rb") as file:
import hashlib
@@ -119,7 +152,7 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoints_list.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
-def load_model_weights(model, checkpoint_info, vae_file="auto"):
- checkpoint_file = checkpoint_info.filename
- sd_model_hash = checkpoint_info.hash
+def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
+ sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
- sd = read_state_dict(checkpoint_file)
+ sd = read_state_dict(checkpoint_info.filename)
model.load_state_dict(sd, strict=False)
del sd
@@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
- model.sd_model_checkpoint = checkpoint_file
+ model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+ vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
sd_vae.load_vae(model, vae_file)
diff --git a/modules/shared.py b/modules/shared.py
index b90ded52..d74c069d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -428,7 +428,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6939efcc..63935878 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
latent_sampling_method = ds.latent_sampling_method
@@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_mid = '[{}]'.format(checkpoint.shorthash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
@@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
try:
- embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint = checkpoint.shorthash
embedding.sd_checkpoint_name = checkpoint.model_name
if remove_cached_checksum:
embedding.cached_checksum = None
diff --git a/webui.py b/webui.py
index 47d372c7..1fff80da 100644
--- a/webui.py
+++ b/webui.py
@@ -78,6 +78,8 @@ def initialize():
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
+ shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
--
cgit v1.2.3
From f9ac3352cb66ce2bc0aa4325130fc7267fb35e4f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 10:25:21 +0300
Subject: change hypernets to use sha256 hashes
---
modules/hypernetworks/hypernetwork.py | 40 ++++++++++++++++++++---------------
modules/processing.py | 2 +-
modules/sd_models.py | 2 +-
modules/shared.py | 1 +
4 files changed, 26 insertions(+), 19 deletions(-)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 9b5f2e79..3aebefa8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -225,7 +225,7 @@ class Hypernetwork:
torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['hash'] = self.shorthash()
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
@@ -237,32 +237,33 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- print(self.layer_structure)
- optional_info = state_dict.get('optional_info', None)
- if optional_info is not None:
- print(f"INFO:\n {optional_info}\n")
- self.optional_info = optional_info
+ self.optional_info = state_dict.get('optional_info', None)
self.activation_func = state_dict.get('activation_func', None)
- print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
- print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
- print(f"Layer norm is set to {self.add_layer_norm}")
self.dropout_structure = state_dict.get('dropout_structure', None)
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
- print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
- print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
if self.dropout_structure is None:
- print("Using previous dropout structure")
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- print(f"Dropout structure is set to {self.dropout_structure}")
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ if shared.opts.print_hypernet_extra:
+ if self.optional_info is not None:
+ print(f" INFO:\n {self.optional_info}\n")
- if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ print(f" Layer structure: {self.layer_structure}")
+ print(f" Activation function: {self.activation_func}")
+ print(f" Weight initialization: {self.weight_init}")
+ print(f" Layer norm: {self.add_layer_norm}")
+ print(f" Dropout usage: {self.use_dropout}" )
+ print(f" Activate last layer: {self.activate_output}")
+ print(f" Dropout structure: {self.dropout_structure}")
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
+
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
@@ -289,6 +290,11 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
self.eval()
+ def shorthash(self):
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
+
+ return sha256[0:10]
+
def list_hypernetworks(path):
res = {}
@@ -296,7 +302,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name + f"({sd_models.model_hash(filename)})"] = filename
+ res[name] = filename
return res
diff --git a/modules/processing.py b/modules/processing.py
index ae04cab7..849f6b19 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7babb9ae..8f00191c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -125,7 +125,7 @@ def list_models():
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_alisases.get(search_string, None)
if checkpoint_info is not None:
- return
+ return checkpoint_info
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
if found:
diff --git a/modules/shared.py b/modules/shared.py
index d74c069d..a6c61db3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
+ "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
}))
options_templates.update(options_section(('training', "Training"), {
--
cgit v1.2.3
From febd2b722e80959b89a0e5966a159b4eb430c5a5 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 13:37:55 +0300
Subject: update key to use with checkpoints' sha256 in cache
---
modules/sd_models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 8f00191c..1fe6d11b 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -54,7 +54,7 @@ class CheckpointInfo:
checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, self.title)
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
--
cgit v1.2.3
From 6eb72fd13f34d94d5459290dd1a0bf0e9ddeda82 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 13:38:10 +0300
Subject: bump gradio to 3.16.1
---
modules/ui.py | 13 ++++++-----
requirements.txt | 2 +-
requirements_versions.txt | 2 +-
style.css | 57 ++++++++++++++++++++++++++++++++---------------
webui.py | 3 +--
5 files changed, 49 insertions(+), 28 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index e86a624b..202e84e5 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -605,7 +605,7 @@ def create_ui():
setup_progressbar(progressbar, txt2img_preview, 'txt2img')
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
@@ -794,7 +794,7 @@ def create_ui():
setup_progressbar(progressbar, img2img_preview, 'img2img')
with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
+ with gr.Column(variant='compact', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img"):
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
@@ -1026,7 +1026,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
+ with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', elem_id="extras_single_tab"):
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
@@ -1127,8 +1127,8 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- gr.HTML(value="
A merger of the two checkpoints will be generated in your checkpoint directory.
")
+ with gr.Column(variant='compact'):
+ gr.HTML(value="
A merger of the two checkpoints will be generated in your checkpoint directory.
")
with FormRow():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
@@ -1150,7 +1150,8 @@ def create_ui():
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
- modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+ with gr.Row():
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
diff --git a/requirements.txt b/requirements.txt
index e1dbf8e5..6cdea781 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,7 +5,7 @@ fairscale==0.4.4
fonts
font-roboto
gfpgan
-gradio==3.15.0
+gradio==3.16.1
invisible-watermark
numpy
omegaconf
diff --git a/requirements_versions.txt b/requirements_versions.txt
index d2899292..cc06d2b4 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -3,7 +3,7 @@ transformers==4.19.2
accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
-gradio==3.15.0
+gradio==3.16.1
numpy==1.23.3
Pillow==9.4.0
realesrgan==0.3.0
diff --git a/style.css b/style.css
index ffd6307f..14b15191 100644
--- a/style.css
+++ b/style.css
@@ -20,7 +20,7 @@
padding-right: 0.25em;
margin: 0.1em 0;
opacity: 0%;
- cursor: default;
+ cursor: default;
}
.output-html p {margin: 0 0.5em;}
@@ -221,7 +221,10 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55);
- box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
+ box-shadow: none;
+ border: 1px solid rgba(128, 128, 128, 0.1);
+ border-radius: 6px;
+ padding: 0.1em 0.5em;
}
#txt2img_column_batch, #img2img_column_batch{
@@ -371,7 +374,7 @@ input[type="range"]{
grid-area: tile;
}
-.modalClose,
+.modalClose,
.modalZoom,
.modalTileImage {
color: white;
@@ -509,29 +512,20 @@ input[type="range"]{
}
#quicksettings > div{
- border: none;
- background: none;
- flex: unset;
- gap: 1em;
-}
-
-#quicksettings > div > div{
- max-width: 32em;
+ max-width: 24em;
min-width: 24em;
padding: 0;
+ border: none;
+ box-shadow: none;
+ background: none;
}
-#quicksettings > div > div > div > div > label > span {
+#quicksettings > div > div > div > label > span {
position: relative;
margin-right: 9em;
margin-bottom: -1em;
}
-#quicksettings > div > div > label > span {
- position: relative;
- margin-bottom: -1em;
-}
-
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
@@ -666,7 +660,10 @@ footer {
font-weight: bold;
}
-#txt2img_checkboxes > div > div{
+#txt2img_checkboxes, #img2img_checkboxes{
+ margin-bottom: 0.5em;
+}
+#txt2img_checkboxes > div > div, #img2img_checkboxes > div > div{
flex: 0;
white-space: nowrap;
min-width: auto;
@@ -676,6 +673,30 @@ footer {
opacity: 0.5;
}
+.gr-compact {
+ border: none;
+ padding-top: 1em;
+}
+
+.dark .gr-compact{
+ background-color: rgb(31 41 55 / var(--tw-bg-opacity));
+ margin-left: 0.8em;
+}
+
+.gr-compact > *{
+ margin-top: 0.5em !important;
+}
+
+.gr-compact .gr-block, .gr-compact .gr-form{
+ border: none;
+ box-shadow: none;
+}
+
+.gr-compact .gr-box{
+ border-radius: .5rem !important;
+ border-width: 1px !important;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
diff --git a/webui.py b/webui.py
index 1fff80da..84159515 100644
--- a/webui.py
+++ b/webui.py
@@ -157,7 +157,7 @@ def webui():
shared.demo = modules.ui.create_ui()
- app, local_url, share_url = shared.demo.queue(default_enabled=False).launch(
+ app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port,
@@ -185,7 +185,6 @@ def webui():
create_api(app)
modules.script_callbacks.app_started_callback(shared.demo, app)
- modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
print('Restarting UI...')
--
cgit v1.2.3
From 5f8685237ed6427c9a8e502124074c740ea7696a Mon Sep 17 00:00:00 2001
From: bbc_mc
Date: Sat, 14 Jan 2023 20:00:00 +0900
Subject: Exclude clip index from merge
---
modules/extras.py | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/modules/extras.py b/modules/extras.py
index a03d558e..22668fcd 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -326,8 +326,14 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
print("Merging...")
+ chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
+
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
a = theta_0[key]
b = theta_1[key]
@@ -352,6 +358,10 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
+
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
--
cgit v1.2.3
From 54fa77facc1849fbbfe61c1ca6d99b117d609d67 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Sat, 14 Jan 2023 12:10:45 +0100
Subject: Fix detection script on macos
This fixes the script on macos
---
detection.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/detection.py b/detection.py
index eb4db0df..442c4be5 100644
--- a/detection.py
+++ b/detection.py
@@ -31,6 +31,8 @@ def check_gpu():
return "NVIDIA"
else:
return "Unknown"
+ else:
+ return "Unknown"
else:
# If the `lspci` command is available, use it to get the GPU vendor and model information
output = os.popen("lspci | grep -i vga").read()
--
cgit v1.2.3
From 865228a83736bea9ede33e98041f2a7d0ca5daaa Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 14:56:39 +0300
Subject: change style dropdowns to multiselect
---
javascript/localization.js | 6 ++----
modules/img2img.py | 4 ++--
modules/styles.py | 12 ++++++++---
modules/txt2img.py | 4 ++--
modules/ui.py | 53 ++++++++++++++++++++++++++--------------------
style.css | 15 ++++++++++---
6 files changed, 57 insertions(+), 37 deletions(-)
diff --git a/javascript/localization.js b/javascript/localization.js
index f92d2d24..bf9e1506 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -10,10 +10,8 @@ ignore_ids_for_localization={
modelmerger_tertiary_model_name: 'OPTION',
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
- txt2img_style_index: 'OPTION',
- txt2img_style2_index: 'OPTION',
- img2img_style_index: 'OPTION',
- img2img_style2_index: 'OPTION',
+ txt2img_styles: 'OPTION',
+ img2img_styles 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
diff --git a/modules/img2img.py b/modules/img2img.py
index f62783c6..79382cc1 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
@@ -101,7 +101,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
diff --git a/modules/styles.py b/modules/styles.py
index ce6e71ca..990d5623 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -40,12 +40,18 @@ def apply_styles_to_prompt(prompt, styles):
class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
- self.styles = {"None": self.no_style}
+ self.styles = {}
+ self.path = path
- if not os.path.exists(path):
+ self.reload()
+
+ def reload(self):
+ self.styles.clear()
+
+ if not os.path.exists(self.path):
return
- with open(path, "r", encoding="utf-8-sig", newline='') as file:
+ with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file)
for row in reader:
# Support loading old CSV format with "name, text"-columns
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 38b5f591..5a71793b 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,13 +8,13 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
diff --git a/modules/ui.py b/modules/ui.py
index 202e84e5..db198a47 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -180,7 +180,7 @@ def add_style(name: str, prompt: str, negative_prompt: str):
# reserialize all styles every time we save them
shared.prompt_styles.save_styles(shared.styles_filename)
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
@@ -197,11 +197,11 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+def apply_styles(prompt, prompt_neg, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
def interrogate(image):
@@ -374,13 +374,10 @@ def create_toprow(is_img2img):
)
with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+ prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
+ create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(*args, **kwargs):
@@ -590,7 +587,7 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -684,8 +681,7 @@ def create_ui():
inputs=[
txt2img_prompt,
txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
+ txt2img_prompt_styles,
steps,
sampler_index,
restore_faces,
@@ -780,7 +776,7 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -921,8 +917,7 @@ def create_ui():
dummy_component,
img2img_prompt,
img2img_negative_prompt,
- img2img_prompt_style,
- img2img_prompt_style2,
+ img2img_prompt_styles,
init_img,
sketch,
init_img_with_mask,
@@ -977,7 +972,7 @@ def create_ui():
)
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
+ style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
@@ -987,15 +982,15 @@ def create_ui():
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
+ outputs=[txt2img_prompt_styles, img2img_prompt_styles],
)
- for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
+ for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click(
fn=apply_styles,
_js=js_func,
- inputs=[prompt, negative_prompt, style1, style2],
- outputs=[prompt, negative_prompt, style1, style2],
+ inputs=[prompt, negative_prompt, styles],
+ outputs=[prompt, negative_prompt, styles],
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1530,6 +1525,7 @@ def create_ui():
previous_section = None
current_tab = None
+ current_row = None
with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
@@ -1538,10 +1534,14 @@ def create_ui():
elem_id, text = item.section
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
+ gr.Group()
current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
previous_section = item.section
@@ -1556,6 +1556,7 @@ def create_ui():
components.append(component)
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
with gr.TabItem("Actions"):
@@ -1774,7 +1775,13 @@ def create_ui():
apply_field(x, 'value')
if type(x) == gr.Dropdown:
- apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
+ def check_dropdown(val):
+ if x.multiselect:
+ return all([value in x.choices for value in val])
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
diff --git a/style.css b/style.css
index 14b15191..c1ebd501 100644
--- a/style.css
+++ b/style.css
@@ -114,6 +114,7 @@
min-width: unset !important;
flex-grow: 0 !important;
padding: 0.4em 0;
+ gap: 0;
}
#roll_col > button {
@@ -141,10 +142,14 @@
min-width: 8em !important;
}
-#txt2img_style_index, #txt2img_style2_index, #img2img_style_index, #img2img_style2_index{
+#txt2img_styles, #img2img_styles{
margin-top: 1em;
}
+#txt2img_styles ul, #img2img_styles ul{
+ max-height: 35em;
+}
+
.gr-form{
background: transparent;
}
@@ -154,10 +159,14 @@
margin-bottom: 0;
}
-#toprow div{
+#toprow div.gr-box, #toprow div.gr-form{
border: none;
gap: 0;
background: transparent;
+ box-shadow: none;
+}
+#toprow div{
+ gap: 0;
}
#resize_mode{
@@ -615,7 +624,7 @@ canvas[key="mask"] {
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
- margin: 0.55em 0;
+ margin: 1.6em 0 0 0;
}
#quicksettings .gr-button-tool{
--
cgit v1.2.3
From 08c6f009a5ee92dd3218a942c08e8337c26352be Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 15:55:40 +0300
Subject: load hashes from cache for checkpoints that have them add checkpoint
hash to footer
---
javascript/ui.js | 25 ++++++++++++++++---------
modules/hashes.py | 26 +++++++++++++++++++-------
modules/sd_models.py | 9 ++++++---
modules/shared.py | 1 +
modules/ui.py | 2 ++
script.js | 4 ++++
6 files changed, 48 insertions(+), 19 deletions(-)
diff --git a/javascript/ui.js b/javascript/ui.js
index a41dd26f..1e04a8f4 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -143,14 +143,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
opts = {}
-function apply_settings(jsdata){
- console.log(jsdata)
-
- opts = JSON.parse(jsdata)
-
- return jsdata
-}
-
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
@@ -160,7 +152,7 @@ onUiUpdate(function(){
textarea = json_elem.querySelector('textarea')
jsdata = textarea.value
opts = JSON.parse(jsdata)
-
+ executeCallbacks(optionsChangedCallbacks);
Object.defineProperty(textarea, 'value', {
set: function(newValue) {
@@ -171,6 +163,8 @@ onUiUpdate(function(){
if (oldValue != newValue) {
opts = JSON.parse(textarea.value)
}
+
+ executeCallbacks(optionsChangedCallbacks);
},
get: function() {
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
@@ -201,6 +195,19 @@ onUiUpdate(function(){
}
})
+
+onOptionsChanged(function(){
+ elem = gradioApp().getElementById('sd_checkpoint_hash')
+ sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
+ shorthash = sd_checkpoint_hash.substr(0,10)
+
+ if(elem && elem.textContent != shorthash){
+ elem.textContent = shorthash
+ elem.title = sd_checkpoint_hash
+ elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
+ }
+})
+
let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
diff --git a/modules/hashes.py b/modules/hashes.py
index ebfbd90c..14231771 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -42,23 +42,35 @@ def calculate_sha256(filename):
return hash_sha256.hexdigest()
-def sha256(filename, title):
+def sha256_from_cache(filename, title):
hashes = cache("hashes")
ondisk_mtime = os.path.getmtime(filename)
- if title in hashes:
- cached_sha256 = hashes[title].get("sha256", None)
- cached_mtime = hashes[title].get("mtime", 0)
+ if title not in hashes:
+ return None
+
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
+ return None
+
+ return cached_sha256
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
- if ondisk_mtime <= cached_mtime and cached_sha256 is not None:
- return cached_sha256
+ sha256_value = sha256_from_cache(filename, title)
+ if sha256_value is not None:
+ return sha256_value
print(f"Calculating sha256 for {filename}: ", end='')
sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {
- "mtime": ondisk_mtime,
+ "mtime": os.path.getmtime(filename),
"sha256": sha256_value,
}
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 1fe6d11b..e5a0bc63 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -44,9 +44,11 @@ class CheckpointInfo:
self.title = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]']
- self.shorthash = None
- self.sha256 = None
+
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
+
+ self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -269,6 +271,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
model.logvar = model.logvar.to(devices.device) # fix for training
diff --git a/modules/shared.py b/modules/shared.py
index a6c61db3..c9988d4d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -458,6 +458,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
options_templates.update()
diff --git a/modules/ui.py b/modules/ui.py
index e86a624b..2625ae32 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1841,4 +1841,6 @@ xformers: {xformers_version}
gradio: {gr.__version__}
•
commit: {short_commit}
+ •
+checkpoint: N/A
"""
diff --git a/script.js b/script.js
index 21960d91..3345e32b 100644
--- a/script.js
+++ b/script.js
@@ -14,6 +14,7 @@ function get_uiCurrentTabContent() {
uiUpdateCallbacks = []
uiTabChangeCallbacks = []
+optionsChangedCallbacks = []
let uiCurrentTab = null
function onUiUpdate(callback){
@@ -22,6 +23,9 @@ function onUiUpdate(callback){
function onUiTabChange(callback){
uiTabChangeCallbacks.push(callback)
}
+function onOptionsChanged(callback){
+ optionsChangedCallbacks.push(callback)
+}
function runCallback(x, m){
try {
--
cgit v1.2.3
From f94a215abed85b34ae978853078812801d3e7738 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 16:29:23 +0300
Subject: add an option to choose what you want to see in live preview (Live
preview subject) and moves live preview settings to its own tab
---
modules/sd_samplers.py | 15 ++++++++++-----
modules/shared.py | 13 +++++++++----
modules/ui_progress.py | 2 +-
3 files changed, 20 insertions(+), 10 deletions(-)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 01221b89..7616fded 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None):
def store_latent(decoded):
state.current_latent = decoded
- if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
shared.state.current_image = sample_to_image(decoded)
@@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler:
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
@@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler:
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
-
+
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ if opts.live_preview_content == "Prompt":
+ store_latent(x_out[0:uncond.shape[0]])
+ elif opts.live_preview_content == "Negative prompt":
+ store_latent(x_out[-uncond.shape[0]:])
+
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
@@ -423,7 +427,8 @@ class KDiffusionSampler:
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
- store_latent(latent)
+ if opts.live_preview_content == "Combined":
+ store_latent(latent)
self.last_latent = latent
if self.stop_at is not None and step > self.stop_at:
diff --git a/modules/shared.py b/modules/shared.py
index c9988d4d..e0ec3136 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -176,7 +176,7 @@ class State:
self.interrupted = True
def nextjob(self):
- if opts.show_progress_every_n_steps == -1:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
@@ -224,7 +224,7 @@ class State:
if not parallel_processing_allowed:
return
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable:
self.do_set_current_image()
def do_set_current_image(self):
@@ -423,8 +423,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
@@ -444,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), {
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
+options_templates.update(options_section(('ui', "Live previews"), {
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
+}))
+
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
diff --git a/modules/ui_progress.py b/modules/ui_progress.py
index 592fda55..7cd312e4 100644
--- a/modules/ui_progress.py
+++ b/modules/ui_progress.py
@@ -52,7 +52,7 @@ def check_progress_call(id_part):
image = gr.update(visible=False)
preview_visibility = gr.update(visible=False)
- if opts.show_progress_every_n_steps != 0:
+ if opts.live_previews_enable:
shared.state.set_current_image()
image = shared.state.current_image
--
cgit v1.2.3
From 69781031e7473e020b3af4461fdceb20130e56ab Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 16:45:39 +0300
Subject: simplify expression in prompts from file script
---
scripts/prompts_from_file.py | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 1fe10a7c..f3e711d7 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -146,11 +146,7 @@ class Script(scripts.Script):
else:
args = {"prompt": line}
- n_iter = args.get("n_iter", p.n_iter)
- if n_iter != 1:
- job_count += n_iter
- else:
- job_count += 1
+ job_count += args.get("n_iter", p.n_iter)
jobs.append(args)
--
cgit v1.2.3
From 934cba0f4ca3e80a2079a657ebb6ca8c1ee2d10b Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Sat, 14 Jan 2023 15:43:29 +0100
Subject: Delete detection.py
---
detection.py | 47 -----------------------------------------------
1 file changed, 47 deletions(-)
delete mode 100644 detection.py
diff --git a/detection.py b/detection.py
deleted file mode 100644
index 442c4be5..00000000
--- a/detection.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# This script detects which GPU is currently used in Windows and Linux
-import os
-import sys
-
-def check_gpu():
- # First, check if the `lspci` command is available
- if not os.system("which lspci > /dev/null") == 0:
- # If the `lspci` command is not available, try the `dxdiag` command on Windows
- if os.name == "nt":
- # On Windows, run the `dxdiag` command and check the output for the "Card name" field
- # Create the dxdiag.txt file
- os.system("dxdiag /t dxdiag.txt")
-
- # Read the dxdiag.txt file
- with open("dxdiag.txt", "r") as f:
- output = f.read()
-
- if "Card name" in output:
- card_name_start = output.index("Card name: ") + len("Card name: ")
- card_name_end = output.index("\n", card_name_start)
- card_name = output[card_name_start:card_name_end]
- else:
- card_name = "Unknown"
- print(f"Card name: {card_name}")
- os.remove("dxdiag.txt")
- if "AMD" in card_name:
- return "AMD"
- elif "Intel" in card_name:
- return "Intel"
- elif "NVIDIA" in card_name:
- return "NVIDIA"
- else:
- return "Unknown"
- else:
- return "Unknown"
- else:
- # If the `lspci` command is available, use it to get the GPU vendor and model information
- output = os.popen("lspci | grep -i vga").read()
- if "AMD" in output:
- return "AMD"
- elif "Intel" in output:
- return "Intel"
- elif "NVIDIA" in output:
- return "NVIDIA"
- else:
- return "Unknown"
-
--
cgit v1.2.3
From 629612935372aa7faeb764b5660431e49b93de24 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Sat, 14 Jan 2023 15:45:07 +0100
Subject: Revert detection code
---
launch.py | 15 +--------------
1 file changed, 1 insertion(+), 14 deletions(-)
diff --git a/launch.py b/launch.py
index 668548f1..bcbb792c 100644
--- a/launch.py
+++ b/launch.py
@@ -7,7 +7,6 @@ import shlex
import platform
import argparse
import json
-import detection
dir_repos = "repositories"
dir_extensions = "extensions"
@@ -16,12 +15,6 @@ git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
-# Get the GPU vendor and the operating system
-gpu = detection.check_gpu()
-if os.name == "posix":
- os_name = platform.uname().system
-else:
- os_name = os.name
def commit_hash():
global stored_commit_hash
@@ -180,11 +173,7 @@ def run_extensions_installers(settings_file):
def prepare_environment():
- if gpu == "AMD" and os_name !="nt":
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2")
- else:
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
-
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -306,8 +295,6 @@ def tests(test_dir):
def start():
- print(f"Operating System: {os_name}")
- print(f"GPU: {gpu}")
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
if '--nowebui' in sys.argv:
--
cgit v1.2.3
From 6192a222bf7131771a2cd7655a64a5b24a1e6e2e Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Sat, 14 Jan 2023 15:46:23 +0100
Subject: Export TORCH_COMMAND for AMD from the webui
---
webui.sh | 1 +
1 file changed, 1 insertion(+)
diff --git a/webui.sh b/webui.sh
index 23629ef9..35f52f2a 100755
--- a/webui.sh
+++ b/webui.sh
@@ -168,6 +168,7 @@ else
gpu_info=$(lspci | grep VGA)
if echo "$gpu_info" | grep -q "AMD"
then
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
else
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
--
cgit v1.2.3
From c4ba34928ec7f977585494f0fa5925496c887698 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Sat, 14 Jan 2023 15:58:50 +0100
Subject: Quick format fix
---
webui.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index 35f52f2a..fcba6b7d 100755
--- a/webui.sh
+++ b/webui.sh
@@ -168,7 +168,7 @@ else
gpu_info=$(lspci | grep VGA)
if echo "$gpu_info" | grep -q "AMD"
then
- export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
else
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
--
cgit v1.2.3
From fad850fc3d33e7cda2ce4b3a32ab7976c313db53 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Sat, 14 Jan 2023 11:18:05 -0500
Subject: add server_start to shared.state
---
modules/shared.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/modules/shared.py b/modules/shared.py
index e0ec3136..ef93637c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -168,6 +168,7 @@ class State:
textinfo = None
time_start = None
need_restart = False
+ server_start = None
def skip(self):
self.skipped = True
@@ -241,6 +242,7 @@ class State:
state = State()
+state.server_start = time.time()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
--
cgit v1.2.3
From a5bbcd215304e0c83ab2b9fe7f172f88536d7629 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 19:56:09 +0300
Subject: fix bug with "Ignore selected VAE for..." option completely disabling
VAE election rework VAE resolving code to be more simple
---
modules/sd_models.py | 6 +-
modules/sd_vae.py | 194 ++++++++++++++++++++-------------------------------
modules/shared.py | 4 +-
scripts/xy_grid.py | 27 ++++---
4 files changed, 95 insertions(+), 136 deletions(-)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e5a0bc63..6a681cef 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
-def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
+def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
- sd_vae.load_vae(model, vae_file)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ sd_vae.load_vae(model, vae_file, vae_source)
def enable_midas_autodownload():
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 0a49daa1..6ea92711 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -9,23 +9,9 @@ import glob
from copy import deepcopy
-model_dir = "Stable-diffusion"
-model_path = os.path.abspath(os.path.join(models_path, model_dir))
-vae_dir = "VAE"
-vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
-
-
+vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-
-
-default_vae_dict = {"auto": "auto", "None": None, None: None}
-default_vae_list = ["auto", "None"]
-
-
-default_vae_values = [default_vae_dict[x] for x in default_vae_list]
-vae_dict = dict(default_vae_dict)
-vae_list = list(default_vae_list)
-first_load = True
+vae_dict = {}
base_vae = None
@@ -64,100 +50,69 @@ def restore_base_vae(model):
def get_filename(filepath):
- return os.path.splitext(os.path.basename(filepath))[0]
-
-
-def refresh_vae_list(vae_path=vae_path, model_path=model_path):
- global vae_dict, vae_list
- res = {}
- candidates = [
- *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
- *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
- *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
+ return os.path.basename(filepath)
+
+
+def refresh_vae_list():
+ vae_dict.clear()
+
+ paths = [
+ os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
+ os.path.join(sd_models.model_path, '**/*.vae.pt'),
+ os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
+ os.path.join(vae_path, '**/*.ckpt'),
+ os.path.join(vae_path, '**/*.pt'),
+ os.path.join(vae_path, '**/*.safetensors'),
]
- if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
- candidates.append(shared.cmd_opts.vae_path)
+
+ if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
+ paths += [
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
+ ]
+
+ candidates = []
+ for path in paths:
+ candidates += glob.iglob(path, recursive=True)
+
for filepath in candidates:
name = get_filename(filepath)
- res[name] = filepath
- vae_list.clear()
- vae_list.extend(default_vae_list)
- vae_list.extend(list(res.keys()))
- vae_dict.clear()
- vae_dict.update(res)
- vae_dict.update(default_vae_dict)
- return vae_list
-
-
-def get_vae_from_settings(vae_file="auto"):
- # else, we load from settings, if not set to be default
- if vae_file == "auto" and shared.opts.sd_vae is not None:
- # if saved VAE settings isn't recognized, fallback to auto
- vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
- # if VAE selected but not found, fallback to auto
- if vae_file not in default_vae_values and not os.path.isfile(vae_file):
- vae_file = "auto"
- print(f"Selected VAE doesn't exist: {vae_file}")
- return vae_file
-
-
-def resolve_vae(checkpoint_file=None, vae_file="auto"):
- global first_load, vae_dict, vae_list
-
- # if vae_file argument is provided, it takes priority, but not saved
- if vae_file and vae_file not in default_vae_list:
- if not os.path.isfile(vae_file):
- print(f"VAE provided as function argument doesn't exist: {vae_file}")
- vae_file = "auto"
- # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
- if first_load and shared.cmd_opts.vae_path is not None:
- if os.path.isfile(shared.cmd_opts.vae_path):
- vae_file = shared.cmd_opts.vae_path
- shared.opts.data['sd_vae'] = get_filename(vae_file)
- else:
- print(f"VAE provided as command line argument doesn't exist: {vae_file}")
- # fallback to selector in settings, if vae selector not set to act as default fallback
- if not shared.opts.sd_vae_as_default:
- vae_file = get_vae_from_settings(vae_file)
- # vae-path cmd arg takes priority for auto
- if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
- if os.path.isfile(shared.cmd_opts.vae_path):
- vae_file = shared.cmd_opts.vae_path
- print(f"Using VAE provided as command line argument: {vae_file}")
- # if still not found, try look for ".vae.pt" beside model
- model_path = os.path.splitext(checkpoint_file)[0]
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.pt"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # if still not found, try look for ".vae.ckpt" beside model
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.ckpt"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # if still not found, try look for ".vae.safetensors" beside model
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.safetensors"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # No more fallbacks for auto
- if vae_file == "auto":
- vae_file = None
- # Last check, just because
- if vae_file and not os.path.exists(vae_file):
- vae_file = None
-
- return vae_file
-
-
-def load_vae(model, vae_file=None):
- global first_load, vae_dict, vae_list, loaded_vae_file
+ vae_dict[name] = filepath
+
+
+def find_vae_near_checkpoint(checkpoint_file):
+ checkpoint_path = os.path.splitext(checkpoint_file)[0]
+ for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
+ if os.path.isfile(vae_location):
+ return vae_location
+
+ return None
+
+
+def resolve_vae(checkpoint_file):
+ if shared.cmd_opts.vae_path is not None:
+ return shared.cmd_opts.vae_path, 'from commandline argument'
+
+ vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
+ if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"):
+ return vae_near_checkpoint, 'found near the checkpoint'
+
+ if shared.opts.sd_vae == "None":
+ return None, None
+
+ vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
+ if vae_from_options is not None:
+ return vae_from_options, 'specified in settings'
+
+ if shared.opts.sd_vae != "Automatic":
+ print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
+
+ return None, None
+
+
+def load_vae(model, vae_file=None, vae_source="from unknown source"):
+ global vae_dict, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@@ -165,12 +120,12 @@ def load_vae(model, vae_file=None):
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
- print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
+ print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
else:
- assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
- print(f"Loading VAE weights from: {vae_file}")
+ assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
+ print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
@@ -191,14 +146,12 @@ def load_vae(model, vae_file=None):
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
- vae_list.append(vae_opt)
+
elif loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file
- first_load = False
-
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
@@ -211,7 +164,10 @@ def clear_loaded_vae():
loaded_vae_file = None
-def reload_vae_weights(sd_model=None, vae_file="auto"):
+unspecified = object()
+
+
+def reload_vae_weights(sd_model=None, vae_file=unspecified):
from modules import lowvram, devices, sd_hijack
if not sd_model:
@@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
- vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
+
+ if vae_file == unspecified:
+ vae_file, vae_source = resolve_vae(checkpoint_file)
+ else:
+ vae_source = "from function argument"
if loaded_vae_file == vae_file:
return
@@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_vae(sd_model, vae_file)
+ load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
@@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
- print("VAE Weights loaded.")
+ print("VAE weights loaded.")
return sd_model
diff --git a/modules/shared.py b/modules/shared.py
index e0ec3136..9756adea 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
+parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -383,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index f04d9b7e..bd3087d4 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs):
def find_vae(name: str):
- if name.lower() in ['auto', 'none']:
- return name
+ if name.lower() in ['auto', 'automatic']:
+ return modules.sd_vae.unspecified
+ if name.lower() == 'none':
+ return None
else:
- vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
- found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
- if found:
- return found[0]
+ choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
+ if len(choices) == 0:
+ print(f"No VAE found for {name}; using automatic")
+ return modules.sd_vae.unspecified
else:
- return 'auto'
+ return modules.sd_vae.vae_dict[choices[0]]
def apply_vae(p, x, xs):
- if x.lower().strip() == 'none':
- modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
- else:
- found = find_vae(x)
- if found:
- v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
+ modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
@@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object):
def __exit__(self, exc_type, exc_value, tb):
modules.sd_models.reload_model_weights(self.model)
- modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))
+
+ opts.data["sd_vae"] = self.vae
+ modules.sd_vae.reload_vae_weights(self.model)
hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength()
--
cgit v1.2.3
From f8c512478568293155539f616dce26c5e4495055 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 20:00:12 +0300
Subject: typo?
---
modules/sd_vae.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 6ea92711..add5cecf 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -95,7 +95,7 @@ def resolve_vae(checkpoint_file):
return shared.cmd_opts.vae_path, 'from commandline argument'
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
- if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"):
+ if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"):
return vae_near_checkpoint, 'found near the checkpoint'
if shared.opts.sd_vae == "None":
--
cgit v1.2.3
From 86359535d6fb0899fa9e838d27f2006b929331d5 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 14 Jan 2023 22:43:01 +0300
Subject: add buttons to copy images between img2img tabs
---
javascript/ui.js | 21 +++++++++++++++++++--
modules/ui.py | 42 +++++++++++++++++++++++++++++++++++++++++-
style.css | 18 ++++++++++++++++++
3 files changed, 78 insertions(+), 3 deletions(-)
diff --git a/javascript/ui.js b/javascript/ui.js
index 1e04a8f4..f8279124 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -45,10 +45,27 @@ function switch_to_txt2img(){
return args_to_array(arguments);
}
-function switch_to_img2img(){
+function switch_to_img2img_tab(no){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
- gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
+ gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
+}
+function switch_to_img2img(){
+ switch_to_img2img_tab(0);
+ return args_to_array(arguments);
+}
+
+function switch_to_sketch(){
+ switch_to_img2img_tab(1);
+ return args_to_array(arguments);
+}
+
+function switch_to_inpaint(){
+ switch_to_img2img_tab(2);
+ return args_to_array(arguments);
+}
+function switch_to_inpaint_sketch(){
+ switch_to_img2img_tab(3);
return args_to_array(arguments);
}
diff --git a/modules/ui.py b/modules/ui.py
index 2625ae32..2425c66f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -795,19 +795,39 @@ def create_ui():
with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"):
+ copy_image_buttons = []
+ copy_image_destinations = {}
+
+ def add_copy_image_controls(tab_name, elem):
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
+
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
+ if name == tab_name:
+ gr.Button(title, interactive=False)
+ copy_image_destinations[name] = elem
+ continue
+
+ button = gr.Button(title)
+ copy_image_buttons.append((button, name, elem))
+
with gr.Tabs(elem_id="mode_img2img"):
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
inpaint_color_sketch_orig = gr.State(None)
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
def update_orig(image, state):
if image is not None:
@@ -824,10 +844,29 @@ def create_ui():
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
")
+ gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
+
+ return img
+
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js="switch_to_"+name.replace(" ", "_"),
+ inputs=[],
+ outputs=[],
+ )
+
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
with FormRow():
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
@@ -856,6 +895,7 @@ def create_ui():
outputs=[inpaint_controls, mask_alpha],
)
+
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
diff --git a/style.css b/style.css
index ffd6307f..2d484e06 100644
--- a/style.css
+++ b/style.css
@@ -676,6 +676,24 @@ footer {
opacity: 0.5;
}
+#mode_img2img > div > div{
+ gap: 0 !important;
+}
+
+[id*='img2img_copy_to_'] {
+ border: none;
+}
+
+[id*='img2img_copy_to_'] > button {
+}
+
+[id*='img2img_label_copy_to_'] {
+ font-size: 1.0em;
+ font-weight: bold;
+ text-align: center;
+ line-height: 2.4em;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
--
cgit v1.2.3
From 2e172cf831a928223e93803b94896325bd4c22a7 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Sat, 14 Jan 2023 22:25:32 +0100
Subject: Only set TORCH_COMMAND if wasn't set webui-user
---
webui.sh | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index fcba6b7d..35542ed6 100755
--- a/webui.sh
+++ b/webui.sh
@@ -168,7 +168,10 @@ else
gpu_info=$(lspci | grep VGA)
if echo "$gpu_info" | grep -q "AMD"
then
- export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+ if [ -z ${TORCH_COMMAND+x} ]
+ then
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+ fi
HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
else
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
--
cgit v1.2.3
From ba077e2110cab891a46d14665fb161ce0669f31e Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Sat, 14 Jan 2023 23:19:52 +0100
Subject: Fix TORCH_COMMAND check
---
webui.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index 35542ed6..6e07778f 100755
--- a/webui.sh
+++ b/webui.sh
@@ -168,7 +168,7 @@ else
gpu_info=$(lspci | grep VGA)
if echo "$gpu_info" | grep -q "AMD"
then
- if [ -z ${TORCH_COMMAND+x} ]
+ if [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
fi
--
cgit v1.2.3
From cbbdfc3609097fb8b31e32387396ee1ae299fc6f Mon Sep 17 00:00:00 2001
From: Josh R
Date: Sat, 14 Jan 2023 14:45:08 -0800
Subject: Fix Aspect Ratio Overlay / AROverlay to work with Inpaint & Sketch
---
javascript/aspectRatioOverlay.js | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js
index 66f26a22..0f164b82 100644
--- a/javascript/aspectRatioOverlay.js
+++ b/javascript/aspectRatioOverlay.js
@@ -21,11 +21,16 @@ function dimensionChange(e, is_width, is_height){
var targetElement = null;
var tabIndex = get_tab_index('mode_img2img')
- if(tabIndex == 0){
+ if(tabIndex == 0){ // img2img
targetElement = gradioApp().querySelector('div[data-testid=image] img');
- } else if(tabIndex == 1){
+ } else if(tabIndex == 1){ //Sketch
+ targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
+ } else if(tabIndex == 2){ // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
+ } else if(tabIndex == 3){ // Inpaint sketch
+ targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
}
+
if(targetElement){
--
cgit v1.2.3
From 9ef41df6f9043d58fbbeea1f06be8e5c8622248b Mon Sep 17 00:00:00 2001
From: Josh R
Date: Sat, 14 Jan 2023 15:26:45 -0800
Subject: add inpaint masking controls to orderable section that the settings
can order
---
modules/shared.py | 1 +
modules/ui.py | 58 +++++++++++++++++++++++++++----------------------------
2 files changed, 30 insertions(+), 29 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 51df056c..7ce8003f 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -116,6 +116,7 @@ restricted_opts = {
}
ui_reorder_categories = [
+ "masking",
"sampler",
"dimensions",
"cfg",
diff --git a/modules/ui.py b/modules/ui.py
index 2425c66f..174930ab 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -867,35 +867,6 @@ def create_ui():
outputs=[],
)
- with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
-
- with FormRow():
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- def select_img2img_tab(tab):
- return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
-
- for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
- elem.select(
- fn=lambda tab=i: select_img2img_tab(tab),
- inputs=[],
- outputs=[inpaint_controls, mask_alpha],
- )
-
-
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
@@ -937,6 +908,35 @@ def create_ui():
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+ elif category == "masking":
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
--
cgit v1.2.3
From d97f467c0d27695d23edad5e4f8898a57e0ccb00 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 09:24:21 +0300
Subject: add license file
---
LICENSE.txt | 663 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 663 insertions(+)
create mode 100644 LICENSE.txt
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 00000000..14577543
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,663 @@
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (c) 2023 AUTOMATIC1111
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
--
cgit v1.2.3
From eef1990a5e6c41ecb6943ff5529316ad5ededb2a Mon Sep 17 00:00:00 2001
From: brkirch
Date: Sun, 15 Jan 2023 08:13:33 -0500
Subject: Fix Approx NN on devices other than CUDA
---
modules/sd_vae_approx.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 0a58542d..0027343a 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -36,7 +36,7 @@ def model():
if sd_vae_approx_model is None:
sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+ sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
--
cgit v1.2.3
From f0312565e5b4d56a421af889a9a8eaea0ba92959 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Sun, 15 Jan 2023 09:42:34 -0500
Subject: increase block size
---
modules/hashes.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/modules/hashes.py b/modules/hashes.py
index 14231771..b85a7580 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -34,9 +34,10 @@ def cache(subsection):
def calculate_sha256(filename):
hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
with open(filename, "rb") as f:
- for chunk in iter(lambda: f.read(4096), b""):
+ for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
--
cgit v1.2.3
From d8b90ac121cbf0c18b1dc9d56a5e1d14ca51e74e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 18:50:56 +0300
Subject: big rework of progressbar/preview system to allow multiple users to
prompts at the same time and do not get previews of each other
---
javascript/progressbar.js | 249 ++++++++++++++++---------
javascript/textualInversion.js | 13 +-
javascript/ui.js | 33 +++-
modules/call_queue.py | 19 +-
modules/hypernetworks/hypernetwork.py | 6 +-
modules/img2img.py | 2 +-
modules/progress.py | 96 ++++++++++
modules/sd_samplers.py | 2 +-
modules/shared.py | 16 +-
modules/textual_inversion/preprocess.py | 2 +-
modules/textual_inversion/textual_inversion.py | 6 +-
modules/txt2img.py | 2 +-
modules/ui.py | 41 ++--
modules/ui_progress.py | 101 ----------
style.css | 74 +++++---
webui.py | 3 +
16 files changed, 390 insertions(+), 275 deletions(-)
create mode 100644 modules/progress.py
delete mode 100644 modules/ui_progress.py
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index d6323ed9..b7524ef7 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -1,82 +1,25 @@
// code related to showing and updating progressbar shown as the image is being made
-global_progressbars = {}
-galleries = {}
-galleryObservers = {}
-
-// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
-timeoutIds = {}
-function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
- // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
- // every time you use gr.HTML(elem_id='xxx'), so we handle this here
- var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
- var progressbarParent
- if(progressbar){
- progressbarParent = gradioApp().querySelector("#"+id_progressbar)
- } else{
- progressbar = gradioApp().getElementById(id_progressbar)
- progressbarParent = null
- }
- var skip = id_skip ? gradioApp().getElementById(id_skip) : null
- var interrupt = gradioApp().getElementById(id_interrupt)
-
- if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
- if(progressbar.innerText){
- let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
- if(document.title != newtitle){
- document.title = newtitle;
- }
- }else{
- let newtitle = 'Stable Diffusion'
- if(document.title != newtitle){
- document.title = newtitle;
- }
- }
- }
-
- if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
- global_progressbars[id_progressbar] = progressbar
-
- var mutationObserver = new MutationObserver(function(m){
- if(timeoutIds[id_part]) return;
-
- preview = gradioApp().getElementById(id_preview)
- gallery = gradioApp().getElementById(id_gallery)
+galleries = {}
+storedGallerySelections = {}
+galleryObservers = {}
- if(preview != null && gallery != null){
- preview.style.width = gallery.clientWidth + "px"
- preview.style.height = gallery.clientHeight + "px"
- if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
+function rememberGallerySelection(id_gallery){
+ storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
+}
- //only watch gallery if there is a generation process going on
- check_gallery(id_gallery);
+function getGallerySelectedIndex(id_gallery){
+ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
+ let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
- var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
- if(progressDiv){
- timeoutIds[id_part] = window.setTimeout(function() {
- timeoutIds[id_part] = null
- requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
- }, 500)
- } else{
- if (skip) {
- skip.style.display = "none"
- }
- interrupt.style.display = "none"
+ let currentlySelectedIndex = -1
+ galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
- //disconnect observer once generation finished, so user can close selected image if they want
- if (galleryObservers[id_gallery]) {
- galleryObservers[id_gallery].disconnect();
- galleries[id_gallery] = null;
- }
- }
- }
-
- });
- mutationObserver.observe( progressbar, { childList:true, subtree:true })
- }
+ return currentlySelectedIndex
}
+// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
@@ -85,10 +28,16 @@ function check_gallery(id_gallery){
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
- let prevSelectedIndex = selected_gallery_index();
+
+ storedGallerySelections[id_gallery] = -1
+
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
+ let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
+ prevSelectedIndex = storedGallerySelections[id_gallery]
+ storedGallerySelections[id_gallery] = -1
+
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
// automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement;
@@ -120,30 +69,150 @@ function check_gallery(id_gallery){
}
onUiUpdate(function(){
- check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
- check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
- check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
+ check_gallery('txt2img_gallery')
+ check_gallery('img2img_gallery')
})
-function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
- btn = gradioApp().getElementById(id_part+"_check_progress");
- if(btn==null) return;
-
- btn.click();
- var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
- var skip = id_skip ? gradioApp().getElementById(id_skip) : null
- var interrupt = gradioApp().getElementById(id_interrupt)
- if(progressDiv && interrupt){
- if (skip) {
- skip.style.display = "block"
+function request(url, data, handler, errorHandler){
+ var xhr = new XMLHttpRequest();
+ var url = url;
+ xhr.open("POST", url, true);
+ xhr.setRequestHeader("Content-Type", "application/json");
+ xhr.onreadystatechange = function () {
+ if (xhr.readyState === 4) {
+ if (xhr.status === 200) {
+ var js = JSON.parse(xhr.responseText);
+ handler(js)
+ } else{
+ errorHandler()
+ }
}
- interrupt.style.display = "block"
+ };
+ var js = JSON.stringify(data);
+ xhr.send(js);
+}
+
+function pad2(x){
+ return x<10 ? '0'+x : x
+}
+
+function formatTime(secs){
+ if(secs > 3600){
+ return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
+ } else if(secs > 60){
+ return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
+ } else{
+ return Math.floor(secs) + "s"
}
}
-function requestProgress(id_part){
- btn = gradioApp().getElementById(id_part+"_check_progress_initial");
- if(btn==null) return;
+function randomId(){
+ return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
+}
+
+// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
+// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
+// calls onProgress every time there is a progress update
+function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
+ var dateStart = new Date()
+ var wasEverActive = false
+ var parentProgressbar = progressbarContainer.parentNode
+ var parentGallery = gallery.parentNode
+
+ var divProgress = document.createElement('div')
+ divProgress.className='progressDiv'
+ var divInner = document.createElement('div')
+ divInner.className='progress'
+
+ divProgress.appendChild(divInner)
+ parentProgressbar.insertBefore(divProgress, progressbarContainer)
+
+ var livePreview = document.createElement('div')
+ livePreview.className='livePreview'
+ parentGallery.insertBefore(livePreview, gallery)
+
+ var removeProgressBar = function(){
+ parentProgressbar.removeChild(divProgress)
+ parentGallery.removeChild(livePreview)
+ atEnd()
+ }
+
+ var fun = function(id_task, id_live_preview){
+ request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
+ console.log(res)
+
+ if(res.completed){
+ removeProgressBar()
+ return
+ }
+
+ var rect = progressbarContainer.getBoundingClientRect()
+
+ if(rect.width){
+ divProgress.style.width = rect.width + "px";
+ }
+
+ progressText = ""
+
+ divInner.style.width = ((res.progress || 0) * 100.0) + '%'
+
+ if(res.progress > 0){
+ progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
+ }
+
+ if(res.eta){
+ progressText += " ETA: " + formatTime(res.eta)
+ } else if(res.textinfo){
+ progressText += " " + res.textinfo
+ }
+
+ divInner.textContent = progressText
+
+ var elapsedFromStart = (new Date() - dateStart) / 1000
+
+ if(res.active) wasEverActive = true;
+
+ if(! res.active && wasEverActive){
+ removeProgressBar()
+ return
+ }
+
+ if(elapsedFromStart > 5 && !res.queued && !res.active){
+ removeProgressBar()
+ return
+ }
+
+
+ if(res.live_preview){
+ var img = new Image();
+ img.onload = function() {
+ var rect = gallery.getBoundingClientRect()
+ if(rect.width){
+ livePreview.style.width = rect.width + "px"
+ livePreview.style.height = rect.height + "px"
+ }
+
+ livePreview.innerHTML = ''
+ livePreview.appendChild(img)
+ if(livePreview.childElementCount > 2){
+ livePreview.removeChild(livePreview.firstElementChild)
+ }
+ }
+ img.src = res.live_preview;
+ }
+
+
+ if(onProgress){
+ onProgress(res)
+ }
+
+ setTimeout(() => {
+ fun(id_task, res.id_live_preview);
+ }, 500)
+ }, function(){
+ removeProgressBar()
+ })
+ }
- btn.click();
+ fun(id_task, 0)
}
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
index 8061be08..0354b860 100644
--- a/javascript/textualInversion.js
+++ b/javascript/textualInversion.js
@@ -1,8 +1,17 @@
+
function start_training_textual_inversion(){
- requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
- return args_to_array(arguments)
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
+ gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
+ })
+
+ var res = args_to_array(arguments)
+
+ res[0] = id
+
+ return res
}
diff --git a/javascript/ui.js b/javascript/ui.js
index f8279124..ecf97cb3 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -126,18 +126,41 @@ function create_submit_args(args){
return res
}
+function showSubmitButtons(tabname, show){
+ gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
+ gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
+}
+
function submit(){
- requestProgress('txt2img')
+ rememberGallerySelection('txt2img_gallery')
+ showSubmitButtons('txt2img', false)
+
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
+ showSubmitButtons('txt2img', true)
+
+ })
- return create_submit_args(arguments)
+ var res = create_submit_args(arguments)
+
+ res[0] = id
+
+ return res
}
function submit_img2img(){
- requestProgress('img2img')
+ rememberGallerySelection('img2img_gallery')
+ showSubmitButtons('img2img', false)
+
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
+ showSubmitButtons('img2img', true)
+ })
- res = create_submit_args(arguments)
+ var res = create_submit_args(arguments)
- res[0] = get_tab_index('mode_img2img')
+ res[0] = id
+ res[1] = get_tab_index('mode_img2img')
return res
}
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 4cd49533..92097c15 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -4,7 +4,7 @@ import threading
import traceback
import time
-from modules import shared
+from modules import shared, progress
queue_lock = threading.Lock()
@@ -22,12 +22,23 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
- shared.state.begin()
+ # if the first argument is a string that says "task(...)", it is treated as a job id
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ id_task = args[0]
+ progress.add_task_to_queue(id_task)
+ else:
+ id_task = None
with queue_lock:
- res = func(*args, **kwargs)
+ shared.state.begin()
+ progress.start_task(id_task)
+
+ try:
+ res = func(*args, **kwargs)
+ finally:
+ progress.finish_task(id_task)
- shared.state.end()
+ shared.state.end()
return res
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3aebefa8..ae6af516 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
torch.cuda.set_rng_state_all(cuda_rng_state)
hypernetwork.train()
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
diff --git a/modules/img2img.py b/modules/img2img.py
index f62783c6..f4a03c57 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
diff --git a/modules/progress.py b/modules/progress.py
new file mode 100644
index 00000000..3327b883
--- /dev/null
+++ b/modules/progress.py
@@ -0,0 +1,96 @@
+import base64
+import io
+import time
+
+import gradio as gr
+from pydantic import BaseModel, Field
+
+from modules.shared import opts
+
+import modules.shared as shared
+
+
+current_task = None
+pending_tasks = {}
+finished_tasks = []
+
+
+def start_task(id_task):
+ global current_task
+
+ current_task = id_task
+ pending_tasks.pop(id_task, None)
+
+
+def finish_task(id_task):
+ global current_task
+
+ if current_task == id_task:
+ current_task = None
+
+ finished_tasks.append(id_task)
+ if len(finished_tasks) > 16:
+ finished_tasks.pop(0)
+
+
+def add_task_to_queue(id_job):
+ pending_tasks[id_job] = time.time()
+
+
+class ProgressRequest(BaseModel):
+ id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
+ id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+
+
+class ProgressResponse(BaseModel):
+ active: bool = Field(title="Whether the task is being worked on right now")
+ queued: bool = Field(title="Whether the task is in queue")
+ completed: bool = Field(title="Whether the task has already finished")
+ progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
+ eta: float = Field(default=None, title="ETA in secs")
+ live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
+ id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+
+
+def setup_progress_api(app):
+ return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+
+
+def progressapi(req: ProgressRequest):
+ active = req.id_task == current_task
+ queued = req.id_task in pending_tasks
+ completed = req.id_task in finished_tasks
+
+ if not active:
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
+
+ progress = 0
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ progress = min(progress, 1)
+
+ elapsed_since_start = time.time() - shared.state.time_start
+ predicted_duration = elapsed_since_start / progress if progress > 0 else None
+ eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
+
+ id_live_preview = req.id_live_preview
+ shared.state.set_current_image()
+ if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
+ image = shared.state.current_image
+ if image is not None:
+ buffered = io.BytesIO()
+ image.save(buffered, format="png")
+ live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
+ id_live_preview = shared.state.id_live_preview
+ else:
+ live_preview = None
+ else:
+ live_preview = None
+
+ return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 7616fded..76e0e0d5 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -140,7 +140,7 @@ def store_latent(decoded):
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
- shared.state.current_image = sample_to_image(decoded)
+ shared.state.assign_current_image(sample_to_image(decoded))
class InterruptedException(BaseException):
diff --git a/modules/shared.py b/modules/shared.py
index 51df056c..de99aca9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -152,6 +152,7 @@ def reload_hypernetworks():
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
class State:
skipped = False
interrupted = False
@@ -165,6 +166,7 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
+ id_live_preview = 0
textinfo = None
time_start = None
need_restart = False
@@ -207,6 +209,7 @@ class State:
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
+ self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
@@ -220,8 +223,8 @@ class State:
devices.torch_gc()
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not parallel_processing_allowed:
return
@@ -234,12 +237,16 @@ class State:
import modules.sd_samplers
if opts.show_progress_grid:
- self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
- self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
+ def assign_current_image(self, image):
+ self.current_image = image
+ self.id_live_preview += 1
+
state = State()
state.server_start = time.time()
@@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
}))
options_templates.update(options_section(('ui', "User interface"), {
- "show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
@@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), {
options_templates.update(options_section(('ui', "Live previews"), {
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 3c1042ad..64abff4d 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 63935878..7e4a6d24 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
@@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
@@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 38b5f591..ca5d4550 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
diff --git a/modules/ui.py b/modules/ui.py
index 2425c66f..ff33236b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -356,7 +356,7 @@ def create_toprow(is_img2img):
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
- with gr.Row():
+ with gr.Row(elem_id=f"{id_part}_generate_box"):
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
@@ -384,9 +384,7 @@ def create_toprow(is_img2img):
def setup_progressbar(*args, **kwargs):
- import modules.ui_progress
-
- modules.ui_progress.setup_progressbar(*args, **kwargs)
+ pass
def apply_setting(key, value):
@@ -479,8 +477,8 @@ Requested path was: {f}
else:
sp.Popen(["xdg-open", path])
- with gr.Column(variant='panel'):
- with gr.Group():
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
generation_info = None
@@ -595,15 +593,6 @@ def create_ui():
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
@@ -682,6 +671,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
+ dummy_component,
txt2img_prompt,
txt2img_negative_prompt,
txt2img_prompt_style,
@@ -782,16 +772,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"):
@@ -958,6 +939,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
inputs=[
+ dummy_component,
dummy_component,
img2img_prompt,
img2img_negative_prompt,
@@ -1335,15 +1317,11 @@ def create_ui():
script_callbacks.ui_train_tabs_callback(params)
- with gr.Column():
- progressbar = gr.HTML(elem_id="ti_progressbar")
+ with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
- ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
- setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
@@ -1384,6 +1362,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
process_src,
process_dst,
process_width,
@@ -1411,6 +1390,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_embedding_name,
embedding_learn_rate,
batch_size,
@@ -1443,6 +1423,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
diff --git a/modules/ui_progress.py b/modules/ui_progress.py
deleted file mode 100644
index 7cd312e4..00000000
--- a/modules/ui_progress.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import time
-
-import gradio as gr
-
-from modules.shared import opts
-
-import modules.shared as shared
-
-
-def calc_time_left(progress, threshold, label, force_display, show_eta):
- if progress == 0:
- return ""
- else:
- time_since_start = time.time() - shared.state.time_start
- eta = (time_since_start/progress)
- eta_relative = eta-time_since_start
- if (eta_relative > threshold and show_eta) or force_display:
- if eta_relative > 3600:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
- elif eta_relative > 60:
- return label + time.strftime('%M:%S', time.gmtime(eta_relative))
- else:
- return label + time.strftime('%Ss', time.gmtime(eta_relative))
- else:
- return ""
-
-
-def check_progress_call(id_part):
- if shared.state.job_count == 0:
- return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
-
- progress = 0
-
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
-
- # Show progress percentage and time left at the same moment, and base it also on steps done
- show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
-
- time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
- if time_left != "":
- shared.state.time_left_force_display = True
-
- progress = min(progress, 1)
-
- progressbar = ""
- if opts.show_progressbar:
- progressbar = f"""
"""
-
- image = gr.update(visible=False)
- preview_visibility = gr.update(visible=False)
-
- if opts.live_previews_enable:
- shared.state.set_current_image()
- image = shared.state.current_image
-
- if image is None:
- image = gr.update(value=None)
- else:
- preview_visibility = gr.update(visible=True)
-
- if shared.state.textinfo is not None:
- textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
- else:
- textinfo_result = gr.update(visible=False)
-
- return f"{time.time()}
{progressbar}
", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
- shared.state.job_count = -1
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.textinfo = None
- shared.state.time_start = time.time()
- shared.state.time_left_force_display = False
-
- return check_progress_call(id_part)
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
- if textinfo is None:
- textinfo = gr.HTML(visible=False)
-
- check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
- check_progress.click(
- fn=lambda: check_progress_call(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
- check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
- check_progress_initial.click(
- fn=lambda: check_progress_call_initial(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
diff --git a/style.css b/style.css
index 2d484e06..786b71d1 100644
--- a/style.css
+++ b/style.css
@@ -305,26 +305,42 @@ input[type="range"]{
}
.progressDiv{
- width: 100%;
- height: 20px;
- background: #b4c0cc;
- border-radius: 8px;
+ position: absolute;
+ height: 20px;
+ top: -20px;
+ background: #b4c0cc;
+ border-radius: 8px !important;
}
.dark .progressDiv{
- background: #424c5b;
+ background: #424c5b;
}
.progressDiv .progress{
- width: 0%;
- height: 20px;
- background: #0060df;
- color: white;
- font-weight: bold;
- line-height: 20px;
- padding: 0 8px 0 0;
- text-align: right;
- border-radius: 8px;
+ width: 0%;
+ height: 20px;
+ background: #0060df;
+ color: white;
+ font-weight: bold;
+ line-height: 20px;
+ padding: 0 8px 0 0;
+ text-align: right;
+ border-radius: 8px;
+ overflow: visible;
+ white-space: nowrap;
+}
+
+.livePreview{
+ position: absolute;
+ z-index: 300;
+ background-color: white;
+ margin: -4px;
+}
+
+.livePreview img{
+ object-fit: contain;
+ width: 100%;
+ height: 100%;
}
#lightboxModal{
@@ -450,23 +466,25 @@ input[type="range"]{
display:none
}
-#txt2img_interrupt, #img2img_interrupt{
- position: absolute;
- width: 50%;
- height: 72px;
- background: #b4c0cc;
- border-radius: 0px;
- display: none;
+#txt2img_generate_box, #img2img_generate_box{
+ position: relative;
+}
+
+#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{
+ position: absolute;
+ width: 50%;
+ height: 100%;
+ background: #b4c0cc;
+ display: none;
}
+#txt2img_interrupt, #img2img_interrupt{
+ right: 0;
+ border-radius: 0 0.5rem 0.5rem 0;
+}
#txt2img_skip, #img2img_skip{
- position: absolute;
- width: 50%;
- right: 0px;
- height: 72px;
- background: #b4c0cc;
- border-radius: 0px;
- display: none;
+ left: 0;
+ border-radius: 0.5rem 0 0 0.5rem;
}
.red {
diff --git a/webui.py b/webui.py
index 1fff80da..4624fe18 100644
--- a/webui.py
+++ b/webui.py
@@ -34,6 +34,7 @@ import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
+import modules.progress
import modules.ui
from modules import modelloader
@@ -181,6 +182,8 @@ def webui():
app.add_middleware(GZipMiddleware, minimum_size=1000)
+ modules.progress.setup_progress_api(app)
+
if launch_api:
create_api(app)
--
cgit v1.2.3
From 388708f7b13dfbc890135cad678bfbcebd7baf37 Mon Sep 17 00:00:00 2001
From: pangbo13 <373108669@qq.com>
Date: Mon, 16 Jan 2023 00:56:24 +0800
Subject: fix when show_progress_every_n_steps == -1
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/shared.py b/modules/shared.py
index de99aca9..f857ccde 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -228,7 +228,7 @@ class State:
if not parallel_processing_allowed:
return
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable:
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
--
cgit v1.2.3
From 16f410893eb96c7810cbbd812541ba35e0e92524 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Mon, 16 Jan 2023 02:08:47 +0900
Subject: fix missing 'mean loss' for tensorboard integration
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index ae6af516..bbd1f673 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -644,7 +644,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if shared.opts.training_enable_tensorboard:
epoch_num = hypernetwork.step // len(ds)
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
-
+ mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
--
cgit v1.2.3
From f6aac4c65a681383616f6e72e3865002600f476f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 20:20:29 +0300
Subject: eliminate flicker for live previews
---
javascript/progressbar.js | 14 +++++++-------
style.css | 1 +
2 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index b7524ef7..8f22c018 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -184,15 +184,15 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
if(res.live_preview){
+
+ var rect = gallery.getBoundingClientRect()
+ if(rect.width){
+ livePreview.style.width = rect.width + "px"
+ livePreview.style.height = rect.height + "px"
+ }
+
var img = new Image();
img.onload = function() {
- var rect = gallery.getBoundingClientRect()
- if(rect.width){
- livePreview.style.width = rect.width + "px"
- livePreview.style.height = rect.height + "px"
- }
-
- livePreview.innerHTML = ''
livePreview.appendChild(img)
if(livePreview.childElementCount > 2){
livePreview.removeChild(livePreview.firstElementChild)
diff --git a/style.css b/style.css
index 786b71d1..5bf1c6f9 100644
--- a/style.css
+++ b/style.css
@@ -338,6 +338,7 @@ input[type="range"]{
}
.livePreview img{
+ position: absolute;
object-fit: contain;
width: 100%;
height: 100%;
--
cgit v1.2.3
From a534bdfc801e0c83e378dfaa2d04cf865d7109f9 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 20:27:39 +0300
Subject: add setting for progressbar update period
---
javascript/progressbar.js | 2 +-
modules/shared.py | 1 +
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 8f22c018..59173c83 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -208,7 +208,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
setTimeout(() => {
fun(id_task, res.id_live_preview);
- }, 500)
+ }, opts.live_preview_refresh_period || 500)
}, function(){
removeProgressBar()
})
diff --git a/modules/shared.py b/modules/shared.py
index de99aca9..3483db1c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -455,6 +455,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
+ "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
--
cgit v1.2.3
From b6ce041cdf722b400df9b5eac306d0cb049923d7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 20:29:48 +0300
Subject: put interrupt and skip buttons back where they were
---
modules/ui.py | 2 +-
style.css | 8 ++++----
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index ff33236b..7a357f9a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -357,8 +357,8 @@ def create_toprow(is_img2img):
with gr.Column(scale=1):
with gr.Row(elem_id=f"{id_part}_generate_box"):
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
diff --git a/style.css b/style.css
index 5bf1c6f9..750fe315 100644
--- a/style.css
+++ b/style.css
@@ -480,13 +480,13 @@ input[type="range"]{
}
#txt2img_interrupt, #img2img_interrupt{
- right: 0;
- border-radius: 0 0.5rem 0.5rem 0;
-}
-#txt2img_skip, #img2img_skip{
left: 0;
border-radius: 0.5rem 0 0 0.5rem;
}
+#txt2img_skip, #img2img_skip{
+ right: 0;
+ border-radius: 0 0.5rem 0.5rem 0;
+}
.red {
color: red;
--
cgit v1.2.3
From 110d1a2d598bcfacffe3d524df1a3422b4cbd8ec Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Sun, 15 Jan 2023 12:41:00 -0500
Subject: add fields to settings file
---
modules/textual_inversion/logging.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 31e50b64..734a4b6f 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,7 +2,7 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
--
cgit v1.2.3
From 598f7fcd84f655dd204ad5e258dc1c41cc806cde Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Mon, 16 Jan 2023 02:46:21 +0900
Subject: Fix loss_dict problem
---
modules/hypernetworks/hypernetwork.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index bbd1f673..438e3e9f 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -561,6 +561,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ loss_logging = []
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
@@ -601,6 +602,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
+ loss_logging.append(loss.item())
del x
del c
@@ -644,7 +646,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if shared.opts.training_enable_tensorboard:
epoch_num = hypernetwork.step // len(ds)
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
- mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
+ mean_loss = sum(loss_logging) / len(loss_logging)
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
--
cgit v1.2.3
From 13445738d974edcca5ff2f4f8f3833c1f3433e5e Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Mon, 16 Jan 2023 03:02:54 +0900
Subject: Fix tensorboard related functions
---
modules/hypernetworks/hypernetwork.py | 13 ++++++-------
1 file changed, 6 insertions(+), 7 deletions(-)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 438e3e9f..c963fc40 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -561,7 +561,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
- loss_logging = []
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
@@ -602,7 +602,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
- loss_logging.append(loss.item())
del x
del c
@@ -612,7 +611,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
-
+ loss_logging.append(_loss_step)
if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate)
@@ -690,9 +689,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
-
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
- textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -703,7 +699,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
hypernetwork.train()
if image is not None:
shared.state.assign_current_image(image)
-
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
+ f"Validation at epoch {epoch_num}", image,
+ hypernetwork.step)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
--
cgit v1.2.3
From 8e2aeee4a127b295bfc880800e4a312e0f049b85 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 22:29:53 +0300
Subject: add BREAK keyword to end current text chunk and start the next
---
modules/prompt_parser.py | 7 ++++++-
modules/sd_hijack_clip.py | 17 +++++++++++++----
2 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 870218db..69665372 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -274,6 +274,7 @@ re_attention = re.compile(r"""
:
""", re.X)
+re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text):
"""
@@ -339,7 +340,11 @@ def parse_prompt_attention(text):
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
- res.append([text, 1.0])
+ parts = re.split(re_break, text)
+ for i, part in enumerate(parts):
+ if i > 0:
+ res.append(["BREAK", -1])
+ res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 852afc66..9fa5c5c5 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
token_count = 0
last_comma = -1
- def next_chunk():
- """puts current chunk into the list of results and produces the next one - empty"""
+ def next_chunk(is_last=False):
+ """puts current chunk into the list of results and produces the next one - empty;
+ if is_last is true, tokens tokens at the end won't add to token_count"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
- token_count += len(chunk.tokens)
+ if is_last:
+ token_count += len(chunk.tokens)
+ else:
+ token_count += self.chunk_length
+
to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [self.id_end] * to_add
@@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
+ if text == 'BREAK' and weight == -1:
+ next_chunk()
+ continue
+
position = 0
while position < len(tokens):
token = tokens[position]
@@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0:
- next_chunk()
+ next_chunk(is_last=True)
return chunks, token_count
--
cgit v1.2.3
From db9b11617997ad02e5eb68be306078b3b8d3e2cf Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Sun, 15 Jan 2023 23:13:58 +0300
Subject: fix paths with parentheses
---
webui.bat | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/webui.bat b/webui.bat
index e6a7a429..3165b94d 100644
--- a/webui.bat
+++ b/webui.bat
@@ -1,7 +1,7 @@
@echo off
if not defined PYTHON (set PYTHON=python)
-if not defined VENV_DIR (set VENV_DIR=%~dp0%venv)
+if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
set ERROR_REPORTING=FALSE
--
cgit v1.2.3
From fc25af3939f0366f147892a8ae5b9da56495352b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 23:22:51 +0300
Subject: remove unneeded log from progressbar
---
javascript/progressbar.js | 2 --
1 file changed, 2 deletions(-)
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 59173c83..5072c13f 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -139,8 +139,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var fun = function(id_task, id_live_preview){
request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
- console.log(res)
-
if(res.completed){
removeProgressBar()
return
--
cgit v1.2.3
From 89314e79da21ac71ad3133ccf5ac3e85d4c24052 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 23:23:16 +0300
Subject: fix an error that happens when you send an empty image from txt2img
to img2img
---
modules/generation_parameters_copypaste.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 593d99ef..a381ff59 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -37,6 +37,9 @@ def quote(text):
def image_from_url_text(filedata):
+ if filedata is None:
+ return None
+
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
--
cgit v1.2.3
From 3db22e6ee45193559a2c3ba44ab672b067245f99 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 23:32:38 +0300
Subject: rename masking to inpaint in UI make inpaint go to the right place
for users who don't have it in config string
---
modules/shared.py | 2 +-
modules/ui.py | 6 +++---
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 3bdc375b..f06ae610 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -116,7 +116,7 @@ restricted_opts = {
}
ui_reorder_categories = [
- "masking",
+ "inpaint",
"sampler",
"dimensions",
"cfg",
diff --git a/modules/ui.py b/modules/ui.py
index b3d4af3e..20b66165 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -570,9 +570,9 @@ def create_sampler_and_steps_selection(choices, tabname):
def ordered_ui_categories():
- user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
@@ -889,7 +889,7 @@ def create_ui():
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
- elif category == "masking":
+ elif category == "inpaint":
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
with FormRow():
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
--
cgit v1.2.3
From 9a43acf94ead6bc15da2782c39ab5a3107c3f06c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 15 Jan 2023 23:37:34 +0300
Subject: add background color for live previews in dark mode
---
style.css | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/style.css b/style.css
index 750fe315..78fa9838 100644
--- a/style.css
+++ b/style.css
@@ -337,6 +337,10 @@ input[type="range"]{
margin: -4px;
}
+.dark .livePreview{
+ background-color: rgb(17 24 39 / var(--tw-bg-opacity));
+}
+
.livePreview img{
position: absolute;
object-fit: contain;
--
cgit v1.2.3
From 3f887f7f61d69fa699a272166b79fdb787e9ce1d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 00:44:46 +0300
Subject: support old configs that say "auto" for ssd_vae change
sd_vae_as_default to True by default as it's a more sensible setting
---
modules/sd_vae.py | 6 ++++--
modules/shared.py | 2 +-
2 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index add5cecf..e9c6bb40 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -94,8 +94,10 @@ def resolve_vae(checkpoint_file):
if shared.cmd_opts.vae_path is not None:
return shared.cmd_opts.vae_path, 'from commandline argument'
+ is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
+
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
- if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"):
+ if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
return vae_near_checkpoint, 'found near the checkpoint'
if shared.opts.sd_vae == "None":
@@ -105,7 +107,7 @@ def resolve_vae(checkpoint_file):
if vae_from_options is not None:
return vae_from_options, 'specified in settings'
- if shared.opts.sd_vae != "Automatic":
+ if is_automatic:
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
return None, None
diff --git a/modules/shared.py b/modules/shared.py
index f06ae610..c5fc250e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -394,7 +394,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
- "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
--
cgit v1.2.3
From ff6a5bcec1ce25aa8f08b157ea957d764be23d8d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 01:28:20 +0300
Subject: bugfix for previous commit
---
modules/sd_vae.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index e9c6bb40..b2af2ce7 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -107,7 +107,7 @@ def resolve_vae(checkpoint_file):
if vae_from_options is not None:
return vae_from_options, 'specified in settings'
- if is_automatic:
+ if not is_automatic:
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
return None, None
--
cgit v1.2.3
From f202ff1901c27d1f82d5e2684dba9e1ed24ffdf2 Mon Sep 17 00:00:00 2001
From: space-nuko <24979496+space-nuko@users.noreply.github.com>
Date: Sun, 15 Jan 2023 19:43:34 -0800
Subject: Make XY grid cancellation much faster
---
scripts/xy_grid.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index bd3087d4..13a3a046 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -406,6 +406,9 @@ class Script(scripts.Script):
grid_infotext = [None]
def cell(x, y):
+ if shared.state.interrupted:
+ return Processed(p, [], p.seed, "")
+
pc = copy(p)
x_opt.apply(pc, x, xs)
y_opt.apply(pc, y, ys)
--
cgit v1.2.3
From 029260b4ca7267d7a75319dbc11bca2a8c52774e Mon Sep 17 00:00:00 2001
From: space-nuko <24979496+space-nuko@users.noreply.github.com>
Date: Sun, 15 Jan 2023 21:40:57 -0800
Subject: Optimize XY grid to run slower axes fewer times
---
scripts/xy_grid.py | 123 ++++++++++++++++++++++++++++++-----------------------
1 file changed, 70 insertions(+), 53 deletions(-)
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 13a3a046..074ee919 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -175,76 +175,87 @@ def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
-AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
-AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
+AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])
+AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"])
axis_options = [
- AxisOption("Nothing", str, do_nothing, format_nothing, None),
- AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
- AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
- AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
- AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
- AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
- AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
- AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
- AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
- AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
- AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
- AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
- AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
- AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
- AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
- AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
- AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
- AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
- AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
- AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
- AxisOption("VAE", str, apply_vae, format_value_add_label, None),
- AxisOption("Styles", str, apply_styles, format_value_add_label, None),
+ AxisOption("Nothing", str, do_nothing, format_nothing, None, 0),
+ AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0),
+ AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0),
+ AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0),
+ AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0),
+ AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0),
+ AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0),
+ AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0),
+ AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2),
+ AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0),
+ AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0),
+ AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0),
+ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0),
+ AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0),
+ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0),
+ AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0),
+ AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0),
+ AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0),
+ AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0),
+ AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7),
+ AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0),
]
-def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
+def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
# Temporary list of all the images that are generated to be populated into the grid.
# Will be filled with empty images for any individual step that fails to process properly
- image_cache = []
+ image_cache = [None] * (len(xs) * len(ys))
processed_result = None
cell_mode = "P"
- cell_size = (1,1)
+ cell_size = (1, 1)
state.job_count = len(xs) * len(ys) * p.n_iter
- for iy, y in enumerate(ys):
+ def process_cell(x, y, ix, iy):
+ nonlocal image_cache, processed_result, cell_mode, cell_size
+
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
+
+ processed: Processed = cell(x, y)
+
+ try:
+ # this dereference will throw an exception if the image was not processed
+ # (this happens in cases such as if the user stops the process from the UI)
+ processed_image = processed.images[0]
+
+ if processed_result is None:
+ # Use our first valid processed result as a template container to hold our full results
+ processed_result = copy(processed)
+ cell_mode = processed_image.mode
+ cell_size = processed_image.size
+ processed_result.images = [Image.new(cell_mode, cell_size)]
+
+ image_cache[ix + iy * len(xs)] = processed_image
+ if include_lone_images:
+ processed_result.images.append(processed_image)
+ processed_result.all_prompts.append(processed.prompt)
+ processed_result.all_seeds.append(processed.seed)
+ processed_result.infotexts.append(processed.infotexts[0])
+ except:
+ image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
+
+ if swap_axes_processing_order:
for ix, x in enumerate(xs):
- state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
-
- processed:Processed = cell(x, y)
- try:
- # this dereference will throw an exception if the image was not processed
- # (this happens in cases such as if the user stops the process from the UI)
- processed_image = processed.images[0]
-
- if processed_result is None:
- # Use our first valid processed result as a template container to hold our full results
- processed_result = copy(processed)
- cell_mode = processed_image.mode
- cell_size = processed_image.size
- processed_result.images = [Image.new(cell_mode, cell_size)]
-
- image_cache.append(processed_image)
- if include_lone_images:
- processed_result.images.append(processed_image)
- processed_result.all_prompts.append(processed.prompt)
- processed_result.all_seeds.append(processed.seed)
- processed_result.infotexts.append(processed.infotexts[0])
- except:
- image_cache.append(Image.new(cell_mode, cell_size))
+ for iy, y in enumerate(ys):
+ process_cell(x, y, ix, iy)
+ else:
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, ix, iy)
if not processed_result:
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
@@ -405,6 +416,11 @@ class Script(scripts.Script):
grid_infotext = [None]
+ # If one of the axes is very slow to change between (like SD model
+ # checkpoint), then make sure it is in the outer iteration of the nested
+ # `for` loop.
+ swap_axes_processing_order = x_opt.cost > y_opt.cost
+
def cell(x, y):
if shared.state.interrupted:
return Processed(p, [], p.seed, "")
@@ -443,7 +459,8 @@ class Script(scripts.Script):
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell,
draw_legend=draw_legend,
- include_lone_images=include_lone_images
+ include_lone_images=include_lone_images,
+ swap_axes_processing_order=swap_axes_processing_order
)
if opts.grid_save:
--
cgit v1.2.3
From 2144c2eb7f5842caed1227d4ec7e659c79a84ce9 Mon Sep 17 00:00:00 2001
From: space-nuko <24979496+space-nuko@users.noreply.github.com>
Date: Sun, 15 Jan 2023 21:41:58 -0800
Subject: Add swap axes button for XY Grid
---
scripts/xy_grid.py | 26 ++++++++++++++++++++------
style.css | 10 ++++++++++
2 files changed, 30 insertions(+), 6 deletions(-)
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 13a3a046..99a660c1 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -23,6 +23,9 @@ import os
import re
+up_down_arrow_symbol = "\u2195\ufe0f"
+
+
def apply_field(field):
def fun(p, x, xs):
setattr(p, field, x)
@@ -293,17 +296,28 @@ class Script(scripts.Script):
current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
with gr.Row():
- x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
- x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
-
- with gr.Row():
- y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
- y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
+ with gr.Column(scale=1, elem_id="xy_grid_button_column"):
+ swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes")
+ with gr.Column(scale=19):
+ with gr.Row():
+ x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
+ x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
+
+ with gr.Row():
+ y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
+ y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ def swap_axes(x_type, x_values, y_type, y_values):
+ nonlocal current_axis_options
+ return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values
+
+ swap_args = [x_type, x_values, y_type, y_values]
+ swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
+
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
diff --git a/style.css b/style.css
index 78fa9838..1fddfcc2 100644
--- a/style.css
+++ b/style.css
@@ -717,6 +717,16 @@ footer {
line-height: 2.4em;
}
+#xy_grid_button_column {
+ min-width: 38px !important;
+}
+
+#xy_grid_button_column button {
+ height: 100%;
+ margin-bottom: 0.7em;
+ margin-left: 1em;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
--
cgit v1.2.3
From 972f5785073b8ba5957add72debd74fc56ee9329 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 09:27:52 +0300
Subject: fix problems related to checkpoint/VAE switching in XY plot
---
scripts/xy_grid.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 13a3a046..0cdfa952 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -263,14 +263,12 @@ class SharedSettingsStackHelper(object):
def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
- self.model = shared.sd_model
self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb):
- modules.sd_models.reload_model_weights(self.model)
-
opts.data["sd_vae"] = self.vae
- modules.sd_vae.reload_vae_weights(self.model)
+ modules.sd_models.reload_model_weights()
+ modules.sd_vae.reload_vae_weights()
hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength()
--
cgit v1.2.3
From 064983c0adb00cd9e88d2f06f66c9a1d5bc116c3 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 12:56:30 +0300
Subject: return an option to hide progressbar
---
javascript/progressbar.js | 1 +
modules/shared.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 5072c13f..da6709bc 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -121,6 +121,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var divProgress = document.createElement('div')
divProgress.className='progressDiv'
+ divProgress.style.display = opts.show_progressbar ? "" : "none"
var divInner = document.createElement('div')
divInner.className='progress'
diff --git a/modules/shared.py b/modules/shared.py
index c5fc250e..483c4c62 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -451,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), {
}))
options_templates.update(options_section(('ui', "Live previews"), {
+ "show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
--
cgit v1.2.3
From 55947857f035040d00249f02b17e39370033a99b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 17:36:56 +0300
Subject: add a button for XY Plot to fill in available values for axes that
support this
---
javascript/hints.js | 1 +
scripts/xy_grid.py | 101 ++++++++++++++++++++++++++++++++++------------------
style.css | 12 +------
3 files changed, 68 insertions(+), 46 deletions(-)
diff --git a/javascript/hints.js b/javascript/hints.js
index 244bfde2..fa5e5ae8 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -20,6 +20,7 @@ titles = {
"\u{1f4be}": "Save style",
"\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt",
+ "\u{1f4d2}": "Paste available values into the field",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index e06c11cb..bf4ba92f 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, paths, sd_samplers, processing
+from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
@@ -22,8 +22,9 @@ import glob
import os
import re
+from modules.ui_components import ToolButton
-up_down_arrow_symbol = "\u2195\ufe0f"
+fill_values_symbol = "\U0001f4d2" # 📒
def apply_field(field):
@@ -178,34 +179,49 @@ def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
-AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])
-AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"])
+
+class AxisOption:
+ def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
+ self.label = label
+ self.type = type
+ self.apply = apply
+ self.format_value = format_value
+ self.confirm = confirm
+ self.cost = cost
+ self.choices = choices
+ self.is_img2img = False
+
+
+class AxisOptionImg2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_img2img = False
axis_options = [
- AxisOption("Nothing", str, do_nothing, format_nothing, None, 0),
- AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0),
- AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0),
- AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0),
- AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0),
- AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0),
- AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0),
- AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0),
- AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0),
- AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2),
- AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0),
- AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0),
- AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0),
- AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0),
- AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0),
- AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0),
- AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0),
- AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0),
- AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0),
- AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0),
- AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7),
- AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0),
+ AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
+ AxisOption("Seed", int, apply_field("seed")),
+ AxisOption("Var. seed", int, apply_field("subseed")),
+ AxisOption("Var. strength", float, apply_field("subseed_strength")),
+ AxisOption("Steps", int, apply_field("steps")),
+ AxisOption("CFG Scale", float, apply_field("cfg_scale")),
+ AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
+ AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
+ AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
+ AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
+ AxisOption("Sigma Churn", float, apply_field("s_churn")),
+ AxisOption("Sigma min", float, apply_field("s_tmin")),
+ AxisOption("Sigma max", float, apply_field("s_tmax")),
+ AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Eta", float, apply_field("eta")),
+ AxisOption("Clip skip", int, apply_clip_skip),
+ AxisOption("Denoising", float, apply_field("denoising_strength")),
+ AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]),
+ AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
+ AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
+ AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
]
@@ -262,7 +278,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
if not processed_result:
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
- return Processed()
+ return Processed(p, [])
grid = images.image_grid(image_cache, rows=len(ys))
if draw_legend:
@@ -302,23 +318,25 @@ class Script(scripts.Script):
return "X/Y plot"
def ui(self, is_img2img):
- current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
+ current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img]
with gr.Row():
- with gr.Column(scale=1, elem_id="xy_grid_button_column"):
- swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes")
with gr.Column(scale=19):
with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
+ fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
-
- draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
- include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
- no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
+
+ with gr.Row(variant="compact"):
+ draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
+ include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
+ no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
def swap_axes(x_type, x_values, y_type, y_values):
nonlocal current_axis_options
@@ -327,6 +345,19 @@ class Script(scripts.Script):
swap_args = [x_type, x_values, y_type, y_values]
swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
+ def fill(x_type):
+ axis = axis_options[x_type]
+ return ", ".join(axis.choices()) if axis.choices else gr.update()
+
+ fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
+ fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
+
+ def select_axis(x_type):
+ return gr.Button.update(visible=axis_options[x_type].choices is not None)
+
+ x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
+ y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
+
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
diff --git a/style.css b/style.css
index 1fddfcc2..97f9402a 100644
--- a/style.css
+++ b/style.css
@@ -644,7 +644,7 @@ canvas[key="mask"] {
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
- margin: 0.55em 0;
+ margin: 0.55em 0.7em 0.55em 0;
}
#quicksettings .gr-button-tool{
@@ -717,16 +717,6 @@ footer {
line-height: 2.4em;
}
-#xy_grid_button_column {
- min-width: 38px !important;
-}
-
-#xy_grid_button_column button {
- height: 100%;
- margin-bottom: 0.7em;
- margin-left: 1em;
-}
-
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
--
cgit v1.2.3
From 52f6e94338f31c286361802b08ee5210b8244141 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 20:13:23 +0300
Subject: add --skip-install option to prevent running pip in launch.py and
speedup launch a bit
---
launch.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/launch.py b/launch.py
index bcbb792c..715427fd 100644
--- a/launch.py
+++ b/launch.py
@@ -14,6 +14,7 @@ python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
+skip_install = False
def commit_hash():
@@ -89,6 +90,9 @@ def run_python(code, desc=None, errdesc=None):
def run_pip(args, desc=None):
+ if skip_install:
+ return
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
@@ -173,6 +177,8 @@ def run_extensions_installers(settings_file):
def prepare_environment():
+ global skip_install
+
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -206,6 +212,7 @@ def prepare_environment():
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
+ sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv
--
cgit v1.2.3
From 9991967f40120b88a1dc925fdf7d747d5e016888 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 22:59:46 +0300
Subject: Add a check and explanation for tensor with all NaNs.
---
modules/devices.py | 28 ++++++++++++++++++++++++++++
modules/processing.py | 3 +++
modules/sd_samplers.py | 2 ++
3 files changed, 33 insertions(+)
diff --git a/modules/devices.py b/modules/devices.py
index caeb0276..6f034948 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -106,6 +106,33 @@ def autocast(disable=False):
return torch.autocast("cuda")
+class NansException(Exception):
+ pass
+
+
+def test_for_nans(x, where):
+ from modules import shared
+
+ if not torch.all(torch.isnan(x)).item():
+ return
+
+ if where == "unet":
+ message = "A tensor with all NaNs was produced in Unet."
+
+ if not shared.cmd_opts.no_half:
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
+
+ elif where == "vae":
+ message = "A tensor with all NaNs was produced in VAE."
+
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
+ else:
+ message = "A tensor with all NaNs was produced."
+
+ raise NansException(message)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
@@ -156,3 +183,4 @@ if has_mps():
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
+
diff --git a/modules/processing.py b/modules/processing.py
index 849f6b19..ab7b3b7d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -608,6 +608,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+ for x in x_samples_ddim:
+ devices.test_for_nans(x, "vae")
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 76e0e0d5..6261d1f7 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -351,6 +351,8 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ devices.test_for_nans(x_out, "unet")
+
if opts.live_preview_content == "Prompt":
store_latent(x_out[0:uncond.shape[0]])
elif opts.live_preview_content == "Negative prompt":
--
cgit v1.2.3
From e0e80050091ea7f58ae17c69f31d1b5de5e0ae20 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 16 Jan 2023 23:09:08 +0300
Subject: make StableDiffusionProcessing class not hold a reference to
shared.sd_model object
---
modules/processing.py | 9 +++++----
scripts/xy_grid.py | 1 -
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/modules/processing.py b/modules/processing.py
index ab7b3b7d..9c3673de 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return image_conditioning
-class StableDiffusionProcessing():
+class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
@@ -102,7 +102,6 @@ class StableDiffusionProcessing():
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
- self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
@@ -156,6 +155,10 @@ class StableDiffusionProcessing():
self.all_subseeds = None
self.iteration = 0
+ @property
+ def sd_model(self):
+ return shared.sd_model
+
def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@@ -236,7 +239,6 @@ class StableDiffusionProcessing():
raise NotImplementedError()
def close(self):
- self.sd_model = None
self.sampler = None
@@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights() # make onchange call for changing SD model
- p.sd_model = shared.sd_model
if k == 'sd_vae':
sd_vae.reload_vae_weights() # make onchange call for changing VAE
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index bf4ba92f..6629f5d5 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -86,7 +86,6 @@ def apply_checkpoint(p, x, xs):
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
- p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs):
--
cgit v1.2.3
From c091cf1b4acd2047644d3571bcbfd81c81b4c3af Mon Sep 17 00:00:00 2001
From: Nick Petalas
Date: Thu, 22 Dec 2022 22:28:10 +0000
Subject: upgrading torch, torchvision, xformers (windows) to use u117
---
launch.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/launch.py b/launch.py
index 715427fd..f51f23f7 100644
--- a/launch.py
+++ b/launch.py
@@ -179,7 +179,7 @@ def run_extensions_installers(settings_file):
def prepare_environment():
global skip_install
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -187,8 +187,6 @@ def prepare_environment():
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
- xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
-
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
@@ -239,7 +237,7 @@ def prepare_environment():
if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
- run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
+ run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
--
cgit v1.2.3
From eb2223340cfdd58efaa157662c279fbf6c90c7d9 Mon Sep 17 00:00:00 2001
From: fuggy <45698918+nonetrix@users.noreply.github.com>
Date: Mon, 16 Jan 2023 21:50:30 -0600
Subject: Fix typo
---
modules/errors.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/errors.py b/modules/errors.py
index a668c014..a10e8708 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -19,7 +19,7 @@ def display(e: Exception, task):
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
-The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
--
cgit v1.2.3
From c361b89026442f3412162657f330d500b803e052 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 17 Jan 2023 11:04:56 +0300
Subject: disable the new NaN check for the CI
---
launch.py | 2 ++
modules/devices.py | 3 +++
modules/shared.py | 1 +
3 files changed, 6 insertions(+)
diff --git a/launch.py b/launch.py
index 715427fd..5afb2956 100644
--- a/launch.py
+++ b/launch.py
@@ -286,6 +286,8 @@ def tests(test_dir):
sys.argv.append("./test/test_files/empty.pt")
if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test")
+ if "--disable-nan-check" not in sys.argv:
+ sys.argv.append("--disable-nan-check")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
diff --git a/modules/devices.py b/modules/devices.py
index 6f034948..206184fb 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -113,6 +113,9 @@ class NansException(Exception):
def test_for_nans(x, where):
from modules import shared
+ if shared.cmd_opts.disable_nan_check:
+ return
+
if not torch.all(torch.isnan(x)).item():
return
diff --git a/modules/shared.py b/modules/shared.py
index 483c4c62..a708f23c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -64,6 +64,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
--
cgit v1.2.3
From 4688bfff55dd6607e6608524fb219f97dc6fe8bb Mon Sep 17 00:00:00 2001
From: dan
Date: Tue, 17 Jan 2023 17:16:43 +0800
Subject: Add auto-sized cropping UI
---
modules/textual_inversion/preprocess.py | 38 ++++++++++++++++++++++++++++++---
modules/ui.py | 28 +++++++++++++++++++++++-
2 files changed, 62 insertions(+), 4 deletions(-)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 64abff4d..86c1cd33 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
@@ -20,7 +20,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height,
if process_caption_deepbooru:
deepbooru.model.start()
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
finally:
@@ -109,8 +109,32 @@ def split_pic(image, inverse_xy, width, height, overlap_ratio):
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
+# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
+def center_crop(image: Image, w: int, h: int):
+ iw, ih = image.size
+ if ih / h < iw / w:
+ sw = w * ih / h
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+ else:
+ sh = h * iw / w
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+ iw, ih = image.size
+ err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
+ try:
+ w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key= lambda wh: ((objective=='Maximize area')*wh[0]*wh[1], -err(*wh))
+ )
+ except ValueError:
+ return
+ return center_crop(image, w, h)
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -194,6 +218,14 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False
+ if process_multicrop:
+ cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+ if cropped is not None:
+ save_pic(cropped, index, params, existing_caption=existing_caption)
+ else:
+ print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
+ process_default_resize = False
+
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, params, existing_caption=existing_caption)
diff --git a/modules/ui.py b/modules/ui.py
index 20b66165..bbce9acd 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1226,6 +1226,7 @@ def create_ui():
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+ process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
@@ -1238,7 +1239,19 @@ def create_ui():
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
+
+ with gr.Column(visible=False) as process_multicrop_col:
+ gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
+ with gr.Row():
+ process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
+ process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
+ with gr.Row():
+ process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
+ process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
+ with gr.Row():
+ process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
+ process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1260,6 +1273,12 @@ def create_ui():
outputs=[process_focal_crop_row],
)
+ process_multicrop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_multicrop],
+ outputs=[process_multicrop_col],
+ )
+
def get_textual_inversion_template_names():
return sorted([x for x in textual_inversion.textual_inversion_templates])
@@ -1379,6 +1398,13 @@ def create_ui():
process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
+ process_multicrop,
+ process_multicrop_mindim,
+ process_multicrop_maxdim,
+ process_multicrop_minarea,
+ process_multicrop_maxarea,
+ process_multicrop_objective,
+ process_multicrop_threshold,
],
outputs=[
ti_output,
--
cgit v1.2.3
From aede265f1d6d512ca9e51a305e98a96a215366c4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 17 Jan 2023 13:57:55 +0300
Subject: Fix unable to find Real-ESRGAN model info error (AttributeError:
'NoneType' object has no attribute 'data_path') #6841 #5170
---
modules/realesrgan_model.py | 12 ++++--------
modules/upscaler.py | 1 +
2 files changed, 5 insertions(+), 8 deletions(-)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 3ac0b97a..47f70251 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -38,13 +38,13 @@ class UpscalerRealESRGAN(Upscaler):
return img
info = self.load_model(path)
- if not os.path.exists(info.data_path):
+ if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
- model_path=info.data_path,
+ model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
@@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler):
def load_model(self, path):
try:
- info = None
- for scaler in self.scalers:
- if scaler.data_path == path:
- info = scaler
+ info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
if info is None:
print(f"Unable to find model info: {path}")
return None
- model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
- info.data_path = model_file
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 231680cb..a5bf5acb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -95,6 +95,7 @@ class UpscalerData:
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
+ self.local_data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model
--
cgit v1.2.3
From 38b7186e6e3a4dffc93225308b822f0dae43a47d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 17 Jan 2023 14:15:47 +0300
Subject: update sending input event in java script to not cause exception in
browser https://github.com/gradio-app/gradio/issues/2981
---
javascript/edit-attention.js | 5 ++---
javascript/extensions.js | 2 +-
javascript/ui.js | 8 ++++++++
3 files changed, 11 insertions(+), 4 deletions(-)
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index b947cbec..ccc8344a 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -69,7 +69,6 @@ addEventListener('keydown', (event) => {
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
}
- // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
- // internal Svelte data binding remains in sync.
- target.dispatchEvent(new Event("input", { bubbles: true }));
+
+ updateInput(target)
});
diff --git a/javascript/extensions.js b/javascript/extensions.js
index 59179ca6..ac6e35b9 100644
--- a/javascript/extensions.js
+++ b/javascript/extensions.js
@@ -29,7 +29,7 @@ function install_extension_from_index(button, url){
textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
- textarea.dispatchEvent(new Event("input", { bubbles: true }))
+ updateInput(textarea)
gradioApp().querySelector('#install_extension_button').click()
}
diff --git a/javascript/ui.js b/javascript/ui.js
index ecf97cb3..954beadd 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -278,3 +278,11 @@ function restart_reload(){
return []
}
+
+// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
+// will only visible on web page and not sent to python.
+function updateInput(target){
+ let e = new Event("input", { bubbles: true })
+ Object.defineProperty(e, "target", {value: target})
+ target.dispatchEvent(e);
+}
--
cgit v1.2.3
From 6e08da2c315c346225aa834017f4e32cfc0de200 Mon Sep 17 00:00:00 2001
From: ddPn08
Date: Tue, 17 Jan 2023 23:50:41 +0900
Subject: Add `--vae-dir` argument
---
modules/sd_vae.py | 7 +++++++
modules/shared.py | 1 +
2 files changed, 8 insertions(+)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index b2af2ce7..da1bf15c 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -72,6 +72,13 @@ def refresh_vae_list():
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
]
+ if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
+ paths += [
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
+ ]
+
candidates = []
for path in paths:
candidates += glob.iglob(path, recursive=True)
diff --git a/modules/shared.py b/modules/shared.py
index a708f23c..a1345ad3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -26,6 +26,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
+parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with stable VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
--
cgit v1.2.3
From 5e15a0b422981c0b5484885d0b4d28af6913c76f Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Tue, 17 Jan 2023 11:42:44 -0500
Subject: Changed params.txt save to after manual init call
---
modules/processing.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/modules/processing.py b/modules/processing.py
index 9c3673de..4a1f033e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -538,10 +538,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.process(p)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
infotexts = []
output_images = []
@@ -572,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
if state.job_count == -1:
state.job_count = p.n_iter
--
cgit v1.2.3
From 3a0d6b77295162146d0a8d04278804334da6f1b4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 17 Jan 2023 23:54:23 +0300
Subject: make it so that PNG images with EXIF do not lose parameters in PNG
info tab
---
modules/images.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/modules/images.py b/modules/images.py
index c3a5fc8b..3b1c5f34 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -605,8 +605,9 @@ def read_info_from_image(image):
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
- items['exif comment'] = exif_comment
- geninfo = exif_comment
+ if exif_comment:
+ items['exif comment'] = exif_comment
+ geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
--
cgit v1.2.3
From d906f87043d809e6d4d8de3c9926e184169b330f Mon Sep 17 00:00:00 2001
From: ddPn08
Date: Wed, 18 Jan 2023 07:52:10 +0900
Subject: fix typo
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/shared.py b/modules/shared.py
index a1345ad3..a42279ec 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -26,7 +26,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
-parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with stable VAE files")
+parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
--
cgit v1.2.3
From a255dac4f8c5ee11c15b634563d3df513f1834b4 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Thu, 12 Jan 2023 08:00:38 -0500
Subject: Fix cumsum for MPS in newer torch
The prior fix assumed that testing int16 was enough to determine if a fix is needed, but a recent fix for cumsum has int16 working but not bool.
---
modules/devices.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/modules/devices.py b/modules/devices.py
index caeb0276..ac3ae0c9 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -139,8 +139,10 @@ orig_Tensor_cumsum = torch.Tensor.cumsum
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
- if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
+ if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
+ elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
@@ -151,8 +153,9 @@ if has_mps():
torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix
elif version.parse(torch.__version__) > version.parse("1.13.1"):
- if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
- torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
- torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
+ cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
+ cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
+ torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
+ torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
--
cgit v1.2.3
From dac59b9b073f86508d3ec787ff731af2e101fbcc Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 18 Jan 2023 06:13:45 +0300
Subject: return progress percentage to title bar
---
javascript/progressbar.js | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index da6709bc..b8473ebf 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -106,6 +106,19 @@ function formatTime(secs){
}
}
+function setTitle(progress){
+ var title = 'Stable Diffusion'
+
+ if(opts.show_progress_in_title && progress){
+ title = '[' + progress.trim() + '] ' + title;
+ }
+
+ if(document.title != title){
+ document.title = title;
+ }
+}
+
+
function randomId(){
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
}
@@ -133,6 +146,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
parentGallery.insertBefore(livePreview, gallery)
var removeProgressBar = function(){
+ setTitle("")
parentProgressbar.removeChild(divProgress)
parentGallery.removeChild(livePreview)
atEnd()
@@ -165,6 +179,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
progressText += " " + res.textinfo
}
+ setTitle(progressText)
divInner.textContent = progressText
var elapsedFromStart = (new Date() - dateStart) / 1000
--
cgit v1.2.3
From d8f8bcb821fa62e943eb95ee05b8a949317326fe Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 18 Jan 2023 13:20:47 +0300
Subject: enable progressbar without gallery
---
javascript/progressbar.js | 24 +++++++++++++++---------
style.css | 19 +++----------------
2 files changed, 18 insertions(+), 25 deletions(-)
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index b8473ebf..18c771a2 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -130,7 +130,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var dateStart = new Date()
var wasEverActive = false
var parentProgressbar = progressbarContainer.parentNode
- var parentGallery = gallery.parentNode
+ var parentGallery = gallery ? gallery.parentNode : null
var divProgress = document.createElement('div')
divProgress.className='progressDiv'
@@ -141,14 +141,16 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
divProgress.appendChild(divInner)
parentProgressbar.insertBefore(divProgress, progressbarContainer)
- var livePreview = document.createElement('div')
- livePreview.className='livePreview'
- parentGallery.insertBefore(livePreview, gallery)
+ if(parentGallery){
+ var livePreview = document.createElement('div')
+ livePreview.className='livePreview'
+ parentGallery.insertBefore(livePreview, gallery)
+ }
var removeProgressBar = function(){
setTitle("")
parentProgressbar.removeChild(divProgress)
- parentGallery.removeChild(livePreview)
+ if(parentGallery) parentGallery.removeChild(livePreview)
atEnd()
}
@@ -168,6 +170,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
progressText = ""
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
+ divInner.style.background = res.progress ? "" : "transparent"
if(res.progress > 0){
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
@@ -175,11 +178,15 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
if(res.eta){
progressText += " ETA: " + formatTime(res.eta)
- } else if(res.textinfo){
- progressText += " " + res.textinfo
}
+
setTitle(progressText)
+
+ if(res.textinfo && res.textinfo.indexOf("\n") == -1){
+ progressText = res.textinfo + " " + progressText
+ }
+
divInner.textContent = progressText
var elapsedFromStart = (new Date() - dateStart) / 1000
@@ -197,8 +204,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
}
- if(res.live_preview){
-
+ if(res.live_preview && gallery){
var rect = gallery.getBoundingClientRect()
if(rect.width){
livePreview.style.width = rect.width + "px"
diff --git a/style.css b/style.css
index 97f9402a..b1d47df6 100644
--- a/style.css
+++ b/style.css
@@ -290,26 +290,12 @@ input[type="range"]{
min-height: unset !important;
}
-#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
- position: absolute;
- z-index: 1000;
- right: 0;
- padding-left: 5px;
- padding-right: 5px;
- display: block;
-}
-
-#txt2img_progress_row, #img2img_progress_row{
- margin-bottom: 10px;
- margin-top: -18px;
-}
-
.progressDiv{
position: absolute;
height: 20px;
top: -20px;
background: #b4c0cc;
- border-radius: 8px !important;
+ border-radius: 3px !important;
}
.dark .progressDiv{
@@ -325,9 +311,10 @@ input[type="range"]{
line-height: 20px;
padding: 0 8px 0 0;
text-align: right;
- border-radius: 8px;
+ border-radius: 3px;
overflow: visible;
white-space: nowrap;
+ padding: 0 0.5em;
}
.livePreview{
--
cgit v1.2.3
From 0c5913b9c28017523011ac6bf83b38ed5de8c11f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 18 Jan 2023 14:14:50 +0300
Subject: re-enable image dragging on non-firefox browsers
---
javascript/imageviewer.js | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index 1f29ad7b..aac2ee82 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -148,7 +148,15 @@ function showGalleryImage() {
if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer'
e.style.userSelect='none'
- e.addEventListener('mousedown', function (evt) {
+
+ var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
+
+ // For Firefox, listening on click first switched to next image then shows the lightbox.
+ // If you know how to fix this without switching to mousedown event, please.
+ // For other browsers the event is click to make it possiblr to drag picture.
+ var event = isFirefox ? 'mousedown' : 'click'
+
+ e.addEventListener(event, function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
--
cgit v1.2.3
From 6faae2323963f9b0e0086a85b9d0472a24fbaa73 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 18 Jan 2023 14:33:09 +0300
Subject: repair broken quicksettings when some form-requiring options are
added to it
---
modules/ui.py | 2 +-
style.css | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index e1f98d23..6d70a795 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1659,7 +1659,7 @@ def create_ui():
interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings"):
+ with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
diff --git a/style.css b/style.css
index b6239142..fb58b6c3 100644
--- a/style.css
+++ b/style.css
@@ -530,7 +530,7 @@ input[type="range"]{
gap: 0.4em;
}
-#quicksettings > div{
+#quicksettings > div, #quicksettings > fieldset{
max-width: 24em;
min-width: 24em;
padding: 0;
--
cgit v1.2.3
From 05a779b0cd7cfe98d525e70362154a8a4d8b5e09 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Wed, 18 Jan 2023 09:47:38 -0500
Subject: fix syntax error
---
javascript/localization.js | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/javascript/localization.js b/javascript/localization.js
index bf9e1506..1a5a1dbb 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -11,7 +11,7 @@ ignore_ids_for_localization={
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION',
- img2img_styles 'OPTION',
+ img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
--
cgit v1.2.3
From 8683427bd9315d2fda0d2f9644c8b1f6a182da55 Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Wed, 18 Jan 2023 20:16:52 +0300
Subject: Process interrogation on all img2img subtabs
---
javascript/ui.js | 7 +++++++
modules/ui.py | 50 +++++++++++++++++++++++++++++++++++++++++++-------
2 files changed, 50 insertions(+), 7 deletions(-)
diff --git a/javascript/ui.js b/javascript/ui.js
index 954beadd..7d3d57a3 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -109,6 +109,13 @@ function get_extras_tab_index(){
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
}
+function get_img2img_tab_index() {
+ let res = args_to_array(arguments)
+ res.splice(-2)
+ res[0] = get_tab_index('mode_img2img')
+ return res
+}
+
function create_submit_args(args){
res = []
for(var i=0;i
Date: Wed, 18 Jan 2023 20:29:44 +0300
Subject: make live previews not obscure multiselect dropdowns
---
style.css | 1 +
1 file changed, 1 insertion(+)
diff --git a/style.css b/style.css
index fb58b6c3..61279a19 100644
--- a/style.css
+++ b/style.css
@@ -148,6 +148,7 @@
#txt2img_styles ul, #img2img_styles ul{
max-height: 35em;
+ z-index: 2000;
}
.gr-form{
--
cgit v1.2.3
From 924e222004ab54273806c5f2ca7a0e7cfa76ad83 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 18 Jan 2023 23:04:24 +0300
Subject: add option to show/hide warnings removed hiding warnings from LDSR
fixed/reworked few places that produced warnings
---
extensions-builtin/LDSR/ldsr_model_arch.py | 3 --
javascript/localization.js | 2 +-
modules/hypernetworks/hypernetwork.py | 7 ++++-
modules/sd_hijack.py | 8 ------
modules/sd_hijack_checkpoint.py | 38 +++++++++++++++++++++++++-
modules/shared.py | 1 +
modules/textual_inversion/textual_inversion.py | 6 +++-
modules/ui.py | 31 ++++++++++++---------
scripts/prompts_from_file.py | 2 +-
style.css | 5 ++--
10 files changed, 71 insertions(+), 32 deletions(-)
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 0ad49f4e..bc11cc6e 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -1,7 +1,6 @@
import os
import gc
import time
-import warnings
import numpy as np
import torch
@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack
-warnings.filterwarnings("ignore", category=UserWarning)
-
cached_ldsr_model: torch.nn.Module = None
diff --git a/javascript/localization.js b/javascript/localization.js
index bf9e1506..1a5a1dbb 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -11,7 +11,7 @@ ignore_ids_for_localization={
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION',
- img2img_styles 'OPTION',
+ img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c963fc40..74e78582 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close()
hypernetwork.eval()
#report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6b0d95af..870eba88 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -69,12 +69,6 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-def fix_checkpoint():
- ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
-
-
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -106,8 +100,6 @@ class StableDiffusionModelHijack:
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
-
- fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
index 5712972f..2604d969 100644
--- a/modules/sd_hijack_checkpoint.py
+++ b/modules/sd_hijack_checkpoint.py
@@ -1,10 +1,46 @@
from torch.utils.checkpoint import checkpoint
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.openaimodel
+
+
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
+
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
+
def ResBlock_forward(self, x, emb):
- return checkpoint(self._forward, x, emb)
\ No newline at end of file
+ return checkpoint(self._forward, x, emb)
+
+
+stored = []
+
+
+def add():
+ if len(stored) != 0:
+ return
+
+ stored.extend([
+ ldm.modules.attention.BasicTransformerBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
+ ])
+
+ ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
+
+
+def remove():
+ if len(stored) == 0:
+ return
+
+ ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
+
+ stored.clear()
+
diff --git a/modules/shared.py b/modules/shared.py
index a708f23c..ddb97f99 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
}))
options_templates.update(options_section(('system', "System"), {
+ "show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7e4a6d24..5a7be422 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -15,7 +15,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
return embedding, filename
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
diff --git a/modules/ui.py b/modules/ui.py
index 6d70a795..25818fb0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -11,6 +11,7 @@ import tempfile
import time
import traceback
from functools import partial, reduce
+import warnings
import gradio as gr
import gradio.routes
@@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
+warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
@@ -417,17 +420,16 @@ def apply_setting(key, value):
return value
-def update_generation_info(args):
- generation_info, html_info, img_index = args
+def update_generation_info(generation_info, html_info, img_index):
try:
generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
except Exception:
pass
# if the json parse or anything else fails, just return the old html_info
- return html_info
+ return html_info, gr.update()
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
@@ -508,10 +510,9 @@ Requested path was: {f}
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
+ _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
)
save.click(
@@ -526,7 +527,8 @@ Requested path was: {f}
outputs=[
download_files,
html_log,
- ]
+ ],
+ show_progress=False,
)
save_zip.click(
@@ -588,7 +590,7 @@ def create_ui():
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
@@ -768,7 +770,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
@@ -1768,7 +1770,10 @@ def create_ui():
if saved_value is None:
ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value):
- print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ pass
+
+ # this warning is generally not useful;
+ # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value)
if init_field is not None:
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index f3e711d7..76dc5778 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -116,7 +116,7 @@ class Script(scripts.Script):
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
- file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file"))
+ file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
diff --git a/style.css b/style.css
index 61279a19..0845519a 100644
--- a/style.css
+++ b/style.css
@@ -299,9 +299,8 @@ input[type="range"]{
}
/* more gradio's garbage cleanup */
-.min-h-\[4rem\] {
- min-height: unset !important;
-}
+.min-h-\[4rem\] { min-height: unset !important; }
+.min-h-\[6rem\] { min-height: unset !important; }
.progressDiv{
position: absolute;
--
cgit v1.2.3
From b186d44dcd0df9d127a663b297334a5bd8258b58 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 18 Jan 2023 23:20:23 +0300
Subject: use DDIM in hires fix is the sampler is PLMS
---
modules/processing.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/modules/processing.py b/modules/processing.py
index 9c3673de..8c18ac53 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -857,7 +857,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
--
cgit v1.2.3
From bb0978ecfd3177d0bfd7cacd1ac8796d7eec2d79 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 00:44:51 +0300
Subject: fix hires fix ui weirdness caused by gradio update
---
modules/ui.py | 6 +++---
style.css | 2 +-
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index 8b7f1dfb..09a3c92e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -638,7 +638,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
+ with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
@@ -646,12 +646,12 @@ def create_ui():
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
- with FormRow(elem_id="txt2img_hires_fix_row2"):
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
diff --git a/style.css b/style.css
index 0845519a..a6abd93d 100644
--- a/style.css
+++ b/style.css
@@ -686,7 +686,7 @@ footer {
#txt2img_checkboxes, #img2img_checkboxes{
margin-bottom: 0.5em;
}
-#txt2img_checkboxes > div > div, #img2img_checkboxes > div > div{
+#txt2img_checkboxes > div, #img2img_checkboxes > div{
flex: 0;
white-space: nowrap;
min-width: auto;
--
cgit v1.2.3
From 956263b8a4f0393dcb47ed497f367717add4f0e9 Mon Sep 17 00:00:00 2001
From: facu
Date: Wed, 18 Jan 2023 19:15:53 -0300
Subject: fixing error using lspci on macOsX
---
webui.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index 6e07778f..1edf921d 100755
--- a/webui.sh
+++ b/webui.sh
@@ -165,7 +165,7 @@ else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
- gpu_info=$(lspci | grep VGA)
+ gpu_info=$(lspci 2>/dev/null | grep VGA)
if echo "$gpu_info" | grep -q "AMD"
then
if [[ -z "${TORCH_COMMAND}" ]]
--
cgit v1.2.3
From 99207bc816d027b522e1c49001748c63fd426b53 Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Wed, 18 Jan 2023 19:13:15 -0500
Subject: check model name values are set before merging
---
modules/extras.py | 22 ++++++++++++++++++----
1 file changed, 18 insertions(+), 4 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 22668fcd..29eb1f07 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -287,10 +287,19 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
+ if not primary_model_name:
+ shared.state.textinfo = "Failed: Merging requires a primary model."
+ shared.state.end()
+ return ["Failed: Merging requires a primary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
+
+ if not secondary_model_name:
+ shared.state.textinfo = "Failed: Merging requires a secondary model."
+ shared.state.end()
+ return ["Failed: Merging requires a secondary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
- result_is_inpainting_model = False
theta_funcs = {
"Weighted sum": (None, weighted_sum),
@@ -298,10 +307,15 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
}
theta_func1, theta_func2 = theta_funcs[interp_method]
- if theta_func1 and not tertiary_model_info:
+ tertiary_model_info = None
+ if theta_func1 and not tertiary_model_name:
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
shared.state.end()
- return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ else:
+ tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
+
+ result_is_inpainting_model = False
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...")
--
cgit v1.2.3
From 26a6a78b16f88a6f88f4cca3f378db3b83fc94f8 Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Wed, 18 Jan 2023 21:21:52 -0500
Subject: only lookup tertiary model if theta_func1 is set
---
modules/extras.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 29eb1f07..88eea22e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -307,13 +307,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
}
theta_func1, theta_func2 = theta_funcs[interp_method]
- tertiary_model_info = None
if theta_func1 and not tertiary_model_name:
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
shared.state.end()
return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
- else:
- tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
+
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
--
cgit v1.2.3
From 308b51012a5def38edb1c2e127e736c43aa6e1a3 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 08:41:37 +0300
Subject: fix an unlikely division by 0 error
---
modules/progress.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/modules/progress.py b/modules/progress.py
index 3327b883..f9e005d3 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -67,10 +67,13 @@ def progressapi(req: ProgressRequest):
progress = 0
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+ job_count, job_no = shared.state.job_count, shared.state.job_no
+ sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
+
+ if job_count > 0:
+ progress += job_no / job_count
+ if sampling_steps > 0:
+ progress += 1 / job_count * sampling_step / sampling_steps
progress = min(progress, 1)
--
cgit v1.2.3
From 7cfc6450305125683799208fb7bc27c0b12586b3 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 08:53:50 +0300
Subject: eliminate repetition of code in #6910
---
modules/extras.py | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 88eea22e..367c15cc 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -278,6 +278,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
shared.state.begin()
shared.state.job = 'model-merge'
+ def fail(message):
+ shared.state.textinfo = message
+ shared.state.end()
+ return [message, *[gr.update() for _ in range(4)]]
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -288,16 +293,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
return theta0 + (alpha * theta1_2_diff)
if not primary_model_name:
- shared.state.textinfo = "Failed: Merging requires a primary model."
- shared.state.end()
- return ["Failed: Merging requires a primary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return fail("Failed: Merging requires a primary model.")
primary_model_info = sd_models.checkpoints_list[primary_model_name]
if not secondary_model_name:
- shared.state.textinfo = "Failed: Merging requires a secondary model."
- shared.state.end()
- return ["Failed: Merging requires a secondary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return fail("Failed: Merging requires a secondary model.")
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
@@ -308,9 +309,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_name:
- shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
- shared.state.end()
- return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
--
cgit v1.2.3
From c7e50425f63c07242068f8dcccce70a4ef28a17f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 09:25:37 +0300
Subject: add progress bar to modelmerger
---
javascript/ui.js | 11 +++++++++++
modules/extras.py | 18 +++++++++++++++---
modules/progress.py | 2 +-
modules/ui.py | 13 ++++++++-----
style.css | 5 +++++
5 files changed, 40 insertions(+), 9 deletions(-)
diff --git a/javascript/ui.js b/javascript/ui.js
index 7d3d57a3..428375d4 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -172,6 +172,17 @@ function submit_img2img(){
return res
}
+function modelmerger(){
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
+
+ gradioApp().getElementById('modelmerger_result').innerHTML = ''
+
+ var res = create_submit_args(arguments)
+ res[0] = id
+ return res
+}
+
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:')
diff --git a/modules/extras.py b/modules/extras.py
index 367c15cc..034f28e4 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename)
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin()
shared.state.job = 'model-merge'
+ shared.state.job_count = 1
def fail(message):
shared.state.textinfo = message
shared.state.end()
- return [message, *[gr.update() for _ in range(4)]]
+ return [*[gr.update() for _ in range(4)], message]
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func1:
+ shared.state.job_count += 1
+
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+ shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
if key in theta_2:
@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
+
+ shared.state.sampling_step += 1
del theta_2
+ shared.state.nextjob()
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+ shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
if save_as_half:
theta_0[key] = theta_0[key].half()
+ shared.state.sampling_step += 1
+
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename)
+ shared.state.nextjob()
shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...")
@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/progress.py b/modules/progress.py
index f9e005d3..c69ecf3d 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):
if job_count > 0:
progress += job_no / job_count
- if sampling_steps > 0:
+ if sampling_steps > 0 and job_count > 0:
progress += 1 / job_count * sampling_step / sampling_steps
progress = min(progress, 1)
diff --git a/modules/ui.py b/modules/ui.py
index 09a3c92e..aeee7853 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1208,8 +1208,9 @@ def create_ui():
with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
- with gr.Column(variant='panel'):
- submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+ with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
+ with gr.Group(elem_id="modelmerger_results_panel"):
+ modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
@@ -1753,12 +1754,14 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list
- return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
modelmerger_merge.click(
- fn=modelmerger,
+ fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
+ _js='modelmerger',
inputs=[
+ dummy_component,
primary_model_name,
secondary_model_name,
tertiary_model_name,
@@ -1770,11 +1773,11 @@ def create_ui():
config_source,
],
outputs=[
- submit_result,
primary_model_name,
secondary_model_name,
tertiary_model_name,
component_dict['sd_model_checkpoint'],
+ modelmerger_result,
]
)
diff --git a/style.css b/style.css
index a6abd93d..32ba4753 100644
--- a/style.css
+++ b/style.css
@@ -737,6 +737,11 @@ footer {
line-height: 2.4em;
}
+#modelmerger_results_container{
+ margin-top: 1em;
+ overflow: visible;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
--
cgit v1.2.3
From 0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 10:39:51 +0300
Subject: allow baking in VAE in checkpoint merger tab do not save config if
it's the default for checkpoint merger tab change file naming scheme for
checkpoint merger tab allow just saving A without any merging for checkpoint
merger tab some stylistic changes for UI in checkpoint merger tab
---
javascript/hints.js | 1 +
javascript/ui.js | 2 -
modules/extras.py | 112 +++++++++++++++++++++++++++++++---------------------
modules/sd_vae.py | 9 ++++-
modules/shared.py | 3 +-
modules/ui.py | 17 ++++++--
style.css | 15 ++++---
7 files changed, 101 insertions(+), 58 deletions(-)
diff --git a/javascript/hints.js b/javascript/hints.js
index fa5e5ae8..e746e20d 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -92,6 +92,7 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
+ "No interpolation": "Result = A",
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
diff --git a/javascript/ui.js b/javascript/ui.js
index 428375d4..37788a3e 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -176,8 +176,6 @@ function modelmerger(){
var id = randomId()
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
- gradioApp().getElementById('modelmerger_result').innerHTML = ''
-
var res = create_submit_args(arguments)
res[0] = id
return res
diff --git a/modules/extras.py b/modules/extras.py
index 034f28e4..fe701a0e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
-from modules import processing, shared, images, devices, sd_models, sd_samplers
+from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -251,7 +251,8 @@ def run_pnginfo(image):
def create_config(ckpt_result, config_source, a, b, c):
def config(x):
- return sd_models.find_checkpoint_config(x) if x else None
+ res = sd_models.find_checkpoint_config(x) if x else None
+ return res if res != shared.sd_default_config else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
@@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename)
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
+chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin()
shared.state.job = 'model-merge'
- shared.state.job_count = 1
def fail(message):
shared.state.textinfo = message
@@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
+ def filename_weighed_sum():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ Ma = round(1 - multiplier, 2)
+ Mb = round(multiplier, 2)
+
+ return f"{Ma}({a}) + {Mb}({b})"
+
+ def filename_add_differnece():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ c = tertiary_model_info.model_name
+ M = round(multiplier, 2)
+
+ return f"{a} + {M}({b} - {c})"
+
+ def filename_nothing():
+ return primary_model_info.model_name
+
+ theta_funcs = {
+ "Weighted sum": (filename_weighed_sum, None, weighted_sum),
+ "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "No interpolation": (filename_nothing, None, None),
+ }
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
+
if not primary_model_name:
return fail("Failed: Merging requires a primary model.")
primary_model_info = sd_models.checkpoints_list[primary_model_name]
- if not secondary_model_name:
+ if theta_func2 and not secondary_model_name:
return fail("Failed: Merging requires a secondary model.")
-
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- theta_funcs = {
- "Weighted sum": (None, weighted_sum),
- "Add difference": (get_difference, add_difference),
- }
- theta_func1, theta_func2 = theta_funcs[interp_method]
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
if theta_func1 and not tertiary_model_name:
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
-
+
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
- shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ if theta_func2:
+ shared.state.textinfo = f"Loading B"
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ else:
+ theta_1 = None
if theta_func1:
- shared.state.job_count += 1
-
+ shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+ shared.state.textinfo = 'Merging B and C'
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
@@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
-
- chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
-
+ shared.state.textinfo = 'Merging A and B'
shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
- if 'model' in key and key in theta_1:
+ if theta_1 and 'model' in key and key in theta_1:
if key in chckpoint_dict_skip_on_merge:
continue
@@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
a = theta_0[key]
b = theta_1[key]
- shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.sampling_step += 1
- # I believe this part should be discarded, but I'll leave it for now until I am sure
- for key in theta_1.keys():
- if 'model' in key and key not in theta_0:
+ del theta_1
+
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
+ if bake_in_vae_filename is not None:
+ print(f"Baking in VAE from {bake_in_vae_filename}")
+ shared.state.textinfo = 'Baking in VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
- if key in chckpoint_dict_skip_on_merge:
- continue
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
- theta_0[key] = theta_1[key]
- if save_as_half:
- theta_0[key] = theta_0[key].half()
- del theta_1
+ del vae_dict
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = \
- primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
- secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
- interp_method.replace(" ", "_") + \
- '-merged.' + \
- ("inpainting." if result_is_inpainting_model else "") + \
- checkpoint_format
-
- filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+ filename = filename_generator() if custom_name == '' else custom_name
+ filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
shared.state.nextjob()
- shared.state.textinfo = f"Saving to {output_modelname}..."
+ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
- print("Checkpoint saved.")
- shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ print(f"Checkpoint saved to {output_modelname}.")
+ shared.state.textinfo = "Checkpoint saved"
shared.state.end()
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index da1bf15c..4ce238b8 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file):
return None, None
+def load_vae_dict(filename, map_location):
+ vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ return vae_dict_1
+
+
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file
# save_settings = False
@@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
- vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
diff --git a/modules/shared.py b/modules/shared.py
index 77e5e91c..29b28bff 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path
demo = None
+sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
diff --git a/modules/ui.py b/modules/ui.py
index aeee7853..4e381a49 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -1185,7 +1185,7 @@ def create_ui():
with gr.Column(variant='compact'):
gr.HTML(value="
A merger of the two checkpoints will be generated in your checkpoint directory.
")
- with FormRow():
+ with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@@ -1197,13 +1197,20 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
- interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
- config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+ with FormRow():
+ with gr.Column():
+ config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+
+ with gr.Column():
+ with FormRow():
+ bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
+ create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
@@ -1757,6 +1764,7 @@ def create_ui():
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
+ modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
modelmerger_merge.click(
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger',
@@ -1771,6 +1779,7 @@ def create_ui():
custom_name,
checkpoint_format,
config_source,
+ bake_in_vae,
],
outputs=[
primary_model_name,
diff --git a/style.css b/style.css
index 32ba4753..c10e32a1 100644
--- a/style.css
+++ b/style.css
@@ -641,6 +641,16 @@ canvas[key="mask"] {
margin: 0.6em 0em 0.55em 0;
}
+#modelmerger_results_container{
+ margin-top: 1em;
+ overflow: visible;
+}
+
+#modelmerger_models{
+ gap: 0;
+}
+
+
#quicksettings .gr-button-tool{
margin: 0;
}
@@ -737,11 +747,6 @@ footer {
line-height: 2.4em;
}
-#modelmerger_results_container{
- margin-top: 1em;
- overflow: visible;
-}
-
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
--
cgit v1.2.3
From 54674674b813894b908283531ddaab4ccfeac721 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 12:12:09 +0300
Subject: allow having at half precision when there is only one checkpoint in
merger tab
---
modules/extras.py | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index fe701a0e..d03f976e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -278,6 +278,13 @@ def create_config(ckpt_result, config_source, a, b, c):
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+def to_half(tensor, enable):
+ if enable and tensor.dtype == torch.float:
+ return tensor.half()
+
+ return tensor
+
+
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -400,8 +407,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
else:
theta_0[key] = theta_func2(a, b, multiplier)
- if save_as_half:
- theta_0[key] = theta_0[key].half()
+ theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1
@@ -416,10 +422,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
for key in vae_dict.keys():
theta_0_key = 'first_stage_model.' + key
if theta_0_key in theta_0:
- theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
del vae_dict
+ if save_as_half and not theta_func2:
+ for key in theta_0.keys():
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = filename_generator() if custom_name == '' else custom_name
--
cgit v1.2.3
From 18a09c7e0032e2e655269e8e2b4f1ca6ed0cc7d3 Mon Sep 17 00:00:00 2001
From: dan
Date: Thu, 19 Jan 2023 17:36:23 +0800
Subject: Simplification and bugfix
---
modules/textual_inversion/preprocess.py | 12 +++++-------
1 file changed, 5 insertions(+), 7 deletions(-)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 86c1cd33..454dcc36 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,13 +124,11 @@ def center_crop(image: Image, w: int, h: int):
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
iw, ih = image.size
err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
- try:
- w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
- if minarea <= w * h <= maxarea and err(w, h) <= threshold),
- key= lambda wh: ((objective=='Maximize area')*wh[0]*wh[1], -err(*wh))
- )
- except ValueError:
- return
+ w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
+ default=None
+ )
return center_crop(image, w, h)
--
cgit v1.2.3
From 2985b317d719f0f0580d2ff93f3008ccabb9c251 Mon Sep 17 00:00:00 2001
From: dan
Date: Thu, 19 Jan 2023 17:39:30 +0800
Subject: Fix of fix
---
modules/textual_inversion/preprocess.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 454dcc36..c0ac11d3 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,12 +124,12 @@ def center_crop(image: Image, w: int, h: int):
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
iw, ih = image.size
err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
- w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
+ wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
default=None
)
- return center_crop(image, w, h)
+ return wh and center_crop(image, *wh)
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
--
cgit v1.2.3
From b271e22f7ac1b2cabca8985b1e4437ab685a2c21 Mon Sep 17 00:00:00 2001
From: vt-idiot <81622808+vt-idiot@users.noreply.github.com>
Date: Thu, 19 Jan 2023 06:12:19 -0500
Subject: Update shared.py
`Witdth/Height` was driving me insane. -> `Width/Height`
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/shared.py b/modules/shared.py
index 29b28bff..2f366454 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -448,7 +448,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
- "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
--
cgit v1.2.3
From c12d7ddd725c485682c1caa025627c9ee936d743 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 15:58:32 +0300
Subject: add handling to some places in javascript that can potentially cause
issues #6898
---
.../javascript/prompt-bracket-checker.js | 10 ++++++----
javascript/progressbar.js | 9 +++++++--
2 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
index eccfb0f9..251a1f57 100644
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -93,10 +93,12 @@ function checkBrackets(evt) {
}
var shadowRootLoaded = setInterval(function() {
- var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
- if(shadowTextArea.length < 1) {
- return false;
- }
+ var sahdowRoot = document.querySelector('gradio-app').shadowRoot;
+ if(! sahdowRoot) return false;
+
+ var shadowTextArea = sahdowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
+ if(shadowTextArea.length < 1) return false;
+
clearInterval(shadowRootLoaded);
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 18c771a2..2514d2e2 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -81,8 +81,13 @@ function request(url, data, handler, errorHandler){
xhr.onreadystatechange = function () {
if (xhr.readyState === 4) {
if (xhr.status === 200) {
- var js = JSON.parse(xhr.responseText);
- handler(js)
+ try {
+ var js = JSON.parse(xhr.responseText);
+ handler(js)
+ } catch (error) {
+ console.error(error);
+ errorHandler()
+ }
} else{
errorHandler()
}
--
cgit v1.2.3
From 81276cde90ebecfab317cc62a0100d298c3c43c4 Mon Sep 17 00:00:00 2001
From: poiuty
Date: Thu, 19 Jan 2023 16:56:45 +0300
Subject: internal progress relative path
---
javascript/progressbar.js | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 2514d2e2..ff6d757b 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -160,7 +160,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
}
var fun = function(id_task, id_live_preview){
- request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
+ request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
if(res.completed){
removeProgressBar()
return
--
cgit v1.2.3
From d1ea518dea3d7584be2927cc486d15ec3e18ddb0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 18:07:37 +0300
Subject: remember the list of checkpoints after you press refresh button and
reload the page
---
modules/ui.py | 11 ++++++++++-
1 file changed, 10 insertions(+), 1 deletion(-)
diff --git a/modules/ui.py b/modules/ui.py
index af416d5f..0c5ba358 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1771,8 +1771,17 @@ def create_ui():
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+ def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
def get_settings_values():
- return [getattr(opts, key) for key in component_keys]
+ return [get_value_for_setting(key) for key in component_keys]
demo.load(
fn=get_settings_values,
--
cgit v1.2.3
From f2ae2529877072874ebaac0257fe4af48c5855a4 Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Thu, 19 Jan 2023 10:24:17 -0500
Subject: fixes minor typos around run_modelmerger
---
modules/extras.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index d03f976e..1218f88f 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -275,7 +275,7 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename)
-chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
def to_half(tensor, enable):
@@ -303,7 +303,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
- def filename_weighed_sum():
+ def filename_weighted_sum():
a = primary_model_info.model_name
b = secondary_model_info.model_name
Ma = round(1 - multiplier, 2)
@@ -311,7 +311,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
return f"{Ma}({a}) + {Mb}({b})"
- def filename_add_differnece():
+ def filename_add_difference():
a = primary_model_info.model_name
b = secondary_model_info.model_name
c = tertiary_model_info.model_name
@@ -323,8 +323,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
return primary_model_info.model_name
theta_funcs = {
- "Weighted sum": (filename_weighed_sum, None, weighted_sum),
- "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
+ "Add difference": (filename_add_difference, get_difference, add_difference),
"No interpolation": (filename_nothing, None, None),
}
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
@@ -362,7 +362,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = 'Merging B and C'
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
- if key in chckpoint_dict_skip_on_merge:
+ if key in checkpoint_dict_skip_on_merge:
continue
if 'model' in key:
@@ -387,7 +387,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
for key in tqdm.tqdm(theta_0.keys()):
if theta_1 and 'model' in key and key in theta_1:
- if key in chckpoint_dict_skip_on_merge:
+ if key in checkpoint_dict_skip_on_merge:
continue
a = theta_0[key]
--
cgit v1.2.3
From c1928cdd6194928af0f53f70c51d59479b7025e2 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 18:58:08 +0300
Subject: bring back short hashes to sd checkpoint selection
---
modules/sd_models.py | 15 +++++++++++----
modules/ui.py | 23 ++++++++++++-----------
2 files changed, 23 insertions(+), 15 deletions(-)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6a681cef..12083848 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -41,14 +41,16 @@ class CheckpointInfo:
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
- self.title = name
+ self.name = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
self.shorthash = self.sha256[0:10] if self.sha256 else None
- self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+
+ self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -56,13 +58,15 @@ class CheckpointInfo:
checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256]
self.register()
+ self.title = f'{self.name} [{self.shorthash}]'
+
return self.shorthash
@@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
def load_model_weights(model, checkpoint_info: CheckpointInfo):
+ title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
+ if checkpoint_info.title != title:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
cache_enabled = shared.opts.sd_checkpoint_cache > 0
diff --git a/modules/ui.py b/modules/ui.py
index 0c5ba358..13d80ae2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -439,7 +439,7 @@ def apply_setting(key, value):
opts.data_labels[key].onchange()
opts.save(shared.config_filename)
- return value
+ return getattr(opts, key)
def update_generation_info(generation_info, html_info, img_index):
@@ -597,6 +597,16 @@ def ordered_ui_categories():
yield category
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -1600,7 +1610,7 @@ def create_ui():
opts.save(shared.config_filename)
- return gr.update(value=value), opts.dumpjson()
+ return get_value_for_setting(key), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
@@ -1771,15 +1781,6 @@ def create_ui():
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
- def get_value_for_setting(key):
- value = getattr(opts, key)
-
- info = opts.data_labels[key]
- args = info.component_args() if callable(info.component_args) else info.component_args or {}
- args = {k: v for k, v in args.items() if k not in {'precision'}}
-
- return gr.update(value=value, **args)
-
def get_settings_values():
return [get_value_for_setting(key) for key in component_keys]
--
cgit v1.2.3
From 4599e8ad0acaae3f13bd3a7bef4db7632aac8504 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Thu, 19 Jan 2023 17:00:51 +0100
Subject: Environment variable on launch just for Navi cards
Setting HSA_OVERRIDE_GFX_VERSION=10.3.0 for all AMD cards seems to break compatibility for polaris and vega cards so it should just be enabled on Navi
---
webui.sh | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index 1edf921d..a35a5f35 100755
--- a/webui.sh
+++ b/webui.sh
@@ -172,7 +172,12 @@ else
then
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
fi
- HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ if echo "$gpu_info" | grep -q "Navi"
+ then
+ HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ else
+ exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ fi
else
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi
--
cgit v1.2.3
From 6073456c8348d15716b9bc5276d994fe8554e4ca Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 19 Jan 2023 20:39:03 +0300
Subject: write a comment for fix_checkpoint function
---
modules/sd_hijack.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 870eba88..f9652d21 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -69,6 +69,13 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+def fix_checkpoint():
+ """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
+ checkpoints to be added when not training (there's a warning)"""
+
+ pass
+
+
class StableDiffusionModelHijack:
fixes = None
comments = []
--
cgit v1.2.3
From c09fb3d8f1f71bc66d7c4cea603885619d6a1cd4 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Thu, 19 Jan 2023 19:21:02 +0100
Subject: Simplify GPU check
---
webui.sh | 25 ++++++++++---------------
1 file changed, 10 insertions(+), 15 deletions(-)
diff --git a/webui.sh b/webui.sh
index a35a5f35..aa4f875c 100755
--- a/webui.sh
+++ b/webui.sh
@@ -104,6 +104,12 @@ then
fi
# Check prerequisites
+gpu_info=$(lspci 2>/dev/null | grep VGA)
+if echo "$gpu_info" | grep -q "Navi"
+then
+ export HSA_OVERRIDE_GFX_VERSION=10.3.0
+fi
+
for preq in "${GIT}" "${python_cmd}"
do
if ! hash "${preq}" &>/dev/null
@@ -165,20 +171,9 @@ else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
- gpu_info=$(lspci 2>/dev/null | grep VGA)
- if echo "$gpu_info" | grep -q "AMD"
+ if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then
- if [[ -z "${TORCH_COMMAND}" ]]
- then
- export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
- fi
- if echo "$gpu_info" | grep -q "Navi"
- then
- HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
- else
- exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
- fi
- else
- exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
- fi
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+ fi
+ exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi
--
cgit v1.2.3
From 48045545d9a3f174621a62086812d9bbfb3ce1c2 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Thu, 19 Jan 2023 19:23:40 +0100
Subject: Small reformat of the GPU check
---
webui.sh | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/webui.sh b/webui.sh
index aa4f875c..ff410e15 100755
--- a/webui.sh
+++ b/webui.sh
@@ -109,6 +109,10 @@ if echo "$gpu_info" | grep -q "Navi"
then
export HSA_OVERRIDE_GFX_VERSION=10.3.0
fi
+if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+then
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+fi
for preq in "${GIT}" "${python_cmd}"
do
@@ -170,10 +174,6 @@ then
else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
- printf "\n%s\n" "${delimiter}"
- if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
- then
- export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
- fi
+ printf "\n%s\n" "${delimiter}"
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi
--
cgit v1.2.3
From 36364bd76c4634820e08070a287f0a5ad27c35f6 Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Thu, 19 Jan 2023 20:05:49 +0100
Subject: GFX env just for RDNA 1 and 2
This commit specifies which GPUs should use the GFX variable, RDNA 3 is excluded since it uses a newer GFX version
---
webui.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index ff410e15..91c95e47 100755
--- a/webui.sh
+++ b/webui.sh
@@ -105,7 +105,7 @@ fi
# Check prerequisites
gpu_info=$(lspci 2>/dev/null | grep VGA)
-if echo "$gpu_info" | grep -q "Navi"
+if echo "$gpu_info" | grep -qE "Navi (1|2)"
then
export HSA_OVERRIDE_GFX_VERSION=10.3.0
fi
--
cgit v1.2.3
From 912285ae64e4e1186feb54caf82b4a0b11c6cb7f Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Thu, 19 Jan 2023 23:42:12 +0100
Subject: Experimental support for Renoir
This adds the GFX version 9.0.0 in order to use Renoir GPUs with at least 4 GB of VRAM (it's possible to increase the virtual VRAM from the BIOS settings of some vendors). This will only work if the remaining ram is at least 12 GB to avoid the system to become unresponsive on launch.).
This change also changes the GPU check to a case statement to be able to add more GPUs efficiently.
---
webui.sh | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
diff --git a/webui.sh b/webui.sh
index 91c95e47..27933c04 100755
--- a/webui.sh
+++ b/webui.sh
@@ -105,10 +105,14 @@ fi
# Check prerequisites
gpu_info=$(lspci 2>/dev/null | grep VGA)
-if echo "$gpu_info" | grep -qE "Navi (1|2)"
-then
- export HSA_OVERRIDE_GFX_VERSION=10.3.0
-fi
+case "$gpu_info" in
+ *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
+ ;;
+ *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
+ ;;
+ *)
+ ;;
+esac
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
--
cgit v1.2.3
From 0684a6819dfaec40732271ca5ef32392c36f17ba Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Fri, 20 Jan 2023 00:21:05 +0100
Subject: Usage explanation for Renoir users
---
webui.sh | 3 +++
1 file changed, 3 insertions(+)
diff --git a/webui.sh b/webui.sh
index 27933c04..4da51880 100755
--- a/webui.sh
+++ b/webui.sh
@@ -109,6 +109,9 @@ case "$gpu_info" in
*"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
+ printf "\n%s\n" "${delimiter}"
+ printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
+ printf "\n%s\n" "${delimiter}"
;;
*)
;;
--
cgit v1.2.3
From fd651bd0bcceb4c746c86b202702bca029cbd6db Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Fri, 20 Jan 2023 00:21:51 +0100
Subject: Update webui.sh
---
webui.sh | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/webui.sh b/webui.sh
index 4da51880..d5e7b3c5 100755
--- a/webui.sh
+++ b/webui.sh
@@ -110,8 +110,8 @@ case "$gpu_info" in
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"
- printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
- printf "\n%s\n" "${delimiter}"
+ printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
+ printf "\n%s\n" "${delimiter}"
;;
*)
;;
--
cgit v1.2.3
From 6c7a50d783c4e406d8597f9cf354bb8128026f6c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 20 Jan 2023 08:36:30 +0300
Subject: remove some unnecessary logging to javascript console
---
javascript/hires_fix.js | 3 ---
modules/ui.py | 2 +-
2 files changed, 1 insertion(+), 4 deletions(-)
diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js
index 07fba549..0629475f 100644
--- a/javascript/hires_fix.js
+++ b/javascript/hires_fix.js
@@ -1,6 +1,5 @@
function setInactive(elem, inactive){
- console.log(elem)
if(inactive){
elem.classList.add('inactive')
} else{
@@ -9,8 +8,6 @@ function setInactive(elem, inactive){
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
- console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y)
-
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
diff --git a/modules/ui.py b/modules/ui.py
index 13d80ae2..eb45a128 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -532,7 +532,7 @@ Requested path was: {f}
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
- _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info, html_info],
outputs=[html_info, html_info],
)
--
cgit v1.2.3
From 98466da4bc312c0fa9c8cea4c825afc64194cb58 Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Fri, 20 Jan 2023 00:48:15 -0500
Subject: adds descriptions for merging methods in ui
---
modules/ui.py | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/modules/ui.py b/modules/ui.py
index eb45a128..ee434bde 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1190,10 +1190,19 @@ def create_ui():
outputs=[html, generation_info, html2],
)
+ def update_interp_description(value):
+ interp_description_css = "
{}
"
+ interp_descriptions = {
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
+ }
+ return interp_descriptions[value]
+
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact'):
- gr.HTML(value="
A merger of the two checkpoints will be generated in your checkpoint directory.
")
+ interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
@@ -1208,6 +1217,7 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
@@ -1903,6 +1913,9 @@ def create_ui():
with open(ui_config_file, "w", encoding="utf8") as file:
json.dump(ui_settings, file, indent=4)
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ interp_description.value = update_interp_description(interp_method.value)
+
return demo
--
cgit v1.2.3
From 20a59ab3b171f398abd09087108c1ed087dbea9b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 20 Jan 2023 10:18:41 +0300
Subject: move token counter to the location of the prompt, add token counting
for the negative prompt
---
.../javascript/prompt-bracket-checker.js | 45 +++++++++++----------
javascript/ui.js | 29 +++++++++----
modules/ui.py | 25 ++++++------
style.css | 47 ++++++++++++++--------
4 files changed, 87 insertions(+), 59 deletions(-)
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
index 251a1f57..4a85c8eb 100644
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -4,16 +4,10 @@
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
-function checkBrackets(evt) {
- textArea = evt.target;
- tabName = evt.target.parentElement.parentElement.id.split("_")[0];
- counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
-
- promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
-
- errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
- errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
- errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
+function checkBrackets(evt, textArea, counterElt) {
+ errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
+ errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
+ errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
openBracketRegExp = /\(/g;
closeBracketRegExp = /\)/g;
@@ -86,24 +80,31 @@ function checkBrackets(evt) {
}
if(counterElt.title != '') {
- counterElt.style = 'color: #FF5555;';
+ counterElt.classList.add('error');
} else {
- counterElt.style = '';
+ counterElt.classList.remove('error');
}
}
-var shadowRootLoaded = setInterval(function() {
- var sahdowRoot = document.querySelector('gradio-app').shadowRoot;
- if(! sahdowRoot) return false;
+function setupBracketChecking(id_prompt, id_counter){
+ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
+ var counter = gradioApp().getElementById(id_counter)
+ textarea.addEventListener("input", function(evt){
+ checkBrackets(evt, textarea, counter)
+ });
+}
- var shadowTextArea = sahdowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
- if(shadowTextArea.length < 1) return false;
+var shadowRootLoaded = setInterval(function() {
+ var shadowRoot = document.querySelector('gradio-app').shadowRoot;
+ if(! shadowRoot) return false;
+ var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
+ if(shadowTextArea.length < 1) return false;
- clearInterval(shadowRootLoaded);
+ clearInterval(shadowRootLoaded);
- document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
- document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
- document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
- document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
+ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
+ setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
+ setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
+ setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
}, 1000);
diff --git a/javascript/ui.js b/javascript/ui.js
index 37788a3e..3ba90ca8 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -230,14 +230,26 @@ onUiUpdate(function(){
json_elem.parentElement.style.display="none"
- if (!txt2img_textarea) {
- txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
- txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
- }
- if (!img2img_textarea) {
- img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
- img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
- }
+ function registerTextarea(id, id_counter, id_button){
+ var prompt = gradioApp().getElementById(id)
+ var counter = gradioApp().getElementById(id_counter)
+ var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
+
+ if(counter.parentElement == prompt.parentElement){
+ return
+ }
+
+ prompt.parentElement.insertBefore(counter, prompt)
+ counter.classList.add("token-counter")
+ prompt.parentElement.style.position = "relative"
+
+ textarea.addEventListener("input", () => update_token_counter(id_button));
+ }
+
+ registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
+ registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
+ registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
+ registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div')
@@ -249,6 +261,7 @@ onUiUpdate(function(){
})
}
}
+
})
diff --git a/modules/ui.py b/modules/ui.py
index eb45a128..06c11848 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -335,28 +335,23 @@ def update_token_counter(text, steps):
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- style_class = ' class="red"' if (token_count > max_length) else ""
- return f"{token_count}/{max_length}"
+ return f"{token_count}/{max_length}"
def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
- with gr.Row(elem_id="toprow"):
- with gr.Column(scale=6):
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
- placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
- placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
with gr.Column(scale=1, elem_id="roll_col"):
paste = gr.Button(value=paste_symbol, elem_id="paste")
@@ -365,6 +360,8 @@ def create_toprow(is_img2img):
clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter")
+ negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
clear_prompt_button.click(
fn=lambda *x: x,
@@ -402,7 +399,7 @@ def create_toprow(is_img2img):
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button
def setup_progressbar(*args, **kwargs):
@@ -619,7 +616,7 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
@@ -795,12 +792,13 @@ def create_ui():
]
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
@@ -1064,6 +1062,7 @@ def create_ui():
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
diff --git a/style.css b/style.css
index c10e32a1..994932de 100644
--- a/style.css
+++ b/style.css
@@ -2,12 +2,26 @@
max-width: 100%;
}
-#txt2img_token_counter {
- height: 0px;
+.token-counter{
+ position: absolute;
+ display: inline-block;
+ right: 2em;
+ min-width: 0 !important;
+ width: auto;
+ z-index: 100;
+}
+
+.token-counter.error span{
+ box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075);
+ border: 2px solid rgba(255,0,0,0.4) !important;
}
-#img2img_token_counter {
- height: 0px;
+.token-counter div{
+ display: inline;
+}
+
+.token-counter span{
+ padding: 0.1em 0.75em;
}
#sh{
@@ -113,7 +127,7 @@
#roll_col{
min-width: unset !important;
flex-grow: 0 !important;
- padding: 0.4em 0;
+ padding: 0 1em 0 0;
gap: 0;
}
@@ -160,16 +174,6 @@
margin-bottom: 0;
}
-#toprow div.gr-box, #toprow div.gr-form{
- border: none;
- gap: 0;
- background: transparent;
- box-shadow: none;
-}
-#toprow div{
- gap: 0;
-}
-
#resize_mode{
flex: 1.5;
}
@@ -706,6 +710,14 @@ footer {
opacity: 0.5;
}
+[id*='_prompt_container']{
+ gap: 0;
+}
+
+[id*='_prompt_container'] > div{
+ margin: -0.4em 0 0 0;
+}
+
.gr-compact {
border: none;
}
@@ -715,8 +727,11 @@ footer {
margin-left: 0.8em;
}
+.gr-compact{
+ overflow: visible;
+}
+
.gr-compact > *{
- margin-top: 0.5em !important;
}
.gr-compact .gr-block, .gr-compact .gr-form{
--
cgit v1.2.3
From 7d3fb5cb03cc8520a32b4f56509f2e13e36911bd Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 20 Jan 2023 12:12:02 +0300
Subject: add margin to interrogate column in img2img UI
---
style.css | 1 +
1 file changed, 1 insertion(+)
diff --git a/style.css b/style.css
index 994932de..3a515ebd 100644
--- a/style.css
+++ b/style.css
@@ -145,6 +145,7 @@
#interrogate_col{
min-width: 0 !important;
max-width: 8em !important;
+ margin-right: 1em;
}
#interrogate, #deepbooru{
margin: 0em 0.25em 0.9em 0.25em;
--
cgit v1.2.3
From e33cace2c2074ef342d027c1f31ffc4b3c3e877e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 20 Jan 2023 12:19:30 +0300
Subject: fix ctrl+up/down that stopped working
---
javascript/edit-attention.js | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index ccc8344a..cec6a530 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -1,6 +1,6 @@
addEventListener('keydown', (event) => {
let target = event.originalTarget || event.composedPath()[0];
- if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
+ if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
if (! (event.metaKey || event.ctrlKey)) return;
--
cgit v1.2.3
From e0b6092bc99efe311261a51289dec67cbf4845bc Mon Sep 17 00:00:00 2001
From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>
Date: Fri, 20 Jan 2023 15:31:27 +0100
Subject: Update webui.sh
---
webui.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/webui.sh b/webui.sh
index d5e7b3c5..8cdad22d 100755
--- a/webui.sh
+++ b/webui.sh
@@ -110,7 +110,7 @@ case "$gpu_info" in
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"
- printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
+ printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
printf "\n%s\n" "${delimiter}"
;;
*)
--
cgit v1.2.3
From 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 21 Jan 2023 08:36:07 +0300
Subject: extra networks UI rework of hypernets: rather than via settings,
hypernets are added directly to prompt as
---
html/card-no-preview.png | Bin 0 -> 84440 bytes
html/extra-networks-card.html | 11 ++
html/extra-networks-no-cards.html | 8 ++
javascript/extraNetworks.js | 60 ++++++++
javascript/hints.js | 2 +
javascript/ui.js | 9 +-
modules/api/api.py | 7 +-
modules/extra_networks.py | 147 +++++++++++++++++++
modules/extra_networks_hypernet.py | 21 +++
modules/generation_parameters_copypaste.py | 12 +-
modules/hypernetworks/hypernetwork.py | 107 +++++++++-----
modules/hypernetworks/ui.py | 5 +-
modules/processing.py | 24 ++--
modules/sd_hijack_optimizations.py | 10 +-
modules/shared.py | 21 ++-
modules/textual_inversion/textual_inversion.py | 2 +
modules/ui.py | 50 ++++---
modules/ui_components.py | 10 ++
modules/ui_extra_networks.py | 149 +++++++++++++++++++
modules/ui_extra_networks_hypernets.py | 34 +++++
modules/ui_extra_networks_textual_inversion.py | 32 +++++
script.js | 13 +-
scripts/xy_grid.py | 29 ----
style.css | 190 +++++++++++++------------
webui.py | 26 +++-
25 files changed, 765 insertions(+), 214 deletions(-)
create mode 100644 html/card-no-preview.png
create mode 100644 html/extra-networks-card.html
create mode 100644 html/extra-networks-no-cards.html
create mode 100644 javascript/extraNetworks.js
create mode 100644 modules/extra_networks.py
create mode 100644 modules/extra_networks_hypernet.py
create mode 100644 modules/ui_extra_networks.py
create mode 100644 modules/ui_extra_networks_hypernets.py
create mode 100644 modules/ui_extra_networks_textual_inversion.py
diff --git a/html/card-no-preview.png b/html/card-no-preview.png
new file mode 100644
index 00000000..e2beb269
Binary files /dev/null and b/html/card-no-preview.png differ
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
new file mode 100644
index 00000000..7314b063
--- /dev/null
+++ b/html/extra-networks-card.html
@@ -0,0 +1,11 @@
+
{name}
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 2ddac3d8..8b4f97f8 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -3,6 +3,7 @@ import os.path
from modules import shared
import gradio as gr
import json
+import html
from modules.generation_parameters_copypaste import image_from_url_text
@@ -54,12 +55,13 @@ class ExtraNetworksPage:
preview = item.get("preview", None)
args = {
- "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
+ "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item["prompt"],
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
- "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+ "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
}
return self.card_page.format(**args)
--
cgit v1.2.3
From 5c1cb9263f980641007088a37360fcab01761d37 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 00:24:17 +0300
Subject: fix BLIP failing to import depending on configuration
---
modules/interrogate.py | 3 ++-
modules/paths.py | 14 ++++++++++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index c252b148..236e6983 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -83,7 +83,8 @@ class InterrogateModels:
return self.loaded_categories
def load_blip_model(self):
- import models.blip
+ with paths.Prioritize("BLIP"):
+ import models.blip
files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "BLIP"),
diff --git a/modules/paths.py b/modules/paths.py
index 4dd03a35..20b3e4d8 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,3 +38,17 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
+
+
+class Prioritize:
+ def __init__(self, name):
+ self.name = name
+ self.path = None
+
+ def __enter__(self):
+ self.path = sys.path.copy()
+ sys.path = [paths[self.name]] + sys.path
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ sys.path = self.path
+ self.path = None
--
cgit v1.2.3
From 82a28bfe35928e244d4b51d72ce424aff5619b75 Mon Sep 17 00:00:00 2001
From: Mykeehu
Date: Mon, 23 Jan 2023 22:36:27 +0100
Subject: Fix extra network thumbs label color
Added white color for labels.
---
style.css | 1 +
1 file changed, 1 insertion(+)
diff --git a/style.css b/style.css
index b2677fa1..ec046f78 100644
--- a/style.css
+++ b/style.css
@@ -857,6 +857,7 @@ footer {
white-space: nowrap;
text-overflow: ellipsis;
background: rgba(0,0,0,.5);
+ color: white;
}
.extra-network-thumbs .card:hover .actions .name {
--
cgit v1.2.3
From 45e270dfc853216b2c413f915946f0f2842e57a4 Mon Sep 17 00:00:00 2001
From: Vladimir Mandic
Date: Mon, 23 Jan 2023 17:11:22 -0500
Subject: add image decod exception handling
---
modules/api/api.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index b1dd14cc..e6e31e41 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -53,7 +53,11 @@ def setUpscalers(req: dict):
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
- return Image.open(BytesIO(base64.b64decode(encoding)))
+ try:
+ image = Image.open(BytesIO(base64.b64decode(encoding)))
+ return image
+ except Exception as err:
+ raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
--
cgit v1.2.3
From f99352582084890b9167c1bf8699865bea0cef5f Mon Sep 17 00:00:00 2001
From: catboxanon <122327233+catboxanon@users.noreply.github.com>
Date: Mon, 23 Jan 2023 21:50:59 -0500
Subject: Make SwinIR interruptible
---
extensions-builtin/SwinIR/scripts/swinir_model.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 9a74b253..3479760a 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts
+from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
+ if state.interrupted:
+ break
+
for w_idx in w_idx_list:
+ if state.interrupted:
+ break
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
--
cgit v1.2.3
From 3c47b050367ee220dcfed7be7883878417735614 Mon Sep 17 00:00:00 2001
From: catboxanon <122327233+catboxanon@users.noreply.github.com>
Date: Mon, 23 Jan 2023 22:00:27 -0500
Subject: Also make SwinIR skippable
---
extensions-builtin/SwinIR/scripts/swinir_model.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 3479760a..e8783bca 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -145,11 +145,11 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
- if state.interrupted:
+ if state.interrupted or state.skipped:
break
for w_idx in w_idx_list:
- if state.interrupted:
+ if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
--
cgit v1.2.3
From 078e16e4d33ccbd40ff3ecfbb57ffd33a2a16c47 Mon Sep 17 00:00:00 2001
From: acncagua
Date: Tue, 24 Jan 2023 12:21:07 +0900
Subject: Set Linux xformers 0.0.16RC425
---
launch.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/launch.py b/launch.py
index e7a0b50c..82094fa0 100644
--- a/launch.py
+++ b/launch.py
@@ -245,7 +245,7 @@ def prepare_environment():
if not is_installed("xformers"):
exit(0)
elif platform.system() == "Linux":
- run_pip("install xformers", "xformers")
+ run_pip("install xformers==0.0.16rc425", "xformers")
if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok")
--
cgit v1.2.3
From f64af77adcd20fabe00e1e642512db9c6742ed23 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Mon, 23 Jan 2023 22:49:20 -0500
Subject: Fix different first gen with Approx NN previews
The loading of the model for approx nn live previews can change the internal state of PyTorch, resulting in a different image. This can be avoided by preloading the approx nn model in advance.
---
modules/processing.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/modules/processing.py b/modules/processing.py
index bc541e2f..3bd590ba 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -568,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1:
+ # preload approx nn model before sampling for a more deterministic result
+ sd_vae_approx.model()
+
if not p.disable_extra_networks:
extra_networks.activate(p, extra_network_data)
--
cgit v1.2.3
From 42a70d74771e8920f658e741679768ed145dd76a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 10:05:45 +0300
Subject: repair sdapi/v1/upscalers returning bogus results
---
modules/api/api.py | 16 +++++++++-------
modules/api/models.py | 2 +-
2 files changed, 10 insertions(+), 8 deletions(-)
diff --git a/modules/api/api.py b/modules/api/api.py
index e6e31e41..da2a5daf 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -375,13 +375,15 @@ class Api:
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
- upscalers = []
-
- for upscaler in shared.sd_upscalers:
- u = upscaler.scaler
- upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
-
- return upscalers
+ return [
+ {
+ "name": upscaler.name,
+ "model_name": upscaler.scaler.model_name,
+ "model_path": upscaler.data_path,
+ "scale": upscaler.scale,
+ }
+ for upscaler in shared.sd_upscalers
+ ]
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
diff --git a/modules/api/models.py b/modules/api/models.py
index 1eb1fcf1..e562ab54 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -219,7 +219,7 @@ class UpscalerItem(BaseModel):
name: str = Field(title="Name")
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
- model_url: Optional[str] = Field(title="URL")
+ scale: Optional[float] = Field(title="Scale")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
--
cgit v1.2.3
From 602a1864b05075ca4283986e6f5c7d5bce864e11 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 10:09:30 +0300
Subject: also return the removed field to sdapi/v1/upscalers because someone
might have relied on it existing
---
modules/api/api.py | 1 +
modules/api/models.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/modules/api/api.py b/modules/api/api.py
index da2a5daf..25c65e57 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -380,6 +380,7 @@ class Api:
"name": upscaler.name,
"model_name": upscaler.scaler.model_name,
"model_path": upscaler.data_path,
+ "model_url": None,
"scale": upscaler.scale,
}
for upscaler in shared.sd_upscalers
diff --git a/modules/api/models.py b/modules/api/models.py
index e562ab54..805bd8f7 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -219,6 +219,7 @@ class UpscalerItem(BaseModel):
name: str = Field(title="Name")
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
+ model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale")
class SDModelItem(BaseModel):
--
cgit v1.2.3
From d30ac02f28bf5fa1ca5d4ba444180ba9e50b4860 Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Tue, 24 Jan 2023 02:21:32 -0500
Subject: renamed xy to xyz grid this is mostly just so git can detect it
properly
---
scripts/xy_grid.py | 498 ----------------------------------------------------
scripts/xyz_grid.py | 498 ++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 498 insertions(+), 498 deletions(-)
delete mode 100644 scripts/xy_grid.py
create mode 100644 scripts/xyz_grid.py
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
deleted file mode 100644
index 1a452355..00000000
--- a/scripts/xy_grid.py
+++ /dev/null
@@ -1,498 +0,0 @@
-from collections import namedtuple
-from copy import copy
-from itertools import permutations, chain
-import random
-import csv
-from io import StringIO
-from PIL import Image
-import numpy as np
-
-import modules.scripts as scripts
-import gradio as gr
-
-from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
-from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
-from modules.shared import opts, cmd_opts, state
-import modules.shared as shared
-import modules.sd_samplers
-import modules.sd_models
-import modules.sd_vae
-import glob
-import os
-import re
-
-from modules.ui_components import ToolButton
-
-fill_values_symbol = "\U0001f4d2" # 📒
-
-
-def apply_field(field):
- def fun(p, x, xs):
- setattr(p, field, x)
-
- return fun
-
-
-def apply_prompt(p, x, xs):
- if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
- raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
-
- p.prompt = p.prompt.replace(xs[0], x)
- p.negative_prompt = p.negative_prompt.replace(xs[0], x)
-
-
-def apply_order(p, x, xs):
- token_order = []
-
- # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
- for token in x:
- token_order.append((p.prompt.find(token), token))
-
- token_order.sort(key=lambda t: t[0])
-
- prompt_parts = []
-
- # Split the prompt up, taking out the tokens
- for _, token in token_order:
- n = p.prompt.find(token)
- prompt_parts.append(p.prompt[0:n])
- p.prompt = p.prompt[n + len(token):]
-
- # Rebuild the prompt with the tokens in the order we want
- prompt_tmp = ""
- for idx, part in enumerate(prompt_parts):
- prompt_tmp += part
- prompt_tmp += x[idx]
- p.prompt = prompt_tmp + p.prompt
-
-
-def apply_sampler(p, x, xs):
- sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
- if sampler_name is None:
- raise RuntimeError(f"Unknown sampler: {x}")
-
- p.sampler_name = sampler_name
-
-
-def confirm_samplers(p, xs):
- for x in xs:
- if x.lower() not in sd_samplers.samplers_map:
- raise RuntimeError(f"Unknown sampler: {x}")
-
-
-def apply_checkpoint(p, x, xs):
- info = modules.sd_models.get_closet_checkpoint_match(x)
- if info is None:
- raise RuntimeError(f"Unknown checkpoint: {x}")
- modules.sd_models.reload_model_weights(shared.sd_model, info)
-
-
-def confirm_checkpoints(p, xs):
- for x in xs:
- if modules.sd_models.get_closet_checkpoint_match(x) is None:
- raise RuntimeError(f"Unknown checkpoint: {x}")
-
-
-def apply_clip_skip(p, x, xs):
- opts.data["CLIP_stop_at_last_layers"] = x
-
-
-def apply_upscale_latent_space(p, x, xs):
- if x.lower().strip() != '0':
- opts.data["use_scale_latent_for_hires_fix"] = True
- else:
- opts.data["use_scale_latent_for_hires_fix"] = False
-
-
-def find_vae(name: str):
- if name.lower() in ['auto', 'automatic']:
- return modules.sd_vae.unspecified
- if name.lower() == 'none':
- return None
- else:
- choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
- if len(choices) == 0:
- print(f"No VAE found for {name}; using automatic")
- return modules.sd_vae.unspecified
- else:
- return modules.sd_vae.vae_dict[choices[0]]
-
-
-def apply_vae(p, x, xs):
- modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
-
-
-def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
- p.styles = x.split(',')
-
-
-def format_value_add_label(p, opt, x):
- if type(x) == float:
- x = round(x, 8)
-
- return f"{opt.label}: {x}"
-
-
-def format_value(p, opt, x):
- if type(x) == float:
- x = round(x, 8)
- return x
-
-
-def format_value_join_list(p, opt, x):
- return ", ".join(x)
-
-
-def do_nothing(p, x, xs):
- pass
-
-
-def format_nothing(p, opt, x):
- return ""
-
-
-def str_permutations(x):
- """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
- return x
-
-
-class AxisOption:
- def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
- self.label = label
- self.type = type
- self.apply = apply
- self.format_value = format_value
- self.confirm = confirm
- self.cost = cost
- self.choices = choices
-
-
-class AxisOptionImg2Img(AxisOption):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_img2img = True
-
-class AxisOptionTxt2Img(AxisOption):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_img2img = False
-
-
-axis_options = [
- AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
- AxisOption("Seed", int, apply_field("seed")),
- AxisOption("Var. seed", int, apply_field("subseed")),
- AxisOption("Var. strength", float, apply_field("subseed_strength")),
- AxisOption("Steps", int, apply_field("steps")),
- AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
- AxisOption("CFG Scale", float, apply_field("cfg_scale")),
- AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
- AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
- AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
- AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
- AxisOption("Sigma Churn", float, apply_field("s_churn")),
- AxisOption("Sigma min", float, apply_field("s_tmin")),
- AxisOption("Sigma max", float, apply_field("s_tmax")),
- AxisOption("Sigma noise", float, apply_field("s_noise")),
- AxisOption("Eta", float, apply_field("eta")),
- AxisOption("Clip skip", int, apply_clip_skip),
- AxisOption("Denoising", float, apply_field("denoising_strength")),
- AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
- AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
- AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
- AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
-]
-
-
-def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
- ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
- hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
-
- # Temporary list of all the images that are generated to be populated into the grid.
- # Will be filled with empty images for any individual step that fails to process properly
- image_cache = [None] * (len(xs) * len(ys))
-
- processed_result = None
- cell_mode = "P"
- cell_size = (1, 1)
-
- state.job_count = len(xs) * len(ys) * p.n_iter
-
- def process_cell(x, y, ix, iy):
- nonlocal image_cache, processed_result, cell_mode, cell_size
-
- state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
-
- processed: Processed = cell(x, y)
-
- try:
- # this dereference will throw an exception if the image was not processed
- # (this happens in cases such as if the user stops the process from the UI)
- processed_image = processed.images[0]
-
- if processed_result is None:
- # Use our first valid processed result as a template container to hold our full results
- processed_result = copy(processed)
- cell_mode = processed_image.mode
- cell_size = processed_image.size
- processed_result.images = [Image.new(cell_mode, cell_size)]
-
- image_cache[ix + iy * len(xs)] = processed_image
- if include_lone_images:
- processed_result.images.append(processed_image)
- processed_result.all_prompts.append(processed.prompt)
- processed_result.all_seeds.append(processed.seed)
- processed_result.infotexts.append(processed.infotexts[0])
- except:
- image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
-
- if swap_axes_processing_order:
- for ix, x in enumerate(xs):
- for iy, y in enumerate(ys):
- process_cell(x, y, ix, iy)
- else:
- for iy, y in enumerate(ys):
- for ix, x in enumerate(xs):
- process_cell(x, y, ix, iy)
-
- if not processed_result:
- print("Unexpected error: draw_xy_grid failed to return even a single processed image")
- return Processed(p, [])
-
- grid = images.image_grid(image_cache, rows=len(ys))
- if draw_legend:
- grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
-
- processed_result.images[0] = grid
-
- return processed_result
-
-
-class SharedSettingsStackHelper(object):
- def __enter__(self):
- self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
- self.vae = opts.sd_vae
-
- def __exit__(self, exc_type, exc_value, tb):
- opts.data["sd_vae"] = self.vae
- modules.sd_models.reload_model_weights()
- modules.sd_vae.reload_vae_weights()
-
- opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
-
-
-re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
-re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
-
-re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
-re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
-
-
-class Script(scripts.Script):
- def title(self):
- return "X/Y plot"
-
- def ui(self, is_img2img):
- self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
-
- with gr.Row():
- with gr.Column(scale=19):
- with gr.Row():
- x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
- x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
- fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
-
- with gr.Row():
- y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
- y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
- fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
-
- with gr.Row(variant="compact", elem_id="axis_options"):
- draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
- include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
- no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
- swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
-
- def swap_axes(x_type, x_values, y_type, y_values):
- return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values
-
- swap_args = [x_type, x_values, y_type, y_values]
- swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
-
- def fill(x_type):
- axis = self.current_axis_options[x_type]
- return ", ".join(axis.choices()) if axis.choices else gr.update()
-
- fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
- fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
-
- def select_axis(x_type):
- return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
-
- x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
- y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
-
- return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
-
- def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
- if not no_fixed_seeds:
- modules.processing.fix_seed(p)
-
- if not opts.return_grid:
- p.batch_size = 1
-
- def process_axis(opt, vals):
- if opt.label == 'Nothing':
- return [0]
-
- valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
-
- if opt.type == int:
- valslist_ext = []
-
- for val in valslist:
- m = re_range.fullmatch(val)
- mc = re_range_count.fullmatch(val)
- if m is not None:
- start = int(m.group(1))
- end = int(m.group(2))+1
- step = int(m.group(3)) if m.group(3) is not None else 1
-
- valslist_ext += list(range(start, end, step))
- elif mc is not None:
- start = int(mc.group(1))
- end = int(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
-
- valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
- else:
- valslist_ext.append(val)
-
- valslist = valslist_ext
- elif opt.type == float:
- valslist_ext = []
-
- for val in valslist:
- m = re_range_float.fullmatch(val)
- mc = re_range_count_float.fullmatch(val)
- if m is not None:
- start = float(m.group(1))
- end = float(m.group(2))
- step = float(m.group(3)) if m.group(3) is not None else 1
-
- valslist_ext += np.arange(start, end + step, step).tolist()
- elif mc is not None:
- start = float(mc.group(1))
- end = float(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
-
- valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
- else:
- valslist_ext.append(val)
-
- valslist = valslist_ext
- elif opt.type == str_permutations:
- valslist = list(permutations(valslist))
-
- valslist = [opt.type(x) for x in valslist]
-
- # Confirm options are valid before starting
- if opt.confirm:
- opt.confirm(p, valslist)
-
- return valslist
-
- x_opt = self.current_axis_options[x_type]
- xs = process_axis(x_opt, x_values)
-
- y_opt = self.current_axis_options[y_type]
- ys = process_axis(y_opt, y_values)
-
- def fix_axis_seeds(axis_opt, axis_list):
- if axis_opt.label in ['Seed', 'Var. seed']:
- return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
- else:
- return axis_list
-
- if not no_fixed_seeds:
- xs = fix_axis_seeds(x_opt, xs)
- ys = fix_axis_seeds(y_opt, ys)
-
- if x_opt.label == 'Steps':
- total_steps = sum(xs) * len(ys)
- elif y_opt.label == 'Steps':
- total_steps = sum(ys) * len(xs)
- else:
- total_steps = p.steps * len(xs) * len(ys)
-
- if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
- if x_opt.label == "Hires steps":
- total_steps += sum(xs) * len(ys)
- elif y_opt.label == "Hires steps":
- total_steps += sum(ys) * len(xs)
- elif p.hr_second_pass_steps:
- total_steps += p.hr_second_pass_steps * len(xs) * len(ys)
- else:
- total_steps *= 2
-
- total_steps *= p.n_iter
-
- image_cell_count = p.n_iter * p.batch_size
- cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
- print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})")
- shared.total_tqdm.updateTotal(total_steps)
-
- grid_infotext = [None]
-
- # If one of the axes is very slow to change between (like SD model
- # checkpoint), then make sure it is in the outer iteration of the nested
- # `for` loop.
- swap_axes_processing_order = x_opt.cost > y_opt.cost
-
- def cell(x, y):
- if shared.state.interrupted:
- return Processed(p, [], p.seed, "")
-
- pc = copy(p)
- x_opt.apply(pc, x, xs)
- y_opt.apply(pc, y, ys)
-
- res = process_images(pc)
-
- if grid_infotext[0] is None:
- pc.extra_generation_params = copy(pc.extra_generation_params)
-
- if x_opt.label != 'Nothing':
- pc.extra_generation_params["X Type"] = x_opt.label
- pc.extra_generation_params["X Values"] = x_values
- if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
-
- if y_opt.label != 'Nothing':
- pc.extra_generation_params["Y Type"] = y_opt.label
- pc.extra_generation_params["Y Values"] = y_values
- if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
-
- grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
-
- return res
-
- with SharedSettingsStackHelper():
- processed = draw_xy_grid(
- p,
- xs=xs,
- ys=ys,
- x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
- y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
- cell=cell,
- draw_legend=draw_legend,
- include_lone_images=include_lone_images,
- swap_axes_processing_order=swap_axes_processing_order
- )
-
- if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
-
- return processed
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
new file mode 100644
index 00000000..1a452355
--- /dev/null
+++ b/scripts/xyz_grid.py
@@ -0,0 +1,498 @@
+from collections import namedtuple
+from copy import copy
+from itertools import permutations, chain
+import random
+import csv
+from io import StringIO
+from PIL import Image
+import numpy as np
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
+from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
+from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
+import modules.sd_samplers
+import modules.sd_models
+import modules.sd_vae
+import glob
+import os
+import re
+
+from modules.ui_components import ToolButton
+
+fill_values_symbol = "\U0001f4d2" # 📒
+
+
+def apply_field(field):
+ def fun(p, x, xs):
+ setattr(p, field, x)
+
+ return fun
+
+
+def apply_prompt(p, x, xs):
+ if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
+ raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
+
+ p.prompt = p.prompt.replace(xs[0], x)
+ p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+
+
+def apply_order(p, x, xs):
+ token_order = []
+
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
+ for token in x:
+ token_order.append((p.prompt.find(token), token))
+
+ token_order.sort(key=lambda t: t[0])
+
+ prompt_parts = []
+
+ # Split the prompt up, taking out the tokens
+ for _, token in token_order:
+ n = p.prompt.find(token)
+ prompt_parts.append(p.prompt[0:n])
+ p.prompt = p.prompt[n + len(token):]
+
+ # Rebuild the prompt with the tokens in the order we want
+ prompt_tmp = ""
+ for idx, part in enumerate(prompt_parts):
+ prompt_tmp += part
+ prompt_tmp += x[idx]
+ p.prompt = prompt_tmp + p.prompt
+
+
+def apply_sampler(p, x, xs):
+ sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
+ if sampler_name is None:
+ raise RuntimeError(f"Unknown sampler: {x}")
+
+ p.sampler_name = sampler_name
+
+
+def confirm_samplers(p, xs):
+ for x in xs:
+ if x.lower() not in sd_samplers.samplers_map:
+ raise RuntimeError(f"Unknown sampler: {x}")
+
+
+def apply_checkpoint(p, x, xs):
+ info = modules.sd_models.get_closet_checkpoint_match(x)
+ if info is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
+
+
+def confirm_checkpoints(p, xs):
+ for x in xs:
+ if modules.sd_models.get_closet_checkpoint_match(x) is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
+
+
+def apply_clip_skip(p, x, xs):
+ opts.data["CLIP_stop_at_last_layers"] = x
+
+
+def apply_upscale_latent_space(p, x, xs):
+ if x.lower().strip() != '0':
+ opts.data["use_scale_latent_for_hires_fix"] = True
+ else:
+ opts.data["use_scale_latent_for_hires_fix"] = False
+
+
+def find_vae(name: str):
+ if name.lower() in ['auto', 'automatic']:
+ return modules.sd_vae.unspecified
+ if name.lower() == 'none':
+ return None
+ else:
+ choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
+ if len(choices) == 0:
+ print(f"No VAE found for {name}; using automatic")
+ return modules.sd_vae.unspecified
+ else:
+ return modules.sd_vae.vae_dict[choices[0]]
+
+
+def apply_vae(p, x, xs):
+ modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
+
+
+def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
+ p.styles = x.split(',')
+
+
+def format_value_add_label(p, opt, x):
+ if type(x) == float:
+ x = round(x, 8)
+
+ return f"{opt.label}: {x}"
+
+
+def format_value(p, opt, x):
+ if type(x) == float:
+ x = round(x, 8)
+ return x
+
+
+def format_value_join_list(p, opt, x):
+ return ", ".join(x)
+
+
+def do_nothing(p, x, xs):
+ pass
+
+
+def format_nothing(p, opt, x):
+ return ""
+
+
+def str_permutations(x):
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
+ return x
+
+
+class AxisOption:
+ def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
+ self.label = label
+ self.type = type
+ self.apply = apply
+ self.format_value = format_value
+ self.confirm = confirm
+ self.cost = cost
+ self.choices = choices
+
+
+class AxisOptionImg2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_img2img = True
+
+class AxisOptionTxt2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_img2img = False
+
+
+axis_options = [
+ AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
+ AxisOption("Seed", int, apply_field("seed")),
+ AxisOption("Var. seed", int, apply_field("subseed")),
+ AxisOption("Var. strength", float, apply_field("subseed_strength")),
+ AxisOption("Steps", int, apply_field("steps")),
+ AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
+ AxisOption("CFG Scale", float, apply_field("cfg_scale")),
+ AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
+ AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+ AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
+ AxisOption("Sigma Churn", float, apply_field("s_churn")),
+ AxisOption("Sigma min", float, apply_field("s_tmin")),
+ AxisOption("Sigma max", float, apply_field("s_tmax")),
+ AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Eta", float, apply_field("eta")),
+ AxisOption("Clip skip", int, apply_clip_skip),
+ AxisOption("Denoising", float, apply_field("denoising_strength")),
+ AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
+ AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
+ AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
+ AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
+]
+
+
+def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
+ hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
+
+ # Temporary list of all the images that are generated to be populated into the grid.
+ # Will be filled with empty images for any individual step that fails to process properly
+ image_cache = [None] * (len(xs) * len(ys))
+
+ processed_result = None
+ cell_mode = "P"
+ cell_size = (1, 1)
+
+ state.job_count = len(xs) * len(ys) * p.n_iter
+
+ def process_cell(x, y, ix, iy):
+ nonlocal image_cache, processed_result, cell_mode, cell_size
+
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
+
+ processed: Processed = cell(x, y)
+
+ try:
+ # this dereference will throw an exception if the image was not processed
+ # (this happens in cases such as if the user stops the process from the UI)
+ processed_image = processed.images[0]
+
+ if processed_result is None:
+ # Use our first valid processed result as a template container to hold our full results
+ processed_result = copy(processed)
+ cell_mode = processed_image.mode
+ cell_size = processed_image.size
+ processed_result.images = [Image.new(cell_mode, cell_size)]
+
+ image_cache[ix + iy * len(xs)] = processed_image
+ if include_lone_images:
+ processed_result.images.append(processed_image)
+ processed_result.all_prompts.append(processed.prompt)
+ processed_result.all_seeds.append(processed.seed)
+ processed_result.infotexts.append(processed.infotexts[0])
+ except:
+ image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
+
+ if swap_axes_processing_order:
+ for ix, x in enumerate(xs):
+ for iy, y in enumerate(ys):
+ process_cell(x, y, ix, iy)
+ else:
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, ix, iy)
+
+ if not processed_result:
+ print("Unexpected error: draw_xy_grid failed to return even a single processed image")
+ return Processed(p, [])
+
+ grid = images.image_grid(image_cache, rows=len(ys))
+ if draw_legend:
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
+
+ processed_result.images[0] = grid
+
+ return processed_result
+
+
+class SharedSettingsStackHelper(object):
+ def __enter__(self):
+ self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
+ self.vae = opts.sd_vae
+
+ def __exit__(self, exc_type, exc_value, tb):
+ opts.data["sd_vae"] = self.vae
+ modules.sd_models.reload_model_weights()
+ modules.sd_vae.reload_vae_weights()
+
+ opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
+
+
+re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
+re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
+
+re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
+re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "X/Y plot"
+
+ def ui(self, is_img2img):
+ self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
+
+ with gr.Row():
+ with gr.Column(scale=19):
+ with gr.Row():
+ x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
+ x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
+ fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
+
+ with gr.Row():
+ y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
+ y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
+ fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
+
+ with gr.Row(variant="compact", elem_id="axis_options"):
+ draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
+ include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
+ no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
+
+ def swap_axes(x_type, x_values, y_type, y_values):
+ return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values
+
+ swap_args = [x_type, x_values, y_type, y_values]
+ swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
+
+ def fill(x_type):
+ axis = self.current_axis_options[x_type]
+ return ", ".join(axis.choices()) if axis.choices else gr.update()
+
+ fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
+ fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
+
+ def select_axis(x_type):
+ return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
+
+ x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
+ y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
+
+ return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
+
+ def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
+ if not no_fixed_seeds:
+ modules.processing.fix_seed(p)
+
+ if not opts.return_grid:
+ p.batch_size = 1
+
+ def process_axis(opt, vals):
+ if opt.label == 'Nothing':
+ return [0]
+
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
+
+ if opt.type == int:
+ valslist_ext = []
+
+ for val in valslist:
+ m = re_range.fullmatch(val)
+ mc = re_range_count.fullmatch(val)
+ if m is not None:
+ start = int(m.group(1))
+ end = int(m.group(2))+1
+ step = int(m.group(3)) if m.group(3) is not None else 1
+
+ valslist_ext += list(range(start, end, step))
+ elif mc is not None:
+ start = int(mc.group(1))
+ end = int(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
+ else:
+ valslist_ext.append(val)
+
+ valslist = valslist_ext
+ elif opt.type == float:
+ valslist_ext = []
+
+ for val in valslist:
+ m = re_range_float.fullmatch(val)
+ mc = re_range_count_float.fullmatch(val)
+ if m is not None:
+ start = float(m.group(1))
+ end = float(m.group(2))
+ step = float(m.group(3)) if m.group(3) is not None else 1
+
+ valslist_ext += np.arange(start, end + step, step).tolist()
+ elif mc is not None:
+ start = float(mc.group(1))
+ end = float(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
+ else:
+ valslist_ext.append(val)
+
+ valslist = valslist_ext
+ elif opt.type == str_permutations:
+ valslist = list(permutations(valslist))
+
+ valslist = [opt.type(x) for x in valslist]
+
+ # Confirm options are valid before starting
+ if opt.confirm:
+ opt.confirm(p, valslist)
+
+ return valslist
+
+ x_opt = self.current_axis_options[x_type]
+ xs = process_axis(x_opt, x_values)
+
+ y_opt = self.current_axis_options[y_type]
+ ys = process_axis(y_opt, y_values)
+
+ def fix_axis_seeds(axis_opt, axis_list):
+ if axis_opt.label in ['Seed', 'Var. seed']:
+ return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
+ else:
+ return axis_list
+
+ if not no_fixed_seeds:
+ xs = fix_axis_seeds(x_opt, xs)
+ ys = fix_axis_seeds(y_opt, ys)
+
+ if x_opt.label == 'Steps':
+ total_steps = sum(xs) * len(ys)
+ elif y_opt.label == 'Steps':
+ total_steps = sum(ys) * len(xs)
+ else:
+ total_steps = p.steps * len(xs) * len(ys)
+
+ if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
+ if x_opt.label == "Hires steps":
+ total_steps += sum(xs) * len(ys)
+ elif y_opt.label == "Hires steps":
+ total_steps += sum(ys) * len(xs)
+ elif p.hr_second_pass_steps:
+ total_steps += p.hr_second_pass_steps * len(xs) * len(ys)
+ else:
+ total_steps *= 2
+
+ total_steps *= p.n_iter
+
+ image_cell_count = p.n_iter * p.batch_size
+ cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
+ print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})")
+ shared.total_tqdm.updateTotal(total_steps)
+
+ grid_infotext = [None]
+
+ # If one of the axes is very slow to change between (like SD model
+ # checkpoint), then make sure it is in the outer iteration of the nested
+ # `for` loop.
+ swap_axes_processing_order = x_opt.cost > y_opt.cost
+
+ def cell(x, y):
+ if shared.state.interrupted:
+ return Processed(p, [], p.seed, "")
+
+ pc = copy(p)
+ x_opt.apply(pc, x, xs)
+ y_opt.apply(pc, y, ys)
+
+ res = process_images(pc)
+
+ if grid_infotext[0] is None:
+ pc.extra_generation_params = copy(pc.extra_generation_params)
+
+ if x_opt.label != 'Nothing':
+ pc.extra_generation_params["X Type"] = x_opt.label
+ pc.extra_generation_params["X Values"] = x_values
+ if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
+
+ if y_opt.label != 'Nothing':
+ pc.extra_generation_params["Y Type"] = y_opt.label
+ pc.extra_generation_params["Y Values"] = y_values
+ if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
+
+ grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
+
+ return res
+
+ with SharedSettingsStackHelper():
+ processed = draw_xy_grid(
+ p,
+ xs=xs,
+ ys=ys,
+ x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
+ y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
+ cell=cell,
+ draw_legend=draw_legend,
+ include_lone_images=include_lone_images,
+ swap_axes_processing_order=swap_axes_processing_order
+ )
+
+ if opts.grid_save:
+ images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+
+ return processed
--
cgit v1.2.3
From 9fc354e1303453bbd865cbede86da4c3273ed14f Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Tue, 24 Jan 2023 02:22:40 -0500
Subject: implements most of xyz grid script
---
scripts/xyz_grid.py | 114 +++++++++++++++++++++++++++++++++++-----------------
1 file changed, 78 insertions(+), 36 deletions(-)
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 1a452355..494e8417 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -205,26 +205,30 @@ axis_options = [
]
-def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
+def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
+ title_texts = [[images.GridAnnotation(z)] for z in z_labels]
# Temporary list of all the images that are generated to be populated into the grid.
# Will be filled with empty images for any individual step that fails to process properly
- image_cache = [None] * (len(xs) * len(ys))
+ image_cache = [None] * (len(xs) * len(ys) * len(zs))
processed_result = None
cell_mode = "P"
cell_size = (1, 1)
- state.job_count = len(xs) * len(ys) * p.n_iter
+ state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
- def process_cell(x, y, ix, iy):
+ def process_cell(x, y, z, ix, iy, iz):
nonlocal image_cache, processed_result, cell_mode, cell_size
- state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
+ def index(ix, iy, iz):
+ return ix + iy*len(xs) + iz*len(xs)*len(ys)
- processed: Processed = cell(x, y)
+ state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
+
+ processed: Processed = cell(x, y, z)
try:
# this dereference will throw an exception if the image was not processed
@@ -238,33 +242,40 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
cell_size = processed_image.size
processed_result.images = [Image.new(cell_mode, cell_size)]
- image_cache[ix + iy * len(xs)] = processed_image
+ image_cache[index(ix, iy, iz)] = processed_image
if include_lone_images:
processed_result.images.append(processed_image)
processed_result.all_prompts.append(processed.prompt)
processed_result.all_seeds.append(processed.seed)
processed_result.infotexts.append(processed.infotexts[0])
except:
- image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
+ image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
if swap_axes_processing_order:
for ix, x in enumerate(xs):
for iy, y in enumerate(ys):
- process_cell(x, y, ix, iy)
+ for iy, y in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
else:
for iy, y in enumerate(ys):
for ix, x in enumerate(xs):
- process_cell(x, y, ix, iy)
+ for iz, z in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
if not processed_result:
- print("Unexpected error: draw_xy_grid failed to return even a single processed image")
+ print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
return Processed(p, [])
- grid = images.image_grid(image_cache, rows=len(ys))
- if draw_legend:
- grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
-
- processed_result.images[0] = grid
+ for i, title_text in enumerate(title_texts):
+ start_index = i * len(xs) * len(ys)
+ end_index = start_index + len(xs) * len(ys)
+ grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
+ if draw_legend:
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
+ if i == 0: # First position is a placeholder as mentioned above, so it can be directly replaced
+ processed_result.images[0] = grid
+ else:
+ processed_result.images.insert(i, grid)
return processed_result
@@ -291,7 +302,7 @@ re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+
class Script(scripts.Script):
def title(self):
- return "X/Y plot"
+ return "X/Y/Z plot"
def ui(self, is_img2img):
self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
@@ -301,24 +312,35 @@ class Script(scripts.Script):
with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
- fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
+ fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
- fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
+ fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False)
+
+ with gr.Row():
+ z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
+ z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
+ fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"):
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
- swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
+ swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
+ swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
+ swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
- def swap_axes(x_type, x_values, y_type, y_values):
- return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values
+ def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
+ return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
- swap_args = [x_type, x_values, y_type, y_values]
- swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
+ xy_swap_args = [x_type, x_values, y_type, y_values]
+ swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
+ yz_swap_args = [y_type, y_values, z_type, z_values]
+ swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
+ xz_swap_args = [x_type, x_values, z_type, z_values]
+ swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
def fill(x_type):
axis = self.current_axis_options[x_type]
@@ -326,16 +348,18 @@ class Script(scripts.Script):
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
+ fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
def select_axis(x_type):
return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
+ z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
- return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
+ return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds]
- def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
+ def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
@@ -409,6 +433,9 @@ class Script(scripts.Script):
y_opt = self.current_axis_options[y_type]
ys = process_axis(y_opt, y_values)
+ z_opt = self.current_axis_options[z_type]
+ zs = process_axis(z_opt, z_values)
+
def fix_axis_seeds(axis_opt, axis_list):
if axis_opt.label in ['Seed', 'Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
@@ -418,21 +445,26 @@ class Script(scripts.Script):
if not no_fixed_seeds:
xs = fix_axis_seeds(x_opt, xs)
ys = fix_axis_seeds(y_opt, ys)
+ zs = fix_axis_seeds(z_opt, zs)
if x_opt.label == 'Steps':
- total_steps = sum(xs) * len(ys)
+ total_steps = sum(xs) * len(ys) * len(zs)
elif y_opt.label == 'Steps':
- total_steps = sum(ys) * len(xs)
+ total_steps = sum(ys) * len(xs) * len(zs)
+ elif z_opt.label == 'Steps':
+ total_steps = sum(zs) * len(xs) * len(ys)
else:
- total_steps = p.steps * len(xs) * len(ys)
+ total_steps = p.steps * len(xs) * len(ys) * len(zs)
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
if x_opt.label == "Hires steps":
- total_steps += sum(xs) * len(ys)
+ total_steps += sum(xs) * len(ys) * len(zs)
elif y_opt.label == "Hires steps":
- total_steps += sum(ys) * len(xs)
+ total_steps += sum(ys) * len(xs) * len(zs)
+ elif z_opt.label == "Hires steps":
+ total_steps += sum(zs) * len(xs) * len(ys)
elif p.hr_second_pass_steps:
- total_steps += p.hr_second_pass_steps * len(xs) * len(ys)
+ total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
else:
total_steps *= 2
@@ -440,7 +472,8 @@ class Script(scripts.Script):
image_cell_count = p.n_iter * p.batch_size
cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
- print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})")
+ plural_s = 's' if len(zs) > 1 else ''
+ print(f"X/Y plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
shared.total_tqdm.updateTotal(total_steps)
grid_infotext = [None]
@@ -450,13 +483,14 @@ class Script(scripts.Script):
# `for` loop.
swap_axes_processing_order = x_opt.cost > y_opt.cost
- def cell(x, y):
+ def cell(x, y, z):
if shared.state.interrupted:
return Processed(p, [], p.seed, "")
pc = copy(p)
x_opt.apply(pc, x, xs)
y_opt.apply(pc, y, ys)
+ z_opt.apply(pc, z, zs)
res = process_images(pc)
@@ -475,17 +509,25 @@ class Script(scripts.Script):
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
+ if z_opt.label != 'Nothing':
+ pc.extra_generation_params["Z Type"] = z_opt.label
+ pc.extra_generation_params["Z Values"] = z_values
+ if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
+
grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
return res
with SharedSettingsStackHelper():
- processed = draw_xy_grid(
+ processed = draw_xyz_grid(
p,
xs=xs,
ys=ys,
+ zs=zs,
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
+ z_labels=[y_opt.format_value(p, z_opt, z) for z in zs],
cell=cell,
draw_legend=draw_legend,
include_lone_images=include_lone_images,
@@ -493,6 +535,6 @@ class Script(scripts.Script):
)
if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+ images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
return processed
--
cgit v1.2.3
From e46bfa5a9e9b489ae925a9c23880e34fe8d9fffa Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Tue, 24 Jan 2023 02:24:32 -0500
Subject: handling sub grids and merging into one
---
modules/images.py | 2 +-
scripts/xyz_grid.py | 29 ++++++++++++++++++-----------
2 files changed, 19 insertions(+), 12 deletions(-)
diff --git a/modules/images.py b/modules/images.py
index 3b1c5f34..0bc3d524 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -195,7 +195,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
ver_texts]
- pad_top = max(hor_text_heights) + line_spacing * 2
+ pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
result.paste(im, (pad_left, pad_top))
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 494e8417..a16653da 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -205,9 +205,9 @@ axis_options = [
]
-def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
- ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
+def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order):
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
# Temporary list of all the images that are generated to be populated into the grid.
@@ -266,16 +266,21 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
return Processed(p, [])
- for i, title_text in enumerate(title_texts):
+ grids = [None] * len(zs)
+ for i in range(len(zs)):
start_index = i * len(xs) * len(ys)
end_index = start_index + len(xs) * len(ys)
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
if draw_legend:
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
- if i == 0: # First position is a placeholder as mentioned above, so it can be directly replaced
- processed_result.images[0] = grid
- else:
- processed_result.images.insert(i, grid)
+
+ grids[i] = grid
+ if include_sub_grids and len(zs) > 1:
+ processed_result.images.insert(i+1, grid)
+
+ original_grid_size = grids[0].size
+ grids = images.image_grid(grids, rows=1)
+ processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]])
return processed_result
@@ -326,7 +331,8 @@ class Script(scripts.Script):
with gr.Row(variant="compact", elem_id="axis_options"):
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
- include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
+ include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
+ include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
@@ -357,9 +363,9 @@ class Script(scripts.Script):
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
- return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds]
+ return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds]
- def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds):
+ def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
@@ -527,10 +533,11 @@ class Script(scripts.Script):
zs=zs,
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
- z_labels=[y_opt.format_value(p, z_opt, z) for z in zs],
+ z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
cell=cell,
draw_legend=draw_legend,
include_lone_images=include_lone_images,
+ include_sub_grids=include_sub_grids,
swap_axes_processing_order=swap_axes_processing_order
)
--
cgit v1.2.3
From ec8774729e17f87a8ffa5a3c5328d12834cbb02a Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Tue, 24 Jan 2023 02:53:35 -0500
Subject: swaps xyz axes internally if one costs more
---
scripts/xyz_grid.py | 64 +++++++++++++++++++++++++++++++++++++++++++----------
1 file changed, 52 insertions(+), 12 deletions(-)
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index a16653da..828c2d12 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -205,7 +205,7 @@ axis_options = [
]
-def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order):
+def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed):
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
@@ -224,7 +224,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
nonlocal image_cache, processed_result, cell_mode, cell_size
def index(ix, iy, iz):
- return ix + iy*len(xs) + iz*len(xs)*len(ys)
+ return ix + iy * len(xs) + iz * len(xs) * len(ys)
state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
@@ -251,16 +251,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
except:
image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
- if swap_axes_processing_order:
+ if first_axes_processed == 'x':
for ix, x in enumerate(xs):
- for iy, y in enumerate(ys):
- for iy, y in enumerate(zs):
- process_cell(x, y, z, ix, iy, iz)
- else:
+ if second_axes_processed == 'y':
+ for iy, y in enumerate(ys):
+ for iz, z in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iz, z in enumerate(zs):
+ for iy, y in enumerate(ys):
+ process_cell(x, y, z, ix, iy, iz)
+ elif first_axes_processed == 'y':
for iy, y in enumerate(ys):
- for ix, x in enumerate(xs):
+ if second_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ for iz, z in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
for iz, z in enumerate(zs):
- process_cell(x, y, z, ix, iy, iz)
+ for ix, x in enumerate(xs):
+ process_cell(x, y, z, ix, iy, iz)
+ elif first_axes_processed == 'z':
+ for iz, z in enumerate(zs):
+ if second_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ for iy, y in enumerate(ys):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, z, ix, iy, iz)
if not processed_result:
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
@@ -322,7 +342,7 @@ class Script(scripts.Script):
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
- fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False)
+ fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
with gr.Row():
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
@@ -487,7 +507,26 @@ class Script(scripts.Script):
# If one of the axes is very slow to change between (like SD model
# checkpoint), then make sure it is in the outer iteration of the nested
# `for` loop.
- swap_axes_processing_order = x_opt.cost > y_opt.cost
+ first_axes_processed = 'x'
+ second_axes_processed = 'y'
+ if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
+ first_axes_processed = 'x'
+ if y_opt.cost > z_opt.cost:
+ second_axes_processed = 'y'
+ else:
+ second_axes_processed = 'z'
+ elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
+ first_axes_processed = 'y'
+ if x_opt.cost > z_opt.cost:
+ second_axes_processed = 'x'
+ else:
+ second_axes_processed = 'z'
+ elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
+ first_axes_processed = 'z'
+ if x_opt.cost > y_opt.cost:
+ second_axes_processed = 'x'
+ else:
+ second_axes_processed = 'y'
def cell(x, y, z):
if shared.state.interrupted:
@@ -538,7 +577,8 @@ class Script(scripts.Script):
draw_legend=draw_legend,
include_lone_images=include_lone_images,
include_sub_grids=include_sub_grids,
- swap_axes_processing_order=swap_axes_processing_order
+ first_axes_processed=first_axes_processed,
+ second_axes_processed=second_axes_processed
)
if opts.grid_save:
--
cgit v1.2.3
From dac45299dd57c6cb240424b93fd28a085605bd90 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 20:22:19 +0300
Subject: make git commands not fail for extensions when you have spaces in
webui directory
---
launch.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/launch.py b/launch.py
index 82094fa0..6d523a34 100644
--- a/launch.py
+++ b/launch.py
@@ -108,18 +108,18 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None:
return
- current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash:
return
- run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
if commithash is not None:
- run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def version_check(commit):
--
cgit v1.2.3
From 28189985e6f56dc725938a3f0e4d2462dad74bc5 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 20:24:27 +0300
Subject: remove fairscale requirement, add fake fairscale to make BLIP not
complain about it
---
modules/interrogate.py | 11 +++++++++--
requirements.txt | 1 -
requirements_versions.txt | 1 -
3 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 236e6983..9f063197 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -82,9 +82,16 @@ class InterrogateModels:
return self.loaded_categories
+ def create_fake_fairscale(self):
+ class FakeFairscale:
+ def checkpoint_wrapper(self):
+ pass
+
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
+
def load_blip_model(self):
- with paths.Prioritize("BLIP"):
- import models.blip
+ create_fake_fairscale()
+ import models.blip
files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "BLIP"),
diff --git a/requirements.txt b/requirements.txt
index ef5e3472..a4be1ec3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,6 @@
blendmodes
accelerate
basicsr
-fairscale==0.4.4
fonts
font-roboto
gfpgan
diff --git a/requirements_versions.txt b/requirements_versions.txt
index f97ad765..135908be 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -14,7 +14,6 @@ scikit-image==0.19.2
fonts
font-roboto
timm==0.6.7
-fairscale==0.4.9
piexif==1.1.3
einops==0.4.1
jsonmerge==1.8.0
--
cgit v1.2.3
From 5228ec8bdada50a8d614573e980193ca89192361 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 20:30:43 +0300
Subject: remove fairscale requirement, add fake fairscale to make BLIP not
complain about it mk2
---
modules/interrogate.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9f063197..c72ff694 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -90,7 +90,7 @@ class InterrogateModels:
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
def load_blip_model(self):
- create_fake_fairscale()
+ self.create_fake_fairscale()
import models.blip
files = modelloader.load_models(
--
cgit v1.2.3
From 93fad28a979727f9b1331dbdc447598824057cdc Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 21:13:05 +0300
Subject: print progress when installing torch add PIP_INSTALLER_LOCATION env
var to install pip if it's not installed remove accidental call to accelerate
when venv is disabled add another env var to skip venv - SKIP_VENV
---
launch.py | 20 +++++++++++++++++---
webui.bat | 8 +++++---
2 files changed, 22 insertions(+), 6 deletions(-)
diff --git a/launch.py b/launch.py
index 6d523a34..f578c1c7 100644
--- a/launch.py
+++ b/launch.py
@@ -48,10 +48,19 @@ def extract_opt(args, name):
return args, is_present, opt
-def run(command, desc=None, errdesc=None, custom_env=None):
+def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
+ if live:
+ result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
+ if result.returncode != 0:
+ raise RuntimeError(f"""{errdesc or 'Error running command'}.
+Command: {command}
+Error code: {result.returncode}""")
+
+ return ""
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
@@ -179,6 +188,8 @@ def run_extensions_installers(settings_file):
def prepare_environment():
global skip_install
+ pip_installer_location = os.environ.get('PIP_INSTALLER_LOCATION', None)
+
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -219,9 +230,12 @@ def prepare_environment():
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")
-
+
+ if pip_installer_location is not None and not is_installed("pip"):
+ run(f'"{python}" "{pip_installer_location}"', "Installing pip", "Couldn't install pip")
+
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
- run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
diff --git a/webui.bat b/webui.bat
index 3165b94d..0d6865c9 100644
--- a/webui.bat
+++ b/webui.bat
@@ -3,6 +3,7 @@
if not defined PYTHON (set PYTHON=python)
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
+
set ERROR_REPORTING=FALSE
mkdir tmp 2>NUL
@@ -14,6 +15,7 @@ goto :show_stdout_stderr
:start_venv
if ["%VENV_DIR%"] == ["-"] goto :skip_venv
+if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :activate_venv
@@ -28,13 +30,13 @@ goto :show_stdout_stderr
:activate_venv
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON%
-if [%ACCELERATE%] == ["True"] goto :accelerate
-goto :launch
:skip_venv
+if [%ACCELERATE%] == ["True"] goto :accelerate
+goto :launch
:accelerate
-echo "Checking for accelerate"
+echo Checking for accelerate
set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
if EXIST %ACCELERATE% goto :accelerate_launch
--
cgit v1.2.3
From bef193189500884c2b20605290ac8bef8251a788 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 24 Jan 2023 23:50:04 +0300
Subject: add fastapi to requirements
---
requirements_versions.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 135908be..1c328d44 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -4,6 +4,7 @@ accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.16.2
+fastapi==0.82.0
numpy==1.23.3
Pillow==9.4.0
realesrgan==0.3.0
--
cgit v1.2.3
From 48a15821de768fea76e66f26df83df3fddf18f4b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 00:49:16 +0300
Subject: remove the pip install stuff because it does not work as i hoped it
would
---
launch.py | 5 -----
requirements_versions.txt | 1 -
webui.bat | 13 +++++++++++--
3 files changed, 11 insertions(+), 8 deletions(-)
diff --git a/launch.py b/launch.py
index f578c1c7..9d6f4a8c 100644
--- a/launch.py
+++ b/launch.py
@@ -188,8 +188,6 @@ def run_extensions_installers(settings_file):
def prepare_environment():
global skip_install
- pip_installer_location = os.environ.get('PIP_INSTALLER_LOCATION', None)
-
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -231,9 +229,6 @@ def prepare_environment():
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")
- if pip_installer_location is not None and not is_installed("pip"):
- run(f'"{python}" "{pip_installer_location}"', "Installing pip", "Couldn't install pip")
-
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 1c328d44..135908be 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -4,7 +4,6 @@ accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.16.2
-fastapi==0.82.0
numpy==1.23.3
Pillow==9.4.0
realesrgan==0.3.0
diff --git a/webui.bat b/webui.bat
index 0d6865c9..209d972b 100644
--- a/webui.bat
+++ b/webui.bat
@@ -9,10 +9,19 @@ set ERROR_REPORTING=FALSE
mkdir tmp 2>NUL
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :start_venv
+if %ERRORLEVEL% == 0 goto :check_pip
echo Couldn't launch python
goto :show_stdout_stderr
+:check_pip
+%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
+%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+echo Couldn't install pip
+goto :show_stdout_stderr
+
:start_venv
if ["%VENV_DIR%"] == ["-"] goto :skip_venv
if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
@@ -46,7 +55,7 @@ pause
exit /b
:accelerate_launch
-echo "Accelerating"
+echo Accelerating
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
pause
exit /b
--
cgit v1.2.3
From 84d9ce30cb427759547bc7876ed80ab91787d175 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Tue, 24 Jan 2023 23:51:45 -0500
Subject: Add option for float32 sampling with float16 UNet
This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half().
---
README.md | 1 +
modules/deepbooru_model.py | 4 +++-
modules/devices.py | 2 ++
modules/processing.py | 15 ++++++++-------
modules/sd_hijack_unet.py | 29 +++++++++++++++++++++++++++++
modules/sd_hijack_utils.py | 28 ++++++++++++++++++++++++++++
modules/sd_models.py | 10 ++++++++++
modules/shared.py | 1 +
8 files changed, 82 insertions(+), 8 deletions(-)
create mode 100644 modules/sd_hijack_utils.py
diff --git a/README.md b/README.md
index 9c0cd1ef..a5611671 100644
--- a/README.md
+++ b/README.md
@@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
+- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- (You)
diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py
index edd40c81..83d2ff09 100644
--- a/modules/deepbooru_model.py
+++ b/modules/deepbooru_model.py
@@ -2,6 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
+from modules import devices
+
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
@@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
- t_360 = self.n_Conv_0(t_359_padded)
+ t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361)
diff --git a/modules/devices.py b/modules/devices.py
index 524ec7af..0981ef80 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -79,6 +79,8 @@ cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
+dtype_unet = torch.float16
+unet_needs_upcast = False
def randn(seed, shape):
diff --git a/modules/processing.py b/modules/processing.py
index bc541e2f..2d186ba0 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -172,7 +172,8 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
+ conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
@@ -203,7 +204,7 @@ class StableDiffusionProcessing:
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
+ conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -211,7 +212,7 @@ class StableDiffusionProcessing:
)
# Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -225,10 +226,10 @@ class StableDiffusionProcessing:
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
- return self.depth2img_image_conditioning(source_image)
+ return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
- with devices.autocast():
+ with devices.autocast(disable=devices.unet_needs_upcast):
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
@@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(shared.device)
+ image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 18daf8c1..88c94e54 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -1,4 +1,8 @@
import torch
+from packaging import version
+
+from modules import devices
+from modules.sd_hijack_utils import CondFunc
class TorchHijackForUnet:
@@ -28,3 +32,28 @@ class TorchHijackForUnet:
th = TorchHijackForUnet()
+
+
+# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
+def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+ with devices.autocast():
+ return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
+class GELUHijack(torch.nn.GELU, torch.nn.Module):
+ def __init__(self, *args, **kwargs):
+ torch.nn.GELU.__init__(self, *args, **kwargs)
+ def forward(self, x):
+ if devices.unet_needs_upcast:
+ return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
+ else:
+ return torch.nn.GELU.forward(self, x)
+
+unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
+if version.parse(torch.__version__) <= version.parse("1.13.1"):
+ CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
+ CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
+ CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
new file mode 100644
index 00000000..f81b169a
--- /dev/null
+++ b/modules/sd_hijack_utils.py
@@ -0,0 +1,28 @@
+import importlib
+
+class CondFunc:
+ def __new__(cls, orig_func, sub_func, cond_func):
+ self = super(CondFunc, cls).__new__(cls)
+ if isinstance(orig_func, str):
+ func_path = orig_func.split('.')
+ for i in range(len(func_path)-2, -1, -1):
+ try:
+ resolved_obj = importlib.import_module('.'.join(func_path[:i]))
+ break
+ except ImportError:
+ pass
+ for attr_name in func_path[i:-1]:
+ resolved_obj = getattr(resolved_obj, attr_name)
+ orig_func = getattr(resolved_obj, func_path[-1])
+ setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+ self.__init__(orig_func, sub_func, cond_func)
+ return lambda *args, **kwargs: self(*args, **kwargs)
+ def __init__(self, orig_func, sub_func, cond_func):
+ self.__orig_func = orig_func
+ self.__sub_func = sub_func
+ self.__cond_func = cond_func
+ def __call__(self, *args, **kwargs):
+ if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
+ return self.__sub_func(self.__orig_func, *args, **kwargs)
+ else:
+ return self.__orig_func(*args, **kwargs)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 12083848..7c98991a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
+ depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
+ # with --upcast-sampling, don't convert the depth model weights to float16
+ if shared.cmd_opts.upcast_sampling and depth_model:
+ model.depth_model = None
model.half()
model.first_stage_model = vae
+ if depth_model:
+ model.depth_model = depth_model
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -372,6 +380,8 @@ def load_model(checkpoint_info=None):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
timer = Timer()
diff --git a/modules/shared.py b/modules/shared.py
index 5f713bee..4ce1209b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
+parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
--
cgit v1.2.3
From e3b53fd295aca784253dfc8668ec87b537a72f43 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Wed, 25 Jan 2023 00:23:10 -0500
Subject: Add UI setting for upcasting attention to float32
Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers.
In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also.
---
modules/devices.py | 6 +-
modules/processing.py | 2 +-
modules/sd_hijack_optimizations.py | 159 +++++++++++++++++++++++--------------
modules/shared.py | 1 +
modules/sub_quadratic_attention.py | 4 +-
5 files changed, 108 insertions(+), 64 deletions(-)
diff --git a/modules/devices.py b/modules/devices.py
index 0981ef80..6b36622c 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -108,6 +108,10 @@ def autocast(disable=False):
return torch.autocast("cuda")
+def without_autocast(disable=False):
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
+
+
class NansException(Exception):
pass
@@ -125,7 +129,7 @@ def test_for_nans(x, where):
message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half:
- message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
diff --git a/modules/processing.py b/modules/processing.py
index 2d186ba0..a850082d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
- with devices.autocast(disable=devices.unet_needs_upcast):
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 74452709..c02d954c 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, errors
+from modules import shared, errors, devices
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
- s2 = s1.softmax(dim=-1)
- del s1
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+
+ s2 = s1.softmax(dim=-1)
+ del s1
+
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+ del q, k, v
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
- del q, k, v
+ r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- k_in *= self.scale
-
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- mem_free_total = get_available_vram()
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+ dtype = q_in.dtype
+ if shared.opts.upcast_attn:
+ q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k_in = k_in * self.scale
+
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+
+ mem_free_total = get_available_vram()
+
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
+
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+
+ del q, k, v
- del q, k, v
+ r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k = self.to_k(context_k) * self.scale
+ k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r = einsum_op(q, k, v)
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
+
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k = k * self.scale
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ r = einsum_op(q, k, v)
+ r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
# -- End of code from https://github.com/invoke-ai/InvokeAI --
@@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ x = x.to(dtype)
+
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
@@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
- return efficient_dot_product_attention(
- q,
- k,
- v,
- query_chunk_size=q_chunk_size,
- kv_chunk_size=kv_chunk_size,
- kv_chunk_size_min = kv_chunk_size_min,
- use_checkpoint=use_checkpoint,
- )
+ with devices.without_autocast(disable=q.dtype == v.dtype):
+ return efficient_dot_product_attention(
+ q,
+ k,
+ v,
+ query_chunk_size=q_chunk_size,
+ kv_chunk_size=kv_chunk_size,
+ kv_chunk_size_min = kv_chunk_size_min,
+ use_checkpoint=use_checkpoint,
+ )
def get_xformers_flash_attention_op(q, k, v):
@@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
+ out = out.to(dtype)
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
+ out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
diff --git a/modules/shared.py b/modules/shared.py
index 4ce1209b..6a0b96cb 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 55052815..05595323 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -67,7 +67,7 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
- exp_values = torch.bmm(exp_weights, value)
+ exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
@@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
- hidden_states_slice = torch.bmm(attn_probs, value)
+ hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice
--
cgit v1.2.3
From 1bfec873fa13d803f3d4ac2a12bf6983838233fe Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 11:29:46 +0300
Subject: add an experimental option to apply loras to outputs rather than
inputs
---
extensions-builtin/Lora/lora.py | 5 ++++-
extensions-builtin/Lora/scripts/lora_script.py | 7 ++++++-
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 137e58f7..cb8f1d36 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -166,7 +166,10 @@ def lora_forward(module, input, res):
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
+ res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ else:
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 60b9eb64..544b228d 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -3,7 +3,7 @@ import torch
import lora
import extra_networks_lora
import ui_extra_networks_lora
-from modules import script_callbacks, ui_extra_networks, extra_networks
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
@@ -28,3 +28,8 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
+
+
+shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
+ "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
+}))
--
cgit v1.2.3
From ee0a0da3244123cb6d2ba4097a54a1e9caccb687 Mon Sep 17 00:00:00 2001
From: Kyle
Date: Wed, 25 Jan 2023 08:53:23 -0500
Subject: Add instruct-pix2pix hijack
Allows loading instruct-pix2pix models via same method as inpainting models in sd_models.py and sd_hijack_ip2p.py
Adds ddpm_edit.py necessary for instruct-pix2pix
---
modules/models/diffusion/ddpm_edit.py | 1459 +++++++++++++++++++++++++++++++++
modules/sd_hijack_ip2p.py | 13 +
modules/sd_models.py | 12 +-
3 files changed, 1483 insertions(+), 1 deletion(-)
create mode 100644 modules/models/diffusion/ddpm_edit.py
create mode 100644 modules/sd_hijack_ip2p.py
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py
new file mode 100644
index 00000000..f3d49c44
--- /dev/null
+++ b/modules/models/diffusion/ddpm_edit.py
@@ -0,0 +1,1459 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
+# See more details in LICENSE.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ load_ema=True,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+
+ if self.use_ema and load_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+
+ # If initialing from EMA-only checkpoint, create EMA model after loading.
+ if self.use_ema and not load_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+
+ # Our model adds additional channels to the first layer to condition on an input image.
+ # For the first layer, copy existing channel weights and initialize new channel weights to zero.
+ input_keys = [
+ "model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight",
+ ]
+
+ self_sd = self.state_dict()
+ for input_key in input_keys:
+ if input_key not in sd or input_key not in self_sd:
+ continue
+
+ input_weight = self_sd[input_key]
+
+ if input_weight.size() != sd[input_key].size():
+ print(f"Manual init: {input_key}")
+ input_weight.zero_()
+ input_weight[:, :4, :, :].copy_(sd[input_key])
+ ignore_keys.append(input_key)
+
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ else:
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ return batch[k]
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ load_ema=True,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ if self.use_ema and not load_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ cond_key = cond_key or self.cond_stage_key
+ xc = super().get_input(batch, cond_key)
+ if bs is not None:
+ xc["c_crossattn"] = xc["c_crossattn"][:bs]
+ xc["c_concat"] = xc["c_concat"][:bs]
+ cond = {}
+
+ # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
+ random = torch.rand(x.size(0), device=x.device)
+ prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
+ input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
+
+ null_prompt = self.get_learned_conditioning([""])
+ cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
+ cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
+
+ out = [z, cond]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ # same as above but without decorator
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ df = self.split_input_params["vqf"]
+ self.split_input_params['original_image_size'] = x.shape[-2:]
+ bs, nc, h, w = x.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ z = unfold(x) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization
+ return decoded
+
+ else:
+ return self.first_stage_model.encode(x)
+ else:
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
+ def rescale_bbox(bbox):
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ return x0, y0, w, h
+
+ return [rescale_bbox(b) for b in bboxes]
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ if hasattr(self, "split_input_params"):
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
+ assert not return_ids
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+
+ h, w = x_noisy.shape[-2:]
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
+ c_key = next(iter(cond.keys())) # get key
+ c = next(iter(cond.values())) # get value
+ assert (len(c) == 1) # todo extend to list with more than one elem
+ c = c[0] # get element
+
+ c = unfold(c)
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+ elif self.cond_stage_key == 'coordinates_bbox':
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+ # assuming padding of unfold is always 0 and its dilation is always 1
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
+ # as we are operating on latents, we need the factor from the original image size to the
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
+ rescale_latent = 2 ** (num_downs)
+
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+ # need to rescale the tl patch coordinates to be in between (0,1)
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+ for patch_nr in range(z.shape[-1])]
+
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+ patch_limits = [(x_tl, y_tl,
+ rescale_latent * ks[0] / full_img_w,
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+ # tokenize crop coordinates for the bounding boxes of the respective patches
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
+ print(patch_limits_tknzd[0].shape)
+ # cut tknzd crop position from conditioning
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+ print(cut_cond.shape)
+
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+ print(adapted_cond.shape)
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
+ print(adapted_cond.shape)
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+ print(adapted_cond.shape)
+
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+ else:
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
+
+ # apply model by loop over crops
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+ assert not isinstance(output_list[0],
+ tuple) # todo cant deal with multiple model outputs check this never happens
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ x_recon = fold(o) / normalization
+
+ else:
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None,**kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+ shape,cond,verbose=False,**kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True,**kwargs)
+
+ return samples, intermediates
+
+
+ @torch.no_grad()
+ def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
+ plot_diffusion_rows=False, **kwargs):
+
+ use_ddim = False
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N, uncond=0)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reals"] = xc["c_concat"]
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with self.ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with self.ema_scope("Plotting Inpaint"):
+
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ with self.ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with self.ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class Layout2ImgDiffusion(LatentDiffusion):
+ # TODO: move all layout-specific hacks to this class
+ def __init__(self, cond_stage_key, *args, **kwargs):
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+
+ def log_images(self, batch, N=8, *args, **kwargs):
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+
+ key = 'train' if self.training else 'validation'
+ dset = self.trainer.datamodule.datasets[key]
+ mapper = dset.conditional_builders[self.cond_stage_key]
+
+ bbox_imgs = []
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
+ bbox_imgs.append(bboximg)
+
+ cond_img = torch.stack(bbox_imgs, dim=0)
+ logs['bbox_image'] = cond_img
+ return logs
diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py
new file mode 100644
index 00000000..635f015f
--- /dev/null
+++ b/modules/sd_hijack_ip2p.py
@@ -0,0 +1,13 @@
+import collections
+import os.path
+import sys
+import gc
+import time
+
+def should_hijack_ip2p(checkpoint_info):
+ from modules import sd_models
+
+ ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
+ cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
+
+ return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 12083848..cddc2343 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
+from modules.sd_hijack_ip2p import should_hijack_ip2p
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -365,6 +366,15 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
+ if should_hijack_ip2p(checkpoint_info):
+ sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
+ sd_config.model.params.conditioning_key = "hybrid"
+ sd_config.model.params.first_stage_key = "edited"
+ sd_config.model.params.cond_stage_key = "edit"
+ sd_config.model.params.image_size = 16
+ sd_config.model.params.unet_config.params.in_channels = 8
+ sd_config.model.params.unet_config.params.out_channels = 4
+
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
@@ -429,7 +439,7 @@ def reload_model_weights(sd_model=None, info=None):
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
- if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
--
cgit v1.2.3
From bd9b55ee908c43fb1b654b3a3a1320545023ce1c Mon Sep 17 00:00:00 2001
From: Kyle
Date: Wed, 25 Jan 2023 09:41:41 -0500
Subject: Update requirements transformers==4.25.1
Update requirement for transformers to version 4.25.1 to allow instruct-pix2pix demo code to work
---
requirements.txt | 2 +-
requirements_versions.txt | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index a4be1ec3..6d53f089 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,7 +16,7 @@ pytorch_lightning==1.7.7
realesrgan
scikit-image>=0.19
timm==0.4.12
-transformers==4.19.2
+transformers==4.25.1
torch
einops
jsonmerge
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 135908be..eaa08806 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,5 +1,5 @@
blendmodes==2022
-transformers==4.19.2
+transformers==4.25.1
accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
--
cgit v1.2.3
From 57c1baa774d07060af0abbd2974c5f36c8cb63ac Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 18:56:23 +0300
Subject: change to code for live preview fix on OSX to be bit more obvious
---
modules/processing.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/processing.py b/modules/processing.py
index 3bd590ba..57c3db1b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -568,8 +568,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
- if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1:
- # preload approx nn model before sampling for a more deterministic result
+ # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
+ if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
if not p.disable_extra_networks:
--
cgit v1.2.3
From 635499e8329dfd8c4c5ccca180881867f34a9f36 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 19:42:26 +0300
Subject: add pix2pix credits
---
README.md | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index a5611671..c6bd6f27 100644
--- a/README.md
+++ b/README.md
@@ -155,7 +155,8 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
+- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
+- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
-- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- (You)
--
cgit v1.2.3
From e179b6098ac1b1ce9645fef5bd9fd0bc9b918f30 Mon Sep 17 00:00:00 2001
From: "Alex \"mcmonkey\" Goodwin"
Date: Wed, 25 Jan 2023 08:48:40 -0800
Subject: allow symlinks in the textual inversion embeddings folder
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4e90f690..6cf00e65 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -194,7 +194,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path):
return
- for root, dirs, fns in os.walk(embdir.path):
+ for root, dirs, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
--
cgit v1.2.3
From 789d47f832a5c921dbbdd0a657dff9bca7f78d94 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 19:55:31 +0300
Subject: make clicking extra networks button one more time close the extra
networks UI
---
modules/ui_extra_networks.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 8b4f97f8..c6ff889a 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -117,8 +117,13 @@ def create_ui(container, button, tabname):
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
- button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
- button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
+ def toggle_visibility(is_visible):
+ is_visible = not is_visible
+ return is_visible, gr.update(visible=is_visible)
+
+ state_visible = gr.State(value=False)
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
+ button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh():
res = []
--
cgit v1.2.3
From e425b9812b067073eb6edfafac689735f5391b45 Mon Sep 17 00:00:00 2001
From: Spaceginner
Date: Wed, 25 Jan 2023 22:07:48 +0500
Subject: Added Python version check
---
launch.py | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/launch.py b/launch.py
index 9d6f4a8c..86b4a32b 100644
--- a/launch.py
+++ b/launch.py
@@ -17,6 +17,17 @@ stored_commit_hash = None
skip_install = False
+def check_python_version():
+ version = sys.version_info
+ version_range = None
+ if os.name == "nt":
+ version_range = range(7, 11)
+ else:
+ version_range = range(7, 12)
+
+ assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/"
+
+
def commit_hash():
global stored_commit_hash
@@ -321,5 +332,6 @@ def start():
if __name__ == "__main__":
+ check_python_version()
prepare_environment()
start()
--
cgit v1.2.3
From 15e89ef0f6f22f823c19592a401b9e4ee477258c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 20:11:01 +0300
Subject: fix for unet hijack breaking the train tab
---
modules/sd_hijack_unet.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 88c94e54..a6ee577c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -36,8 +36,11 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
- for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
+ if isinstance(cond, dict):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
--
cgit v1.2.3
From 57096823fadbc18b33d9b89d2d3a02d5ebba29f4 Mon Sep 17 00:00:00 2001
From: Spaceginner
Date: Wed, 25 Jan 2023 22:33:35 +0500
Subject: Remove a stacktrace from an assertion to not scare people
---
launch.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/launch.py b/launch.py
index 86b4a32b..cf747e72 100644
--- a/launch.py
+++ b/launch.py
@@ -25,7 +25,10 @@ def check_python_version():
else:
version_range = range(7, 12)
- assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/"
+ try:
+ assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/"
+ except AssertionError as e:
+ print(e)
def commit_hash():
--
cgit v1.2.3
From 2de99d62dd80123bf2d7dcbb2c4970fad5d92d42 Mon Sep 17 00:00:00 2001
From: Spaceginner
Date: Wed, 25 Jan 2023 22:38:28 +0500
Subject: some clarification
---
launch.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/launch.py b/launch.py
index cf747e72..4608bc81 100644
--- a/launch.py
+++ b/launch.py
@@ -26,9 +26,10 @@ def check_python_version():
version_range = range(7, 12)
try:
- assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/"
+ assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first."
except AssertionError as e:
print(e)
+ sys.exit(-1)
def commit_hash():
--
cgit v1.2.3
From 0cc5f380d5a21625413554a6a64b97172b36d64a Mon Sep 17 00:00:00 2001
From: Spaceginner
Date: Wed, 25 Jan 2023 22:41:51 +0500
Subject: even more clarifications(?) i have no idea what commit message should
be
---
launch.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/launch.py b/launch.py
index 4608bc81..801e0371 100644
--- a/launch.py
+++ b/launch.py
@@ -26,7 +26,7 @@ def check_python_version():
version_range = range(7, 12)
try:
- assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first."
+ assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too."
except AssertionError as e:
print(e)
sys.exit(-1)
--
cgit v1.2.3
From f5d73b6a6646b51027c8e6f6c6154f21b58d6af2 Mon Sep 17 00:00:00 2001
From: Spaceginner
Date: Wed, 25 Jan 2023 22:56:09 +0500
Subject: Fixed typo
---
launch.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/launch.py b/launch.py
index 801e0371..e39c68e7 100644
--- a/launch.py
+++ b/launch.py
@@ -21,9 +21,9 @@ def check_python_version():
version = sys.version_info
version_range = None
if os.name == "nt":
- version_range = range(7, 11)
+ version_range = range(7 + 1, 10 + 1)
else:
- version_range = range(7, 12)
+ version_range = range(7 + 1, 11 + 1)
try:
assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too."
--
cgit v1.2.3
From e0df864b8c1f99d7d65d56c6ac8a1e5e314dddba Mon Sep 17 00:00:00 2001
From: brkirch
Date: Wed, 25 Jan 2023 13:19:06 -0500
Subject: Update arguments to use --upcast-sampling
---
webui-macos-env.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
index 95ca9c55..fa187dd1 100644
--- a/webui-macos-env.sh
+++ b/webui-macos-env.sh
@@ -10,7 +10,7 @@ then
fi
export install_dir="$HOME"
-export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate"
+export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
--
cgit v1.2.3
From d1d6ce29831d1b067801c3206f314258de88f683 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 23:25:25 +0300
Subject: add edit_image_conditioning from my earlier edits in case there's an
attempt to inegrate pix2pix properly this allows to use pix2pix model in
img2img though it won't work well this way
---
modules/processing.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/modules/processing.py b/modules/processing.py
index 9e5a2f38..cb41288a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -185,7 +185,12 @@ class StableDiffusionProcessing:
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
+ def edit_image_conditioning(self, source_image):
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+
+ return conditioning_image
+
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -228,6 +233,9 @@ class StableDiffusionProcessing:
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
+ if self.sd_model.cond_stage_key == "edit":
+ return self.edit_image_conditioning(source_image)
+
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
--
cgit v1.2.3
From 6cff4401824299a983c8e13424018efc347b4a2b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 25 Jan 2023 23:25:40 +0300
Subject: fix prompt editing break after first batch in img2img
---
modules/sd_samplers.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 6261d1f7..a7910b56 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -454,7 +454,7 @@ class KDiffusionSampler:
def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap.step = 0
+ self.model_wrap_cfg.step = 0
self.eta = p.eta or opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
--
cgit v1.2.3
From d82d471bf7797fe09dbc6f3a6ac1c2a76142c8f1 Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Thu, 26 Jan 2023 02:09:14 +0300
Subject: Ask user to clarify conditions
---
.github/ISSUE_TEMPLATE/bug_report.yml | 29 +++++++++++++++++++++++------
1 file changed, 23 insertions(+), 6 deletions(-)
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index ed372f22..7d435297 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -37,20 +37,20 @@ body:
id: what-should
attributes:
label: What should have happened?
- description: tell what you think the normal behavior should be
+ description: Tell what you think the normal behavior should be
validations:
required: true
- type: input
id: commit
attributes:
label: Commit where the problem happens
- description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
+ description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations:
required: true
- type: dropdown
id: platforms
attributes:
- label: What platforms do you use to access UI ?
+ label: What platforms do you use to access the UI ?
multiple: true
options:
- Windows
@@ -74,10 +74,27 @@ body:
id: cmdargs
attributes:
label: Command Line Arguments
- description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
+ description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
render: Shell
+ validations:
+ required: true
+ - type: textarea
+ id: extensions
+ attributes:
+ label: List of extensions
+ description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
+ validations:
+ required: true
+ - type: textarea
+ id: logs
+ attributes:
+ label: Console logs
+ description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
+ render: Shell
+ validations:
+ required: true
- type: textarea
id: misc
attributes:
- label: Additional information, context and logs
- description: Please provide us with any relevant additional info, context or log output.
+ label: Additional information
+ description: Please provide us with any relevant additional info or context.
--
cgit v1.2.3
From 10421f93c3f7f7ce88cb40391b46d4e6664eff74 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Thu, 26 Jan 2023 00:34:38 -0500
Subject: Fix full previews, --no-half-vae
---
modules/processing.py | 8 ++++----
modules/sd_hijack_utils.py | 2 +-
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/modules/processing.py b/modules/processing.py
index cb41288a..92894d67 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -172,7 +172,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
@@ -217,7 +217,7 @@ class StableDiffusionProcessing:
)
# Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x)
+ x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
return x
@@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
+ image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
index f81b169a..f8684475 100644
--- a/modules/sd_hijack_utils.py
+++ b/modules/sd_hijack_utils.py
@@ -5,7 +5,7 @@ class CondFunc:
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
- for i in range(len(func_path)-2, -1, -1):
+ for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
--
cgit v1.2.3
From 1619233a747830887831cfea2f05fe826fce1bed Mon Sep 17 00:00:00 2001
From: Spaceginner
Date: Thu, 26 Jan 2023 12:52:44 +0500
Subject: Only Linux will have max 3.11
---
launch.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/launch.py b/launch.py
index e39c68e7..52f3bd52 100644
--- a/launch.py
+++ b/launch.py
@@ -20,10 +20,10 @@ skip_install = False
def check_python_version():
version = sys.version_info
version_range = None
- if os.name == "nt":
- version_range = range(7 + 1, 10 + 1)
- else:
+ if platform.system() == "Linux":
version_range = range(7 + 1, 11 + 1)
+ else:
+ version_range = range(7 + 1, 10 + 1)
try:
assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too."
--
cgit v1.2.3
From f4ec411f2c9d6bc6817a2eca8a2c00f255ffb386 Mon Sep 17 00:00:00 2001
From: "ULTRANOX\\Chris"
Date: Thu, 26 Jan 2023 03:45:16 -0500
Subject: Allow checkpoint merger to merge pix2pix models in the same way that
it currently supports inpainting models.
---
modules/extras.py | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 36123aa5..67ffdee3 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
+ result_is_pix2pix_model = False
if theta_func2:
shared.state.textinfo = f"Loading B"
@@ -186,13 +187,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
- assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
-
- theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
- result_is_inpainting_model = True
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model...
+ print("Detected possible merge of instruct model with non-instruct model.")
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
+ result_is_pix2pix_model = True
+ else:
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
+ result_is_inpainting_model = True
else:
theta_0[key] = theta_func2(a, b, multiplier)
-
+
theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1
@@ -226,6 +231,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
filename = filename_generator() if custom_name == '' else custom_name
filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += ".pix2pix" if result_is_pix2pix_model else ""
filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
--
cgit v1.2.3
From f90798c6b6cc48e514acb08ce02bdb5874bf74d8 Mon Sep 17 00:00:00 2001
From: "ULTRANOX\\Chris"
Date: Thu, 26 Jan 2023 04:38:04 -0500
Subject: Added error check for the rare case a user merges a pix2pix model
with a normal model using weighted sum. Also removed bad print message that
interfered with merging progress bar.
---
modules/extras.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/modules/extras.py b/modules/extras.py
index 67ffdee3..badd13c7 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -186,9 +186,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
+ if a.shape[1] == 4 and b.shape[1] == 8:
+ raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.")
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model...
- print("Detected possible merge of instruct model with non-instruct model.")
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
result_is_pix2pix_model = True
else:
--
cgit v1.2.3
From 9e72dc743480c8b1ca6aeb8ced3af03f3e3243a3 Mon Sep 17 00:00:00 2001
From: "ULTRANOX\\Chris"
Date: Thu, 26 Jan 2023 06:05:40 -0500
Subject: Changed all references to "pix2pix" to the more precise name
"instruct pix2pix". Also changed extension to instrpix2pix at least for now.
---
modules/extras.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index badd13c7..2bf0d17e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -132,7 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
- result_is_pix2pix_model = False
+ result_is_instruct_pix2pix_model = False
if theta_func2:
shared.state.textinfo = f"Loading B"
@@ -187,11 +187,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
if a.shape[1] == 4 and b.shape[1] == 8:
- raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.")
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
- if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model...
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
- result_is_pix2pix_model = True
+ result_is_instruct_pix2pix_model = True
else:
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
@@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
filename = filename_generator() if custom_name == '' else custom_name
filename += ".inpainting" if result_is_inpainting_model else ""
- filename += ".pix2pix" if result_is_pix2pix_model else ""
+ filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else ""
filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
--
cgit v1.2.3
From cdc2fa209a3efdc71a90643a5e7a1df49869cd5f Mon Sep 17 00:00:00 2001
From: "ULTRANOX\\Chris"
Date: Thu, 26 Jan 2023 11:27:07 -0500
Subject: Changed filename addition from "instrpix2pix" to the more readable
".instruct-pix2pix" for newly generated instruct pix2pix models.
---
modules/extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/extras.py b/modules/extras.py
index 2bf0d17e..466ecc15 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
filename = filename_generator() if custom_name == '' else custom_name
filename += ".inpainting" if result_is_inpainting_model else ""
- filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else ""
+ filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
--
cgit v1.2.3
From 7a14c8ab45da8a681792a6331d48a88dd684a0a9 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 26 Jan 2023 23:29:27 +0300
Subject: add an option to enable sections from extras tab in txt2img/img2img
fix some style inconsistenices
---
modules/processing.py | 7 +++++-
modules/scripts.py | 32 ++++++++++++++++++++++----
modules/scripts_auto_postprocessing.py | 42 ++++++++++++++++++++++++++++++++++
modules/scripts_postprocessing.py | 11 ++++++---
modules/shared.py | 15 ++++--------
modules/shared_items.py | 10 ++++++++
modules/ui_components.py | 8 +++++++
scripts/postprocessing_upscale.py | 25 ++++++++++++++++++++
style.css | 6 +----
9 files changed, 133 insertions(+), 23 deletions(-)
create mode 100644 modules/scripts_auto_postprocessing.py
create mode 100644 modules/shared_items.py
diff --git a/modules/processing.py b/modules/processing.py
index 92894d67..262806a1 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = Image.fromarray(x_sample)
+ if p.scripts is not None:
+ pp = scripts.PostprocessImageArgs(image)
+ p.scripts.postprocess_image(p, pp)
+ image = pp.image
+
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
diff --git a/modules/scripts.py b/modules/scripts.py
index 03907a63..6e9dc0c0 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -6,12 +6,16 @@ from collections import namedtuple
import gradio as gr
-from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object()
+class PostprocessImageArgs:
+ def __init__(self, image):
+ self.image = image
+
+
class Script:
filename = None
args_from = None
@@ -65,7 +69,7 @@ class Script:
args contains all values returned by components from ui()
"""
- raise NotImplementedError()
+ pass
def process(self, p, *args):
"""
@@ -100,6 +104,13 @@ class Script:
pass
+ def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -247,11 +258,15 @@ class ScriptRunner:
self.infotext_fields = []
def initialize_scripts(self, is_img2img):
+ from modules import scripts_auto_postprocessing
+
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
- for script_class, path, basedir, script_module in scripts_data:
+ auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
+
+ for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
script = script_class()
script.filename = path
script.is_txt2img = not is_img2img
@@ -332,7 +347,7 @@ class ScriptRunner:
return inputs
- def run(self, p: StableDiffusionProcessing, *args):
+ def run(self, p, *args):
script_index = args[0]
if script_index == 0:
@@ -386,6 +401,15 @@ class ScriptRunner:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def postprocess_image(self, p, pp: PostprocessImageArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_image(p, pp, *script_args)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py
new file mode 100644
index 00000000..30d6d658
--- /dev/null
+++ b/modules/scripts_auto_postprocessing.py
@@ -0,0 +1,42 @@
+from modules import scripts, scripts_postprocessing, shared
+
+
+class ScriptPostprocessingForMainUI(scripts.Script):
+ def __init__(self, script_postproc):
+ self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
+ self.postprocessing_controls = None
+
+ def title(self):
+ return self.script.name
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img):
+ self.postprocessing_controls = self.script.ui()
+ return self.postprocessing_controls.values()
+
+ def postprocess_image(self, p, script_pp, *args):
+ args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
+
+ pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
+ pp.info = {}
+ self.script.process(pp, **args_dict)
+ p.extra_generation_params.update(pp.info)
+ script_pp.image = pp.image
+
+
+def create_auto_preprocessing_script_data():
+ from modules import scripts
+
+ res = []
+
+ for name in shared.opts.postprocessing_enable_in_main_ui:
+ script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
+ if script is None:
+ continue
+
+ constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
+ res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
+
+ return res
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index 25de02d0..ce0ebb61 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -46,6 +46,8 @@ class ScriptPostprocessing:
pass
+
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
res = func(*args, **kwargs)
@@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
script: ScriptPostprocessing = script_class()
script.filename = path
+ if script.name == "Simple Upscale":
+ continue
+
self.scripts.append(script)
def create_script_ui(self, script, inputs):
@@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
import modules.scripts
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
- scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
+ scripts_order = shared.opts.postprocessing_operation_order
def script_score(name):
- name = name.lower()
for i, possible_match in enumerate(scripts_order):
- if possible_match in name:
+ if possible_match == name:
return i
return len(self.scripts)
@@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
def image_changed(self):
for script in self.scripts_in_preferred_order():
script.image_changed()
+
diff --git a/modules/shared.py b/modules/shared.py
index 6a0b96cb..cdeed55d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,8 +13,8 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
-from modules.paths import models_path, script_path, sd_path
+from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items
+from modules.paths import models_path, script_path
demo = None
@@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
-
-def realesrgan_models_names():
- import modules.realesrgan_model
- return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
-
-
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default
@@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
@@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
}))
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
- 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
+ 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
+ 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
}))
diff --git a/modules/shared_items.py b/modules/shared_items.py
new file mode 100644
index 00000000..b5d480c9
--- /dev/null
+++ b/modules/shared_items.py
@@ -0,0 +1,10 @@
+
+
+def realesrgan_models_names():
+ import modules.realesrgan_model
+ return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
+
+def postprocessing_scripts():
+ import modules.scripts
+
+ return modules.scripts.scripts_postproc.scripts
\ No newline at end of file
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 9aec3097..284ca0cf 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
def get_block_name(self):
return "colorpicker"
+
+class DropdownMulti(gr.Dropdown):
+ """Same as gr.Dropdown but always multiselect"""
+ def __init__(self, **kwargs):
+ super().__init__(multiselect=True, **kwargs)
+
+ def get_block_name(self):
+ return "dropdown"
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index 095d29b2..8842bd91 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def image_changed(self):
upscale_cache.clear()
+
+
+class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
+ name = "Simple Upscale"
+ order = 900
+
+ def ui(self):
+ with FormRow():
+ upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+ upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
+
+ return {
+ "upscale_by": upscale_by,
+ "upscaler_name": upscaler_name,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
+ if upscaler_name is None or upscaler_name == "None":
+ return
+
+ upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
+ assert upscaler1, f'could not find upscaler named {upscaler_name}'
+
+ pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
+ pp.info[f"Postprocess upscaler"] = upscaler1.name
diff --git a/style.css b/style.css
index ec046f78..dd914104 100644
--- a/style.css
+++ b/style.css
@@ -164,7 +164,7 @@
min-height: 3.2em;
}
-#txt2img_styles ul, #img2img_styles ul{
+ul.list-none{
max-height: 35em;
z-index: 2000;
}
@@ -714,9 +714,6 @@ footer {
white-space: nowrap;
min-width: auto;
}
-#txt2img_hires_fix{
- margin-left: -0.8em;
-}
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
margin-left: 0em;
@@ -744,7 +741,6 @@ footer {
.dark .gr-compact{
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
- margin-left: 0.8em;
}
.gr-compact{
--
cgit v1.2.3
From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 27 Jan 2023 11:28:12 +0300
Subject: remove the need to place configs near models
---
configs/instruct-pix2pix.yaml | 99 +++++++++++++++
configs/v1-inpainting-inference.yaml | 70 +++++++++++
modules/api/api.py | 5 +-
modules/devices.py | 12 +-
modules/sd_hijack_inpainting.py | 9 --
modules/sd_models.py | 228 +++++++++++++++++------------------
modules/sd_models_config.py | 65 ++++++++++
modules/shared.py | 7 +-
modules/shared_items.py | 15 ++-
modules/timer.py | 35 ++++++
v2-inference-v.yaml | 68 -----------
11 files changed, 411 insertions(+), 202 deletions(-)
create mode 100644 configs/instruct-pix2pix.yaml
create mode 100644 configs/v1-inpainting-inference.yaml
create mode 100644 modules/sd_models_config.py
create mode 100644 modules/timer.py
delete mode 100644 v2-inference-v.yaml
diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml
new file mode 100644
index 00000000..437ddcef
--- /dev/null
+++ b/configs/instruct-pix2pix.yaml
@@ -0,0 +1,99 @@
+# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
+# See more details in LICENSE.
+
+model:
+ base_learning_rate: 1.0e-04
+ target: modules.models.diffusion.ddpm_edit.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: edited
+ cond_stage_key: edit
+ # image_size: 64
+ # image_size: 32
+ image_size: 16
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: hybrid
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: true
+ load_ema: true
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 0 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 128
+ num_workers: 1
+ wrap: false
+ validation:
+ target: edit_dataset.EditDataset
+ params:
+ path: data/clip-filtered-dataset
+ cache_dir: data/
+ cache_name: data_10k
+ split: val
+ min_text_sim: 0.2
+ min_image_sim: 0.75
+ min_direction_sim: 0.2
+ max_samples_per_prompt: 1
+ min_resize_res: 512
+ max_resize_res: 512
+ crop_res: 512
+ output_as_edit: False
+ real_input: True
diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml
new file mode 100644
index 00000000..f9eec37d
--- /dev/null
+++ b/configs/v1-inpainting-inference.yaml
@@ -0,0 +1,70 @@
+model:
+ base_learning_rate: 7.5e-05
+ target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: hybrid # important
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ finetune_keys: null
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/modules/api/api.py b/modules/api/api.py
index 25c65e57..eb7b1da5 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, find_checkpoint_config
+from modules.sd_models import checkpoints_list
+from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import List
@@ -387,7 +388,7 @@ class Api:
]
def get_sd_models(self):
- return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/devices.py b/modules/devices.py
index 6b36622c..2d5f797a 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -34,14 +34,18 @@ def get_cuda_device_string():
return "cuda"
-def get_optimal_device():
+def get_optimal_device_name():
if torch.cuda.is_available():
- return torch.device(get_cuda_device_string())
+ return get_cuda_device_string()
if has_mps():
- return torch.device("mps")
+ return "mps"
+
+ return "cpu"
- return cpu
+
+def get_optimal_device():
+ return torch.device(get_optimal_device_name())
def get_device_for(task):
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 31d2c898..478cd499 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t
-def should_hijack_inpainting(checkpoint_info):
- from modules import sd_models
-
- ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
- cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
-
- return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
-
-
def do_inpainting_hijack():
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7072eb2e..fa208728 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -2,8 +2,6 @@ import collections
import os.path
import sys
import gc
-import time
-from collections import namedtuple
import torch
import re
import safetensors.torch
@@ -14,10 +12,10 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.paths import models_path
-from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
-from modules.sd_hijack_ip2p import should_hijack_ip2p
+from modules.sd_hijack_inpainting import do_inpainting_hijack
+from modules.timer import Timer
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -99,17 +97,6 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
-def find_checkpoint_config(info):
- if info is None:
- return shared.cmd_opts.config
-
- config = os.path.splitext(info.filename)[0] + ".yaml"
- if os.path.exists(config):
- return config
-
- return shared.cmd_opts.config
-
-
def list_models():
checkpoints_list.clear()
checkpoint_alisases.clear()
@@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
- device = map_location or shared.weight_load_location
- if device is None:
- device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+ device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
-def load_model_weights(model, checkpoint_info: CheckpointInfo):
+def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
+ sd_model_hash = checkpoint_info.calculate_shorthash()
+ timer.record("calculate hash")
+
+ if checkpoint_info in checkpoints_loaded:
+ # use checkpoint cache
+ print(f"Loading weights [{sd_model_hash}] from cache")
+ return checkpoints_loaded[checkpoint_info]
+
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
+ res = read_state_dict(checkpoint_info.filename)
+ timer.record("load weights from disk")
+
+ return res
+
+
+def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
+ timer.record("calculate hash")
+
if checkpoint_info.title != title:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
- cache_enabled = shared.opts.sd_checkpoint_cache > 0
+ if state_dict is None:
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
- if cache_enabled and checkpoint_info in checkpoints_loaded:
- # use checkpoint cache
- print(f"Loading weights [{sd_model_hash}] from cache")
- model.load_state_dict(checkpoints_loaded[checkpoint_info])
- else:
- # load from file
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
+ model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ timer.record("apply weights to model")
- sd = read_state_dict(checkpoint_info.filename)
- model.load_state_dict(sd, strict=False)
- del sd
-
- if cache_enabled:
- # cache newly loaded model
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ if shared.opts.sd_checkpoint_cache > 0:
+ # cache newly loaded model
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+
+ if shared.cmd_opts.opt_channelslast:
+ model.to(memory_format=torch.channels_last)
+ timer.record("apply channels_last")
- if shared.cmd_opts.opt_channelslast:
- model.to(memory_format=torch.channels_last)
+ if not shared.cmd_opts.no_half:
+ vae = model.first_stage_model
+ depth_model = getattr(model, 'depth_model', None)
- if not shared.cmd_opts.no_half:
- vae = model.first_stage_model
- depth_model = getattr(model, 'depth_model', None)
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+ if shared.cmd_opts.no_half_vae:
+ model.first_stage_model = None
+ # with --upcast-sampling, don't convert the depth model weights to float16
+ if shared.cmd_opts.upcast_sampling and depth_model:
+ model.depth_model = None
- # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
- if shared.cmd_opts.no_half_vae:
- model.first_stage_model = None
- # with --upcast-sampling, don't convert the depth model weights to float16
- if shared.cmd_opts.upcast_sampling and depth_model:
- model.depth_model = None
+ model.half()
+ model.first_stage_model = vae
+ if depth_model:
+ model.depth_model = depth_model
- model.half()
- model.first_stage_model = vae
- if depth_model:
- model.depth_model = depth_model
+ timer.record("apply half()")
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- devices.dtype_unet = model.model.diffusion_model.dtype
- devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
- model.first_stage_model.to(devices.dtype_vae)
+ model.first_stage_model.to(devices.dtype_vae)
+ timer.record("apply dtype to VAE")
# clean up cache if limit is reached
- if cache_enabled:
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
- checkpoints_loaded.popitem(last=False) # LRU
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False)
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_info.filename
@@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source)
+ timer.record("load VAE")
def enable_midas_autodownload():
@@ -340,24 +340,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
-class Timer:
- def __init__(self):
- self.start = time.time()
+def repair_config(sd_config):
- def elapsed(self):
- end = time.time()
- res = end - self.start
- self.start = end
- return res
+ if not hasattr(sd_config.model.params, "use_ema"):
+ sd_config.model.params.use_ema = False
+ if shared.cmd_opts.no_half:
+ sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
-def load_model(checkpoint_info=None):
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
- checkpoint_config = find_checkpoint_config(checkpoint_info)
-
- if checkpoint_config != shared.cmd_opts.config:
- print(f"Loading config from: {checkpoint_config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
gc.collect()
devices.torch_gc()
- sd_config = OmegaConf.load(checkpoint_config)
-
- if should_hijack_inpainting(checkpoint_info):
- # Hardcoded config for now...
- sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
- sd_config.model.params.conditioning_key = "hybrid"
- sd_config.model.params.unet_config.params.in_channels = 9
- sd_config.model.params.finetune_keys = None
-
- if should_hijack_ip2p(checkpoint_info):
- sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
- sd_config.model.params.conditioning_key = "hybrid"
- sd_config.model.params.first_stage_key = "edited"
- sd_config.model.params.cond_stage_key = "edit"
- sd_config.model.params.image_size = 16
- sd_config.model.params.unet_config.params.in_channels = 8
- sd_config.model.params.unet_config.params.out_channels = 4
+ do_inpainting_hijack()
- if not hasattr(sd_config.model.params, "use_ema"):
- sd_config.model.params.use_ema = False
+ timer = Timer()
- do_inpainting_hijack()
+ if already_loaded_state_dict is not None:
+ state_dict = already_loaded_state_dict
+ else:
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
- if shared.cmd_opts.no_half:
- sd_config.model.params.unet_config.params.use_fp16 = False
- elif shared.cmd_opts.upcast_sampling:
- sd_config.model.params.unet_config.params.use_fp16 = True
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- timer = Timer()
+ timer.record("find config")
- sd_model = None
+ sd_config = OmegaConf.load(checkpoint_config)
+ repair_config(sd_config)
+
+ timer.record("load config")
+
+ print(f"Creating model from config: {checkpoint_config}")
+ sd_model = None
try:
with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model)
@@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)
- elapsed_create = timer.elapsed()
+ sd_model.used_config = checkpoint_config
- load_model_weights(sd_model, checkpoint_info)
+ timer.record("create model")
- elapsed_load_weights = timer.elapsed()
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
sd_model.to(shared.device)
+ timer.record("move model to device")
+
sd_hijack.model_hijack.hijack(sd_model)
+ timer.record("hijack")
+
sd_model.eval()
shared.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
+ timer.record("load textual inversion embeddings")
+
script_callbacks.model_loaded_callback(sd_model)
- elapsed_the_rest = timer.elapsed()
+ timer.record("scripts callbacks")
- print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
+ print(f"Model loaded in {timer.summary()}.")
return sd_model
@@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model:
sd_model = shared.sd_model
+
if sd_model is None: # previous model load failed
current_checkpoint_info = None
else:
@@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- checkpoint_config = find_checkpoint_config(current_checkpoint_info)
-
- if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
- del sd_model
- checkpoints_loaded.clear()
- load_model(checkpoint_info)
- return shared.sd_model
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
@@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None):
timer = Timer()
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
+
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+
+ timer.record("find config")
+
+ if sd_model is None or checkpoint_config != sd_model.used_config:
+ del sd_model
+ checkpoints_loaded.clear()
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
+ return shared.sd_model
+
try:
- load_model_weights(sd_model, checkpoint_info)
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
except Exception as e:
print("Failed to load checkpoint, restoring previous")
- load_model_weights(sd_model, current_checkpoint_info)
+ load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
finally:
sd_hijack.model_hijack.hijack(sd_model)
+ timer.record("hijack")
+
script_callbacks.model_loaded_callback(sd_model)
+ timer.record("script callbacks")
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
+ timer.record("move model to device")
- elapsed = timer.elapsed()
-
- print(f"Weights loaded in {elapsed:.1f}s.")
+ print(f"Weights loaded in {timer.summary()}.")
return sd_model
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
new file mode 100644
index 00000000..ea773a10
--- /dev/null
+++ b/modules/sd_models_config.py
@@ -0,0 +1,65 @@
+import re
+import os
+
+from modules import shared, paths
+
+sd_configs_path = shared.sd_configs_path
+sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
+
+
+config_default = shared.sd_default_config
+config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
+config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
+config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
+config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
+config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
+
+re_parametrization_v = re.compile(r'-v\b')
+
+
+def guess_model_config_from_state_dict(sd, filename):
+ fn = os.path.basename(filename)
+
+ sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
+ diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
+ roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None)
+
+ if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
+ if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
+ return config_sd2v
+ else:
+ return config_sd2
+
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ return config_inpainting
+ if diffusion_model_input.shape[1] == 8:
+ return config_instruct_pix2pix
+
+ if roberta_weight is not None:
+ return config_alt_diffusion
+
+ return config_default
+
+
+def find_checkpoint_config(state_dict, info):
+ if info is None:
+ return guess_model_config_from_state_dict(state_dict, "")
+
+ config = find_checkpoint_config_near_filename(info)
+ if config is not None:
+ return config
+
+ return guess_model_config_from_state_dict(state_dict, info.filename)
+
+
+def find_checkpoint_config_near_filename(info):
+ if info is None:
+ return None
+
+ config = os.path.splitext(info.filename)[0] + ".yaml"
+ if os.path.exists(config):
+ return config
+
+ return None
+
diff --git a/modules/shared.py b/modules/shared.py
index cdeed55d..14be993d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,13 +13,14 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items
+from modules import localization, extensions, script_loading, errors, ui_components, shared_items
from modules.paths import models_path, script_path
demo = None
-sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
+sd_configs_path = os.path.join(script_path, "configs")
+sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
@@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
diff --git a/modules/shared_items.py b/modules/shared_items.py
index b5d480c9..8b5ec96d 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -4,7 +4,20 @@ def realesrgan_models_names():
import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
+
def postprocessing_scripts():
import modules.scripts
- return modules.scripts.scripts_postproc.scripts
\ No newline at end of file
+ return modules.scripts.scripts_postproc.scripts
+
+
+def sd_vae_items():
+ import modules.sd_vae
+
+ return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
+
+
+def refresh_vae_list():
+ import modules.sd_vae
+
+ return modules.sd_vae.refresh_vae_list
diff --git a/modules/timer.py b/modules/timer.py
new file mode 100644
index 00000000..57a4f17a
--- /dev/null
+++ b/modules/timer.py
@@ -0,0 +1,35 @@
+import time
+
+
+class Timer:
+ def __init__(self):
+ self.start = time.time()
+ self.records = {}
+ self.total = 0
+
+ def elapsed(self):
+ end = time.time()
+ res = end - self.start
+ self.start = end
+ return res
+
+ def record(self, category, extra_time=0):
+ e = self.elapsed()
+ if category not in self.records:
+ self.records[category] = 0
+
+ self.records[category] += e + extra_time
+ self.total += e + extra_time
+
+ def summary(self):
+ res = f"{self.total:.1f}s"
+
+ additions = [x for x in self.records.items() if x[1] >= 0.1]
+ if not additions:
+ return res
+
+ res += " ("
+ res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
+ res += ")"
+
+ return res
diff --git a/v2-inference-v.yaml b/v2-inference-v.yaml
deleted file mode 100644
index 513cd635..00000000
--- a/v2-inference-v.yaml
+++ /dev/null
@@ -1,68 +0,0 @@
-model:
- base_learning_rate: 1.0e-4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- parameterization: "v"
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: "jpg"
- cond_stage_key: "txt"
- image_size: 64
- channels: 4
- cond_stage_trainable: false
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False # we set this to false because this is an inference only config
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- use_checkpoint: True
- use_fp16: True
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
\ No newline at end of file
--
cgit v1.2.3
From 6f31d2210c189f8db118e6f95add7ba2a64f0238 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 27 Jan 2023 11:54:19 +0300
Subject: support detecting midas model fix broken api for checkpoint list
---
modules/api/models.py | 2 +-
modules/sd_models.py | 10 +++++-----
modules/sd_models_config.py | 7 +++++--
3 files changed, 11 insertions(+), 8 deletions(-)
diff --git a/modules/api/models.py b/modules/api/models.py
index 805bd8f7..cba43d3b 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -228,7 +228,7 @@ class SDModelItem(BaseModel):
hash: Optional[str] = Field(title="Short hash")
sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename")
- config: str = Field(title="Config file")
+ config: Optional[str] = Field(title="Config file")
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
diff --git a/modules/sd_models.py b/modules/sd_models.py
index fa208728..37dad18d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -439,12 +439,12 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ else:
+ sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(sd_model)
+ sd_hijack.model_hijack.undo_hijack(sd_model)
timer = Timer()
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index ea773a10..4d1e92e1 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs",
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
+config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
@@ -22,7 +23,9 @@ def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
- roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None)
+
+ if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
+ return config_depth_model
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
@@ -36,7 +39,7 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
- if roberta_weight is not None:
+ if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
return config_alt_diffusion
return config_default
--
cgit v1.2.3
From 9beb794e0b0dc1a0f9e89d8e38bd789a8c608397 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 27 Jan 2023 13:08:00 +0300
Subject: clarify the option to disable NaN check.
---
modules/devices.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/modules/devices.py b/modules/devices.py
index 2d5f797a..4687944e 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -143,6 +143,8 @@ def test_for_nans(x, where):
else:
message = "A tensor with all NaNs was produced."
+ message += " Use --disable-nan-check commandline argument to disable this check."
+
raise NansException(message)
--
cgit v1.2.3
From 9ecf1e827c5966e11495a0c066a127defbba9bcc Mon Sep 17 00:00:00 2001
From: Spaceginner
Date: Fri, 27 Jan 2023 17:35:24 +0500
Subject: Made it only a warning
---
.gitignore | 1 +
launch.py | 39 ++++++++++++++++++++++++++++-----------
2 files changed, 29 insertions(+), 11 deletions(-)
diff --git a/.gitignore b/.gitignore
index 0b1d17ca..c8be9688 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,3 +33,4 @@ notification.mp3
/test/stdout.txt
/test/stderr.txt
/cache.json
+no_py_ver_warning
diff --git a/launch.py b/launch.py
index 52f3bd52..4f5a4bc4 100644
--- a/launch.py
+++ b/launch.py
@@ -18,18 +18,35 @@ skip_install = False
def check_python_version():
- version = sys.version_info
- version_range = None
- if platform.system() == "Linux":
- version_range = range(7 + 1, 11 + 1)
- else:
- version_range = range(7 + 1, 10 + 1)
+ if not os.path.isfile("no_py_ver_warning"):
+ version = sys.version_info
+ version_range = None
+ if platform.system() == "Linux":
+ version_range = range(7 + 1, 11 + 1)
+ else:
+ version_range = range(7 + 1, 10 + 1)
- try:
- assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too."
- except AssertionError as e:
- print(e)
- sys.exit(-1)
+ try:
+ assert version.major == 3 and version.minor in version_range, f"""
+=== Warning ===
+This program was tested only with 3.10 Python, but you have {version.major}.{version.minor} Python.
+If you encounter an error with "RuntimeError: Couldn't install torch." message,
+or any other error regarding unsuccessful package (library) installation,
+please downgrade (or upgrade) to the latest version of 3.10 Python
+and delete current Python and "venv" folder in WebUI's directory.
+
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
+
+You will see this warning only once, delete file "no_py_ver_warning" file to show this warning again.
+=== Warning ===
+
+Press ENTER to continue...\
+"""
+ except AssertionError as e:
+ print(e)
+ with open("no_py_ver_warning", "w"):
+ pass
+ input()
def commit_hash():
--
cgit v1.2.3
From 5eee2ac39863f9e44591b50d0710dd2615416a13 Mon Sep 17 00:00:00 2001
From: Max Audron
Date: Wed, 25 Jan 2023 17:15:42 +0100
Subject: add data-dir flag and set all user data directories based on it
---
modules/extensions.py | 2 +-
modules/generation_parameters_copypaste.py | 4 ++--
modules/gfpgan_model.py | 5 ++---
modules/hashes.py | 4 +++-
modules/interrogate.py | 2 +-
modules/paths.py | 10 +++++++++-
modules/processing.py | 3 ++-
modules/sd_models.py | 6 +++---
modules/sd_vae.py | 5 ++---
modules/shared.py | 11 ++++++-----
modules/textual_inversion/preprocess.py | 5 ++---
modules/ui.py | 6 +++---
modules/ui_extensions.py | 2 +-
modules/upscaler.py | 5 ++---
14 files changed, 39 insertions(+), 31 deletions(-)
diff --git a/modules/extensions.py b/modules/extensions.py
index b522125c..92ee8144 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -7,7 +7,7 @@ import git
from modules import paths, shared
extensions = []
-extensions_dir = os.path.join(paths.script_path, "extensions")
+extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 46e12dc6..35f72808 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -6,7 +6,7 @@ import re
from pathlib import Path
import gradio as gr
-from modules.shared import script_path
+from modules.paths import data_path, script_path
from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image
@@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
- filename = os.path.join(script_path, "params.txt")
+ filename = os.path.join(data_path, "params.txt")
if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file:
prompt = file.read()
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 1e2dbc32..fbe6215a 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -6,12 +6,11 @@ import facexlib
import gfpgan
import modules.face_restoration
-from modules import shared, devices, modelloader
-from modules.paths import models_path
+from modules import paths, shared, devices, modelloader
model_dir = "GFPGAN"
user_path = None
-model_path = os.path.join(models_path, model_dir)
+model_path = os.path.join(paths.models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None
diff --git a/modules/hashes.py b/modules/hashes.py
index b85a7580..819362a3 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -4,8 +4,10 @@ import os.path
import filelock
+from modules.paths import data_path
-cache_filename = "cache.json"
+
+cache_filename = os.path.join(data_path, "cache.json")
cache_data = None
diff --git a/modules/interrogate.py b/modules/interrogate.py
index c72ff694..cbb80683 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -12,7 +12,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram, modelloader, errors
+from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
diff --git a/modules/paths.py b/modules/paths.py
index 20b3e4d8..08e6f9b9 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -4,7 +4,15 @@ import sys
import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
-models_path = os.path.join(script_path, "models")
+
+# Parse the --data-dir flag first so we can use it as a base for our other argument default values
+parser = argparse.ArgumentParser()
+parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
+cmd_opts_pre = parser.parse_known_args()[0]
+data_path = cmd_opts_pre.data_dir
+models_path = os.path.join(data_path, "models")
+
+# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places
diff --git a/modules/processing.py b/modules/processing.py
index 262806a1..5072fc40 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
+import modules.paths as paths
import modules.face_restoration
import modules.images as images
import modules.styles
@@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if not p.disable_extra_networks:
extra_networks.activate(p, extra_network_data)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 37dad18d..b2d48a51 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -12,13 +12,13 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
model_dir = "Stable-diffusion"
-model_path = os.path.abspath(os.path.join(models_path, model_dir))
+model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
checkpoints_list = {}
checkpoint_alisases = {}
@@ -307,7 +307,7 @@ def enable_midas_autodownload():
location automatically.
"""
- midas_path = os.path.join(models_path, 'midas')
+ midas_path = os.path.join(paths.models_path, 'midas')
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 4ce238b8..9b00f76e 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -3,13 +3,12 @@ import safetensors.torch
import os
import collections
from collections import namedtuple
-from modules import shared, devices, script_callbacks, sd_models
-from modules.paths import models_path
+from modules import paths, shared, devices, script_callbacks, sd_models
import glob
from copy import deepcopy
-vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
+vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
vae_dict = {}
diff --git a/modules/shared.py b/modules/shared.py
index 14be993d..474fcc42 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
-from modules.paths import models_path, script_path
+from modules.paths import models_path, script_path, data_path
demo = None
@@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
+parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
@@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
-parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
@@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
-parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
+parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
-parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
+parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
-parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
+parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index c0ac11d3..2239cb84 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -6,8 +6,7 @@ import sys
import tqdm
import time
-from modules import shared, images, deepbooru
-from modules.paths import models_path
+from modules import paths, shared, images, deepbooru
from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
@@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = None
try:
- dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
diff --git a/modules/ui.py b/modules/ui.py
index 85ae62c7..0117df3e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
-from modules.paths import script_path
+from modules.paths import script_path, data_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -1497,8 +1497,8 @@ def create_ui():
with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n"
- if os.path.exists(os.path.join(script_path, "user.css")):
- with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
+ if os.path.exists(os.path.join(data_path, "user.css")):
+ with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 742e745e..66a41865 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url):
normalized_url = normalize_git_url(url)
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
- tmpdir = os.path.join(paths.script_path, "tmp", dirname)
+ tmpdir = os.path.join(paths.data_path, "tmp", dirname)
try:
shutil.rmtree(tmpdir, True)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index a5bf5acb..e2eaa730 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -11,7 +11,6 @@ from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
-from modules.paths import models_path
class Upscaler:
@@ -39,7 +38,7 @@ class Upscaler:
self.mod_scale = None
if self.model_path is None and self.name:
- self.model_path = os.path.join(models_path, self.name)
+ self.model_path = os.path.join(shared.models_path, self.name)
if self.model_path and create_dirs:
os.makedirs(self.model_path, exist_ok=True)
@@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler):
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Nearest"
- self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
+ self.scalers = [UpscalerData("Nearest", None, self)]
--
cgit v1.2.3
From 14c0884fd0948c478db165989cca7aaffc9a0504 Mon Sep 17 00:00:00 2001
From: Max Audron
Date: Wed, 25 Jan 2023 17:55:59 +0100
Subject: use python importlib to load and execute extension modules
previously module attributes like __file__ where not set correctly,
leading to scripts getting the directory of the stable-diffusion repo
location instead of their own script.
This causes problem when loading user data from an external location
using the --data-dir flag, as extensions would look for their own code
in the stable-diffusion repo location instead of the data dir location.
Using pythons importlib functions sets the modules specs correctly and
executes them. But this will break extensions if they build paths based
on the previously incorrect __file__ attribute.
---
modules/script_loading.py | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/modules/script_loading.py b/modules/script_loading.py
index f93f0951..a7d2203f 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,16 +1,14 @@
import os
import sys
import traceback
+import importlib.util
from types import ModuleType
def load_module(path):
- with open(path, "r", encoding="utf8") as file:
- text = file.read()
-
- compiled = compile(text, path, 'exec')
- module = ModuleType(os.path.basename(path))
- exec(compiled, module.__dict__)
+ module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
+ module = importlib.util.module_from_spec(module_spec)
+ module_spec.loader.exec_module(module)
return module
--
cgit v1.2.3
From 6b3981c0685cd1df750df4eb51823f1cfd70c6d5 Mon Sep 17 00:00:00 2001
From: Max Audron
Date: Wed, 25 Jan 2023 18:00:09 +0100
Subject: clean up unused script_path imports
---
modules/codeformer_model.py | 2 +-
modules/generation_parameters_copypaste.py | 2 +-
webui.py | 1 -
3 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index ab40d842..01fb7bd8 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -8,7 +8,7 @@ import torch
import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader
-from modules.paths import script_path, models_path
+from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 35f72808..773c5c0e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -6,7 +6,7 @@ import re
from pathlib import Path
import gradio as gr
-from modules.paths import data_path, script_path
+from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image
diff --git a/webui.py b/webui.py
index e1565a8d..41f32f5c 100644
--- a/webui.py
+++ b/webui.py
@@ -15,7 +15,6 @@ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not
from modules import import_hook, errors, extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
-from modules.paths import script_path
import torch
--
cgit v1.2.3
From 23a9d5e27390846dea0895a02c04aec9583a4d38 Mon Sep 17 00:00:00 2001
From: Max Audron
Date: Wed, 25 Jan 2023 18:18:55 +0100
Subject: create user extensions directory if not exists
---
modules/extensions.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/modules/extensions.py b/modules/extensions.py
index 92ee8144..5e12b1aa 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -10,6 +10,8 @@ extensions = []
extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
+if not os.path.exists(extensions_dir):
+ os.makedirs(extensions_dir)
def active():
return [x for x in extensions if x.enabled]
--
cgit v1.2.3
From eafaf14167cf574ad0f918c10f60ef86aea9cd20 Mon Sep 17 00:00:00 2001
From: Gazzoo-byte <73721238+Gazzoo-byte@users.noreply.github.com>
Date: Fri, 27 Jan 2023 18:34:41 +0000
Subject: Add button to switch width and height
Adds a button to switch width and height, allowing quick and easy switching between landscape and portrait.
---
modules/ui.py | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/modules/ui.py b/modules/ui.py
index 85ae62c7..fb0e4d5c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -91,6 +91,13 @@ save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
+switch_values_symbol = '\U000021C5' # ⇅
+
+def switch_width_and_height(width, height):
+ width_temp = width
+ width = height
+ height = width_temp
+ return width, height
def plaintext_to_html(text):
@@ -466,6 +473,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
if opts.dimensions_and_batch_together:
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
@@ -566,6 +574,7 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
+ res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height])
txt_prompt_img.change(
fn=modules.images.image_data,
@@ -728,6 +737,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
if opts.dimensions_and_batch_together:
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
@@ -865,6 +875,7 @@ def create_ui():
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
+ res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height])
img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
--
cgit v1.2.3
From a6a5bfb15531b19ce0319593d67d05a356f49a65 Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Fri, 27 Jan 2023 13:48:39 -0500
Subject: deepcopy pc.styles, allows for multiple style axis
---
scripts/xyz_grid.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 828c2d12..f2fe506c 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -123,7 +123,7 @@ def apply_vae(p, x, xs):
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
- p.styles = x.split(',')
+ p.styles.extend(x.split(','))
def format_value_add_label(p, opt, x):
@@ -533,6 +533,7 @@ class Script(scripts.Script):
return Processed(p, [], p.seed, "")
pc = copy(p)
+ pc.styles = pc.styles[:]
x_opt.apply(pc, x, xs)
y_opt.apply(pc, y, ys)
z_opt.apply(pc, z, zs)
--
cgit v1.2.3
From 32d389ef0f7c75dd85fc7aebe7bca279f36fed86 Mon Sep 17 00:00:00 2001
From: EllangoK
Date: Fri, 27 Jan 2023 14:04:23 -0500
Subject: changes remaining text from X/Y -> X/Y/Z
---
README.md | 2 +-
javascript/hints.js | 2 +-
scripts/xyz_grid.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index c6bd6f27..2149dcc5 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- a man in a (tuxedo:1.21) - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times
-- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
+- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion
- have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token
diff --git a/javascript/hints.js b/javascript/hints.js
index 3cf10e20..7b60b25e 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -50,7 +50,7 @@ titles = {
"None": "Do not do anything special",
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
- "X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
+ "X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index f2fe506c..f0116055 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -499,7 +499,7 @@ class Script(scripts.Script):
image_cell_count = p.n_iter * p.batch_size
cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
plural_s = 's' if len(zs) > 1 else ''
- print(f"X/Y plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
+ print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
shared.total_tqdm.updateTotal(total_steps)
grid_infotext = [None]
--
cgit v1.2.3
From cc8c9b7474d917888a0bd069fcd59a458c67ae4b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 27 Jan 2023 22:43:08 +0300
Subject: fix broken calls to find_checkpoint_config
---
modules/extras.py | 4 ++--
modules/sd_hijack_ip2p.py | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/modules/extras.py b/modules/extras.py
index 36123aa5..4f842be9 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -6,7 +6,7 @@ import shutil
import torch
import tqdm
-from modules import shared, images, sd_models, sd_vae
+from modules import shared, images, sd_models, sd_vae, sd_models_config
from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
@@ -37,7 +37,7 @@ def run_pnginfo(image):
def create_config(ckpt_result, config_source, a, b, c):
def config(x):
- res = sd_models.find_checkpoint_config(x) if x else None
+ res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
return res if res != shared.sd_default_config else None
if config_source == 0:
diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py
index 635f015f..3c727d3b 100644
--- a/modules/sd_hijack_ip2p.py
+++ b/modules/sd_hijack_ip2p.py
@@ -5,9 +5,9 @@ import gc
import time
def should_hijack_ip2p(checkpoint_info):
- from modules import sd_models
+ from modules import sd_models_config
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
- cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
+ cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
--
cgit v1.2.3
From 6b82efd737827bbeef202f04ff5a8faec9b64ef8 Mon Sep 17 00:00:00 2001
From: MrCheeze
Date: Fri, 27 Jan 2023 20:06:19 -0500
Subject: add v2-inpainting model detection, and broaden v-model detection to
include anything with 768 in the name
---
modules/sd_models_config.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 4d1e92e1..73854a45 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs",
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
+config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
@@ -28,7 +29,9 @@ def guess_model_config_from_state_dict(sd, filename):
return config_depth_model
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
- if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
+ if diffusion_model_input.shape[1] == 9:
+ return config_sd2_inpainting
+ elif re.search(re_parametrization_v, fn) or "768" in fn:
return config_sd2v
else:
return config_sd2
--
cgit v1.2.3
From 2aac1d97782b486f3a4a5209cf399dcdcb7bbb4d Mon Sep 17 00:00:00 2001
From: Andrii Skaliuk
Date: Fri, 27 Jan 2023 17:32:31 -0800
Subject: Basic inpainting batch support
Modifies batch UI to add optional inpainting support
---
modules/img2img.py | 20 +++++++++++++++++---
modules/ui.py | 9 ++++++++-
2 files changed, 25 insertions(+), 4 deletions(-)
diff --git a/modules/img2img.py b/modules/img2img.py
index 2168c8e2..fe9447c7 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -16,11 +16,16 @@ import modules.images as images
import modules.scripts
-def process_batch(p, input_dir, output_dir, args):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processing.fix_seed(p)
images = shared.listfiles(input_dir)
+ inpaint_masks = shared.listfiles(inpaint_mask_dir)
+ is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0
+ if is_inpaint_batch:
+ print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
+
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
@@ -43,6 +48,15 @@ def process_batch(p, input_dir, output_dir, args):
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
+ if is_inpaint_batch:
+ # try to find corresponding mask for an image using simple filename matching
+ mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
+ # if not found use first one ("same mask for all images" use-case)
+ if not mask_image_path in inpaint_masks:
+ mask_image_path = inpaint_masks[0]
+ mask_image = Image.open(mask_image_path)
+ p.image_mask = mask_image
+
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
@@ -59,7 +73,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
@@ -139,7 +153,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
processed = Processed(p, [], p.seed, "")
else:
diff --git a/modules/ui.py b/modules/ui.py
index 85ae62c7..fddb9177 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -691,9 +691,15 @@ def create_ui():
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"
Process images in a directory on the same machine where the server is running. Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}
")
+ gr.HTML(
+ f"
Process images in a directory on the same machine where the server is running." +
+ f" Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ f" Add inpaint batch mask directory to enable inpaint batch processing."
+ f"{hidden}
"
+ )
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
def copy_image(img):
if isinstance(img, dict) and 'image' in img:
@@ -838,6 +844,7 @@ def create_ui():
inpainting_mask_invert,
img2img_batch_input_dir,
img2img_batch_output_dir,
+ img2img_batch_inpaint_mask_dir
] + custom_inputs,
outputs=[
img2img_gallery,
--
cgit v1.2.3
From 4c52dfe4ac98c53431ecd267d59f27391d3a63e7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 28 Jan 2023 08:30:17 +0300
Subject: make the detection for -v models less broad
---
modules/sd_models_config.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 73854a45..00217990 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -31,7 +31,7 @@ def guess_model_config_from_state_dict(sd, filename):
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9:
return config_sd2_inpainting
- elif re.search(re_parametrization_v, fn) or "768" in fn:
+ elif re.search(re_parametrization_v, fn):
return config_sd2v
else:
return config_sd2
--
cgit v1.2.3
From 0834d4ce374225131e025540220c727e352a3e43 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 28 Jan 2023 08:41:15 +0300
Subject: simplify #7284
---
modules/ui.py | 11 +++--------
1 file changed, 3 insertions(+), 8 deletions(-)
diff --git a/modules/ui.py b/modules/ui.py
index 3c0a4050..ca2c1eb6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -93,12 +93,6 @@ clear_prompt_symbol = '\U0001F5D1' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
-def switch_width_and_height(width, height):
- width_temp = width
- width = height
- height = width_temp
- return width, height
-
def plaintext_to_html(text):
return ui_common.plaintext_to_html(text)
@@ -574,7 +568,8 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
- res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height])
+
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
txt_prompt_img.change(
fn=modules.images.image_data,
@@ -882,7 +877,7 @@ def create_ui():
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
- res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height])
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
--
cgit v1.2.3
From 7d1f2a3a495327341ef1b3238347864845799bb6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 28 Jan 2023 10:21:31 +0300
Subject: remove waiting for input on version mismatch warning, change
supported versions
---
.gitignore | 1 -
launch.py | 35 ++++++++++++-----------------------
2 files changed, 12 insertions(+), 24 deletions(-)
diff --git a/.gitignore b/.gitignore
index c8be9688..0b1d17ca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,4 +33,3 @@ notification.mp3
/test/stdout.txt
/test/stderr.txt
/cache.json
-no_py_ver_warning
diff --git a/launch.py b/launch.py
index 4f5a4bc4..7614f9c9 100644
--- a/launch.py
+++ b/launch.py
@@ -18,35 +18,24 @@ skip_install = False
def check_python_version():
- if not os.path.isfile("no_py_ver_warning"):
- version = sys.version_info
- version_range = None
- if platform.system() == "Linux":
- version_range = range(7 + 1, 11 + 1)
- else:
- version_range = range(7 + 1, 10 + 1)
+ version = sys.version_info
+ if platform.system() == "Windows":
+ supported_minors = [10]
+ else:
+ supported_minors = [7, 8, 9, 10, 11]
+
+ if not (version.major == 3 and version.minor in supported_minors):
+ import modules.errors
- try:
- assert version.major == 3 and version.minor in version_range, f"""
-=== Warning ===
-This program was tested only with 3.10 Python, but you have {version.major}.{version.minor} Python.
+ modules.errors.print_error_explanation(f"""
+This program is tested with 3.10.6 Python, but you have {version.major}.{version.minor}.{version.micro}.
If you encounter an error with "RuntimeError: Couldn't install torch." message,
or any other error regarding unsuccessful package (library) installation,
please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
-You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
-
-You will see this warning only once, delete file "no_py_ver_warning" file to show this warning again.
-=== Warning ===
-
-Press ENTER to continue...\
-"""
- except AssertionError as e:
- print(e)
- with open("no_py_ver_warning", "w"):
- pass
- input()
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/\
+""")
def commit_hash():
--
cgit v1.2.3
From 3752aad23d4be4522f9edf3fe79c1122fa5ad509 Mon Sep 17 00:00:00 2001
From: Mackerel
Date: Sat, 28 Jan 2023 02:44:12 -0500
Subject: don't replace regular --help with new paths.py parser help
---
modules/paths.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/paths.py b/modules/paths.py
index 08e6f9b9..d991cc71 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -6,7 +6,7 @@ import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
-parser = argparse.ArgumentParser()
+parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
--
cgit v1.2.3
From bd52a6d89970cca4f0f8b4275db895c99e173b3f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 28 Jan 2023 10:48:08 +0300
Subject: some more changes for python version warning; add a commandline flag
to disable
---
launch.py | 25 +++++++++++++++++++------
1 file changed, 19 insertions(+), 6 deletions(-)
diff --git a/launch.py b/launch.py
index 7614f9c9..370920de 100644
--- a/launch.py
+++ b/launch.py
@@ -18,23 +18,33 @@ skip_install = False
def check_python_version():
- version = sys.version_info
- if platform.system() == "Windows":
+ is_windows = platform.system() == "Windows"
+ major = sys.version_info.major
+ minor = sys.version_info.minor
+ micro = sys.version_info.micro
+
+ if is_windows:
supported_minors = [10]
else:
supported_minors = [7, 8, 9, 10, 11]
- if not (version.major == 3 and version.minor in supported_minors):
+ if not (major == 3 and minor in supported_minors):
import modules.errors
modules.errors.print_error_explanation(f"""
-This program is tested with 3.10.6 Python, but you have {version.major}.{version.minor}.{version.micro}.
+INCOMPATIBLE PYTHON VERSION
+
+This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
If you encounter an error with "RuntimeError: Couldn't install torch." message,
or any other error regarding unsuccessful package (library) installation,
please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
-You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/\
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
+
+{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
+
+Use --skip-python-version-check to suppress this warning.
""")
@@ -237,6 +247,7 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
+ sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
@@ -245,6 +256,9 @@ def prepare_environment():
xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv
+ if not skip_python_version_check:
+ check_python_version()
+
commit = commit_hash()
print(f"Python {sys.version}")
@@ -342,6 +356,5 @@ def start():
if __name__ == "__main__":
- check_python_version()
prepare_environment()
start()
--
cgit v1.2.3