From b46b97fa297b3a4a654da77cf98a775a2bcab4c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 11:38:17 +0300 Subject: more fixes for gradio update --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index fbd91300..54b3372d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -38,7 +38,7 @@ def quote(text): def image_from_url_text(filedata): if type(filedata) == dict and filedata["is_file"]: filename = filedata["name"] - is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs) + is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets]) assert is_in_right_dir, 'trying to open image file outside of allowed directories' return Image.open(filename) -- cgit v1.2.3 From a005fccddd5a37c57f1afe5234660b59b9a41508 Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Sun, 1 Jan 2023 14:51:12 +0100 Subject: Add a lot more elem_id/HTML id, modified some that were duplicates for seed section --- modules/generation_parameters_copypaste.py | 2 +- modules/ui.py | 254 ++++++++++++++--------------- style.css | 12 +- 3 files changed, 134 insertions(+), 134 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 54b3372d..8e7f0df0 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -93,7 +93,7 @@ def integrate_settings_paste_fields(component_dict): def create_buttons(tabs_list): buttons = {} for tab in tabs_list: - buttons[tab] = gr.Button(f"Send to {tab}") + buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab") return buttons diff --git a/modules/ui.py b/modules/ui.py index 27da2c2c..7070ea15 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -272,17 +272,17 @@ def interrogate_deepbooru(image): return gr_show(True) if prompt is None else prompt -def create_seed_inputs(): +def create_seed_inputs(target_interface): with gr.Row(): with gr.Box(): - with gr.Row(elem_id='seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1) + with gr.Row(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id='random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed') + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - with gr.Box(elem_id='subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) + with gr.Box(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) # Components to show/hide based on the 'Extra' checkbox seed_extras = [] @@ -290,17 +290,17 @@ def create_seed_inputs(): with gr.Row(visible=False) as seed_extra_row_1: seed_extras.append(seed_extra_row_1) with gr.Box(): - with gr.Row(elem_id='subseed_row'): - subseed = gr.Number(label='Variation seed', value=-1) + with gr.Row(elem_id=target_interface + '_subseed_row'): + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id='random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') with gr.Row(visible=False) as seed_extra_row_2: seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0) - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) @@ -678,28 +678,28 @@ def create_ui(): steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) - enable_hr = gr.Checkbox(label='Highres. fix', value=False) + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0) - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) + firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") + firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") with gr.Row(equal_height=True): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - with gr.Group(): + with gr.Group(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) @@ -821,10 +821,10 @@ def create_ui(): with gr.Column(variant='panel', elem_id="img2img_settings"): with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img'): + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - with gr.TabItem('Inpaint', id='inpaint'): + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) init_img_with_mask_orig = gr.State(None) @@ -843,24 +843,24 @@ def create_ui(): init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") with gr.Row(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") with gr.Row(): mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") + inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") with gr.Row(): - inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) + inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, elem_id="img2img_inpaint_full_res") + inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - with gr.TabItem('Batch img2img', id='batch'): + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") with gr.Row(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -872,20 +872,20 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") with gr.Row(): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") with gr.Group(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - with gr.Group(): + with gr.Group(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -1032,45 +1032,45 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image'): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - with gr.TabItem('Batch Process'): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - with gr.TabItem('Batch from Directory'): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") - show_extras_results = gr.Checkbox(label='Show result images', value=True) + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by'): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) - with gr.TabItem('Scale to'): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): with gr.Group(): with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") with gr.Group(): extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) @@ -1117,7 +1117,7 @@ def create_ui(): with gr.Column(variant='panel'): html = gr.HTML() - generation_info = gr.Textbox(visible=False) + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") html2 = gr.HTML() with gr.Row(): buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) @@ -1144,13 +1144,13 @@ def create_ui(): tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format") - save_as_half = gr.Checkbox(value=False, label="Save as float16") + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') @@ -1165,58 +1165,58 @@ def create_ui(): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name") - initialization_text = gr.Textbox(label="Initialization text", value="*") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary') + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"]) - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory') - process_dst = gr.Textbox(label='Destination directory') - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images') - process_focal_crop = gr.Checkbox(label='Auto focal point crop') - process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_debug = gr.Checkbox(label='Create debug image') + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") with gr.Row(): with gr.Column(scale=3): @@ -1224,8 +1224,8 @@ def create_ui(): with gr.Column(): with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") process_split.change( fn=lambda show: gr_show(show), @@ -1248,31 +1248,31 @@ def create_ui(): train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") with gr.Row(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - - batch_size = gr.Number(label='Batch size', value=1, precision=0) - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) - steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") with gr.Row(): - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") with gr.Row(): - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') - train_embedding = gr.Button(value="Train Embedding", variant='primary') + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") params = script_callbacks.UiTrainTabParams(txt2img_preview_params) @@ -1490,7 +1490,7 @@ def create_ui(): return gr.update(value=value), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: - settings_submit = gr.Button(value="Apply settings", variant='primary') + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") result = gr.HTML() settings_cols = 3 @@ -1541,8 +1541,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary', elem_id="settings_restart_gradio") request_notifications.click( fn=lambda: None, diff --git a/style.css b/style.css index f168571e..924d4ae7 100644 --- a/style.css +++ b/style.css @@ -73,7 +73,7 @@ margin-right: auto; } -#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{ +[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{ min-width: auto; flex-grow: 0; padding-left: 0.25em; @@ -84,27 +84,27 @@ display: none; } -#seed_row, #subseed_row{ +[id$=_seed_row], [id$=_subseed_row]{ gap: 0.5rem; } -#subseed_show_box{ +[id$=_subseed_show_box]{ min-width: auto; flex-grow: 0; } -#subseed_show_box > div{ +[id$=_subseed_show_box] > div{ border: 0; height: 100%; } -#subseed_show{ +[id$=_subseed_show]{ min-width: auto; flex-grow: 0; padding: 0; } -#subseed_show label{ +[id$=_subseed_show] label{ height: 100%; } -- cgit v1.2.3 From ef27a18b6b7cb1a8eebdc9b2e88d25baf2c2414d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 19:42:10 +0300 Subject: Hires fix rework --- modules/generation_parameters_copypaste.py | 32 ++++++++++++++ modules/images.py | 24 +++++++++-- modules/processing.py | 68 ++++++++++++------------------ modules/shared.py | 7 ++- modules/txt2img.py | 6 +-- modules/ui.py | 15 +++---- scripts/xy_grid.py | 4 +- 7 files changed, 96 insertions(+), 60 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 8e7f0df0..d6fa822b 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,6 @@ import base64 import io +import math import os import re from pathlib import Path @@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): return None +def restore_old_hires_fix_params(res): + """for infotexts that specify old First pass size parameter, convert it into + width, height, and hr scale""" + + firstpass_width = res.get('First pass size-1', None) + firstpass_height = res.get('First pass size-2', None) + + if firstpass_width is None or firstpass_height is None: + return + + firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height) + width = int(res.get("Size-1", 512)) + height = int(res.get("Size-2", 512)) + + if firstpass_width == 0 or firstpass_height == 0: + # old algorithm for auto-calculating first pass size + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + firstpass_width = math.ceil(scale * width / 64) * 64 + firstpass_height = math.ceil(scale * height / 64) * 64 + + hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height + + res['Size-1'] = firstpass_width + res['Size-2'] = firstpass_height + res['Hires upscale'] = hr_scale + + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + restore_old_hires_fix_params(res) + return res diff --git a/modules/images.py b/modules/images.py index f84fd485..c3a5fc8b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): return draw_grid_annotations(im, width, height, hor_texts, ver_texts) -def resize_image(resize_mode, im, width, height): +def resize_image(resize_mode, im, width, height, upscaler_name=None): + """ + Resizes an image with the specified resize_mode, width, and height. + + Args: + resize_mode: The mode to use when resizing the image. + 0: Resize the image to the specified width and height. + 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + im: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. + """ + + upscaler_name = upscaler_name or opts.upscaler_for_img2img + def resize(im, w, h): - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': + if upscaler_name is None or upscaler_name == "None" or im.mode == 'L': return im.resize((w, h), resample=LANCZOS) scale = max(w / im.width, h / im.height) if scale > 1.0: - upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] - assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" + upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] + assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}" upscaler = upscalers[0] im = upscaler.scaler.upscale(im, scale, upscaler.data_path) diff --git a/modules/processing.py b/modules/processing.py index 42dc19ea..4654570c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength - self.firstphase_width = firstphase_width - self.firstphase_height = firstphase_height - self.truncate_x = 0 - self.truncate_y = 0 + self.hr_scale = hr_scale + self.hr_upscaler = hr_upscaler + + if firstphase_width != 0 or firstphase_height != 0: + print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) + self.hr_scale = self.width / firstphase_width + self.width = firstphase_width + self.height = firstphase_height def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" - - if self.firstphase_width == 0 or self.firstphase_height == 0: - desired_pixel_count = 512 * 512 - actual_pixel_count = self.width * self.height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - firstphase_width_truncated = int(scale * self.width) - firstphase_height_truncated = int(scale * self.height) - - else: - - width_ratio = self.width / self.firstphase_width - height_ratio = self.height / self.firstphase_height - - if width_ratio > height_ratio: - firstphase_width_truncated = self.firstphase_width - firstphase_height_truncated = self.firstphase_width * self.height / self.width - else: - firstphase_width_truncated = self.firstphase_height * self.width / self.height - firstphase_height_truncated = self.firstphase_height - - self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f - self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f + self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_upscaler is not None: + self.extra_generation_params["Hires upscaler"] = self.hr_upscaler def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode + if self.enable_hr and latent_scale_mode is None: + assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) - - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + target_width = int(self.width * self.hr_scale) + target_height = int(self.height * self.hr_scale) - """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" def save_intermediate(image, index): + """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" + if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return @@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") - if opts.use_scale_latent_for_hires_fix: + if latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): save_intermediate(image, i) - image = images.resize_image(0, image, self.width, self.height) + image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) @@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None diff --git a/modules/shared.py b/modules/shared.py index 7f430b93..b65559ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -545,6 +544,12 @@ opts = Options() if os.path.exists(config_filename): opts.load(config_filename) +latent_upscale_default_mode = "Latent" +latent_upscale_modes = { + "Latent": "bilinear", + "Latent (nearest)": "nearest", +} + sd_upscalers = [] sd_model = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 7f61e19a..e189a899 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: tiling=tiling, enable_hr=enable_hr, denoising_strength=denoising_strength if enable_hr else None, - firstphase_width=firstphase_width if enable_hr else None, - firstphase_height=firstphase_height if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 7070ea15..27cd9ddd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -684,11 +684,11 @@ def create_ui(): with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") with gr.Row(equal_height=True): @@ -729,8 +729,8 @@ def create_ui(): width, enable_hr, denoising_strength, - firstphase_width, - firstphase_height, + hr_scale, + hr_upscaler, ] + custom_inputs, outputs=[ @@ -762,7 +762,6 @@ def create_ui(): outputs=[hr_options], ) - txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), @@ -781,8 +780,8 @@ def create_ui(): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass size-1"), - (firstphase_height, "First pass size-2"), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 3e0b2805..f92f9776 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -202,7 +202,7 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None), @@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.hypernetwork = opts.sd_hypernetwork self.model = shared.sd_model - self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object): hypernetwork.apply_strength() opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -- cgit v1.2.3 From 251ecee6949c36e9df1d99a950b3e1af2b5fa2b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 22:44:46 +0300 Subject: make "send to" buttons send actual dimension of the sent image rather than fields --- javascript/ui.js | 4 +-- modules/generation_parameters_copypaste.py | 58 ++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 20 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/javascript/ui.js b/javascript/ui.js index 587dd782..d0c054d9 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -19,7 +19,7 @@ function selected_gallery_index(){ function extract_image_from_gallery(gallery){ if(gallery.length == 1){ - return gallery[0] + return [gallery[0]] } index = selected_gallery_index() @@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){ return [null] } - return gallery[index]; + return [gallery[index]]; } function args_to_array(args){ diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d6fa822b..ec60319a 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -103,35 +103,57 @@ def bind_buttons(buttons, send_image, send_generate_info): bind_list.append([buttons, send_image, send_generate_info]) +def send_image_and_dimensions(x): + if isinstance(x, Image.Image): + img = x + else: + img = image_from_url_text(x) + + if shared.opts.send_size and isinstance(img, Image.Image): + w = img.width + h = img.height + else: + w = gr.update() + h = gr.update() + + return img, w, h + + def run_bind(): - for buttons, send_image, send_generate_info in bind_list: + for buttons, source_image_component, send_generate_info in bind_list: for tab in buttons: button = buttons[tab] - if send_image and paste_fields[tab]["init_img"]: - if type(send_image) == gr.Gallery: - button.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery", - inputs=[send_image], - outputs=[paste_fields[tab]["init_img"]], - ) + destination_image_component = paste_fields[tab]["init_img"] + fields = paste_fields[tab]["fields"] + + destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) + destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) + + if source_image_component and destination_image_component: + if isinstance(source_image_component, gr.Gallery): + func = send_image_and_dimensions if destination_width_component else image_from_url_text + jsfunc = "extract_image_from_gallery" else: - button.click( - fn=lambda x: x, - inputs=[send_image], - outputs=[paste_fields[tab]["init_img"]], - ) + func = send_image_and_dimensions if destination_width_component else lambda x: x + jsfunc = None + + button.click( + fn=func, + _js=jsfunc, + inputs=[source_image_component], + outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], + ) - if send_generate_info and paste_fields[tab]["fields"] is not None: + if send_generate_info and fields is not None: if send_generate_info in paste_fields: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else []) + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) button.click( fn=lambda *x: x, inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], - outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], + outputs=[field for field, name in fields if name in paste_field_names], ) else: - connect_paste(button, paste_fields[tab]["fields"], send_generate_info) + connect_paste(button, fields, send_generate_info) button.click( fn=None, -- cgit v1.2.3 From c0ee1488702d5a6ae35fbf7e0422f9f685394920 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 14:18:48 +0300 Subject: add support for running with gradio 3.9 installed --- modules/generation_parameters_copypaste.py | 4 ++-- modules/ui_tempdir.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index ec60319a..d94f11a3 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared +from modules import shared, ui_tempdir import tempfile from PIL import Image @@ -39,7 +39,7 @@ def quote(text): def image_from_url_text(filedata): if type(filedata) == dict and filedata["is_file"]: filename = filedata["name"] - is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets]) + is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) assert is_in_right_dir, 'trying to open image file outside of allowed directories' return Image.open(filename) diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 363d449d..21945235 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -1,6 +1,7 @@ import os import tempfile from collections import namedtuple +from pathlib import Path import gradio as gr @@ -12,10 +13,28 @@ from modules import shared Savedfile = namedtuple("Savedfile", ["name"]) +def register_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): # gradio 3.15 + gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} + + if hasattr(gradio, 'temp_dirs'): # gradio 3.9 + gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} + + +def check_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): + return any([filename in fileset for fileset in gradio.temp_file_sets]) + + if hasattr(gradio, 'temp_dirs'): + return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) + + return False + + def save_pil_to_file(pil_image, dir=None): already_saved_as = getattr(pil_image, 'already_saved_as', None) if already_saved_as and os.path.isfile(already_saved_as): - shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(already_saved_as)} + register_tmp_file(shared.demo, already_saved_as) file_obj = Savedfile(already_saved_as) return file_obj @@ -45,7 +64,7 @@ def on_tmpdir_changed(): os.makedirs(shared.opts.temp_dir, exist_ok=True) - shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(shared.opts.temp_dir)} + register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) def cleanup_tmpdr(): -- cgit v1.2.3 From 3e22e294135ed0327ce9d9738655ff03c53df3c0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 21:49:24 +0300 Subject: fix broken send to extras button --- modules/generation_parameters_copypaste.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d94f11a3..4baf4d9a 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -37,7 +37,10 @@ def quote(text): def image_from_url_text(filedata): - if type(filedata) == dict and filedata["is_file"]: + if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): + filedata = filedata[0] + + if type(filedata) == dict and filedata.get("is_file", False): filename = filedata["name"] is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) assert is_in_right_dir, 'trying to open image file outside of allowed directories' -- cgit v1.2.3 From 81490780949fffed77493b4bd741e96ec737fe27 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 22:04:40 +0300 Subject: added the option to specify target resolution with possibility of truncating for hires fix; also sampling steps --- javascript/hints.js | 11 ++++--- modules/generation_parameters_copypaste.py | 9 ++++-- modules/processing.py | 51 +++++++++++++++++++++++++++--- modules/txt2img.py | 5 ++- modules/ui.py | 24 ++++++++++---- 5 files changed, 81 insertions(+), 19 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/javascript/hints.js b/javascript/hints.js index 63e17e05..dda66e09 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -81,9 +81,6 @@ titles = { "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", - "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", - "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", - "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", @@ -100,7 +97,13 @@ titles = { "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.", - "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality." + "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.", + + "Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", + "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.", + "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", + "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", + "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders." } diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4baf4d9a..12a9de3d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res): firstpass_width = math.ceil(scale * width / 64) * 64 firstpass_height = math.ceil(scale * height / 64) * 64 - hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height - res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height - res['Hires upscale'] = hr_scale + res['Hires resize-1'] = width + res['Hires resize-2'] = height def parse_generation_parameters(x: str): @@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + restore_old_hires_fix_params(res) return res diff --git a/modules/processing.py b/modules/processing.py index 47712159..9cad05f2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -662,12 +662,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength self.hr_scale = hr_scale self.hr_upscaler = hr_upscaler + self.hr_second_pass_steps = hr_second_pass_steps + self.hr_resize_x = hr_resize_x + self.hr_resize_y = hr_resize_y + self.hr_upscale_to_x = hr_resize_x + self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) @@ -675,6 +680,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.width = firstphase_width self.height = firstphase_height + self.truncate_x = 0 + self.truncate_y = 0 + def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: if state.job_count == -1: @@ -682,7 +690,38 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_resize_x == 0 and self.hr_resize_y == 0: + self.extra_generation_params["Hires upscale"] = self.hr_scale + self.hr_upscale_to_x = int(self.width * self.hr_scale) + self.hr_upscale_to_y = int(self.height * self.hr_scale) + else: + self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}" + + if self.hr_resize_y == 0: + self.hr_upscale_to_x = self.hr_resize_x + self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width + elif self.hr_resize_x == 0: + self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height + self.hr_upscale_to_y = self.hr_resize_y + else: + target_w = self.hr_resize_x + target_h = self.hr_resize_y + src_ratio = self.width / self.height + dst_ratio = self.hr_resize_x / self.hr_resize_y + + if src_ratio < dst_ratio: + self.hr_upscale_to_x = self.hr_resize_x + self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width + else: + self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height + self.hr_upscale_to_y = self.hr_resize_y + + self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f + self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f + + if self.hr_second_pass_steps: + self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps + if self.hr_upscaler is not None: self.extra_generation_params["Hires upscaler"] = self.hr_upscaler @@ -699,8 +738,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: return samples - target_width = int(self.width * self.hr_scale) - target_height = int(self.height * self.hr_scale) + target_width = self.hr_upscale_to_x + target_height = self.hr_upscale_to_y def save_intermediate(image, index): """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" @@ -755,13 +794,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) return samples diff --git a/modules/txt2img.py b/modules/txt2img.py index e189a899..38b5f591 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: denoising_strength=denoising_strength if enable_hr else None, hr_scale=hr_scale, hr_upscaler=hr_upscaler, + hr_second_pass_steps=hr_second_pass_steps, + hr_resize_x=hr_resize_x, + hr_resize_y=hr_resize_y, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 44f4f3a4..04091e67 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -637,10 +637,10 @@ def create_sampler_and_steps_selection(choices, tabname): with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") sampler_index.save_to_config = True - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") return steps, sampler_index @@ -709,10 +709,16 @@ def create_ui(): enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") elif category == "hires_fix": - with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options: - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") elif category == "batch": if not opts.dimensions_and_batch_together: @@ -753,6 +759,9 @@ def create_ui(): denoising_strength, hr_scale, hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, ] + custom_inputs, outputs=[ @@ -804,6 +813,9 @@ def create_ui(): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) -- cgit v1.2.3 From d4fd2418efb0986a8226add0b800fb5c73ffb58c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 14:57:47 +0300 Subject: add an option to use old hiresfix width/height behavior add a visual effect to inactive hires fix elements --- javascript/hires_fix.js | 25 +++++++++++++++++++++++++ modules/generation_parameters_copypaste.py | 17 +++++++++++------ modules/processing.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + modules/ui.py | 23 ++++++++++++++--------- style.css | 4 ++++ 6 files changed, 79 insertions(+), 17 deletions(-) create mode 100644 javascript/hires_fix.js (limited to 'modules/generation_parameters_copypaste.py') diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js new file mode 100644 index 00000000..07fba549 --- /dev/null +++ b/javascript/hires_fix.js @@ -0,0 +1,25 @@ + +function setInactive(elem, inactive){ + console.log(elem) + if(inactive){ + elem.classList.add('inactive') + } else{ + elem.classList.remove('inactive') + } +} + +function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ + console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) + + hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') + hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') + hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') + + gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" + + setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) + setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) + setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) + + return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] +} diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 12a9de3d..f7f68b67 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res): firstpass_width = res.get('First pass size-1', None) firstpass_height = res.get('First pass size-2', None) + if shared.opts.use_old_hires_fix_width_height: + hires_width = int(res.get("Hires resize-1", None)) + hires_height = int(res.get("Hires resize-2", None)) + + if hires_width is not None and hires_height is not None: + res['Size-1'] = hires_width + res['Size-2'] = hires_height + return + if firstpass_width is None or firstpass_height is None: return @@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res): height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: - # old algorithm for auto-calculating first pass size - desired_pixel_count = 512 * 512 - actual_pixel_count = width * height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - firstpass_width = math.ceil(scale * width / 64) * 64 - firstpass_height = math.ceil(scale * height / 64) * 64 + from modules import processing + firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height diff --git a/modules/processing.py b/modules/processing.py index 1d23b15f..f04a0e1e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: return res +def old_hires_fix_first_pass_dimensions(width, height): + """old algorithm for auto-calculating first pass size""" + + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + width = math.ceil(scale * width / 64) * 64 + height = math.ceil(scale * height / 64) * 64 + + return width, height + + class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None @@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: - print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) - self.hr_scale = self.width / firstphase_width + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height self.width = firstphase_width self.height = firstphase_height self.truncate_x = 0 self.truncate_y = 0 + self.applied_old_hires_behavior_to = None def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): + self.hr_resize_x = self.width + self.hr_resize_y = self.height + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height + + self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height) + self.applied_old_hires_behavior_to = (self.width, self.height) + if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale self.hr_upscale_to_x = int(self.width * self.hr_scale) diff --git a/modules/shared.py b/modules/shared.py index a6712dae..a1e10201 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,6 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), + "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index 99483130..719c26b3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): @@ -745,15 +745,20 @@ def create_ui(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - hr_resolution_preview_args = dict( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False - ) - for input in hr_resolution_preview_inputs: - input.change(**hr_resolution_preview_args) + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) diff --git a/style.css b/style.css index d796cbe9..ec5e4182 100644 --- a/style.css +++ b/style.css @@ -670,6 +670,10 @@ footer { min-width: auto; } +.inactive{ + opacity: 0.5; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 3fe9e9e54dcfc41d7c5ee6976f83b0de29fd3dda Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 02:17:33 +0300 Subject: fix broken resolution detection when pasting parameters with old hires fix enabled --- modules/generation_parameters_copypaste.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index f7f68b67..620aa606 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -198,10 +198,10 @@ def restore_old_hires_fix_params(res): firstpass_height = res.get('First pass size-2', None) if shared.opts.use_old_hires_fix_width_height: - hires_width = int(res.get("Hires resize-1", None)) - hires_height = int(res.get("Hires resize-2", None)) + hires_width = int(res.get("Hires resize-1", 0)) + hires_height = int(res.get("Hires resize-2", 0)) - if hires_width is not None and hires_height is not None: + if hires_width and hires_height: res['Size-1'] = hires_width res['Size-2'] = hires_height return -- cgit v1.2.3 From 6c88eaed4f5efca54a882eb1f8f30f01f350332a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:50:09 -0800 Subject: Add script callback for fixing infotext parameters --- modules/generation_parameters_copypaste.py | 3 ++- modules/script_callbacks.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 620aa606..593d99ef 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared, ui_tempdir +from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): prompt = file.read() params = parse_generation_parameters(prompt) + script_callbacks.infotext_pasted_callback(prompt, params) res = [] for output, key in paste_fields: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 608c5300..a9e19236 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import sys import traceback from collections import namedtuple import inspect -from typing import Optional +from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_infotext_pasted=[], callbacks_script_unloaded=[], ) @@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): + for c in callback_map['callbacks_infotext_pasted']: + try: + c.callback(infotext, params) + except Exception: + report_exception(c, 'infotext_pasted') + + def script_unloaded_callback(): for c in reversed(callback_map['callbacks_script_unloaded']): try: @@ -290,6 +299,15 @@ def on_image_grid(callback): add_callback(callback_map['callbacks_image_grid'], callback) +def on_infotext_pasted(callback): + """register a function to be called before applying an infotext. + The callback is called with two arguments: + - infotext: str - raw infotext. + - result: Dict[str, any] - parsed infotext parameters. + """ + add_callback(callback_map['callbacks_infotext_pasted'], callback) + + def on_script_unloaded(callback): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" -- cgit v1.2.3 From 89314e79da21ac71ad3133ccf5ac3e85d4c24052 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 23:23:16 +0300 Subject: fix an error that happens when you send an empty image from txt2img to img2img --- modules/generation_parameters_copypaste.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 593d99ef..a381ff59 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -37,6 +37,9 @@ def quote(text): def image_from_url_text(filedata): + if filedata is None: + return None + if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] -- cgit v1.2.3 From 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 08:36:07 +0300 Subject: extra networks UI rework of hypernets: rather than via settings, hypernets are added directly to prompt as --- html/card-no-preview.png | Bin 0 -> 84440 bytes html/extra-networks-card.html | 11 ++ html/extra-networks-no-cards.html | 8 ++ javascript/extraNetworks.js | 60 ++++++++ javascript/hints.js | 2 + javascript/ui.js | 9 +- modules/api/api.py | 7 +- modules/extra_networks.py | 147 +++++++++++++++++++ modules/extra_networks_hypernet.py | 21 +++ modules/generation_parameters_copypaste.py | 12 +- modules/hypernetworks/hypernetwork.py | 107 +++++++++----- modules/hypernetworks/ui.py | 5 +- modules/processing.py | 24 ++-- modules/sd_hijack_optimizations.py | 10 +- modules/shared.py | 21 ++- modules/textual_inversion/textual_inversion.py | 2 + modules/ui.py | 50 ++++--- modules/ui_components.py | 10 ++ modules/ui_extra_networks.py | 149 +++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 34 +++++ modules/ui_extra_networks_textual_inversion.py | 32 +++++ script.js | 13 +- scripts/xy_grid.py | 29 ---- style.css | 190 +++++++++++++------------ webui.py | 26 +++- 25 files changed, 765 insertions(+), 214 deletions(-) create mode 100644 html/card-no-preview.png create mode 100644 html/extra-networks-card.html create mode 100644 html/extra-networks-no-cards.html create mode 100644 javascript/extraNetworks.js create mode 100644 modules/extra_networks.py create mode 100644 modules/extra_networks_hypernet.py create mode 100644 modules/ui_extra_networks.py create mode 100644 modules/ui_extra_networks_hypernets.py create mode 100644 modules/ui_extra_networks_textual_inversion.py (limited to 'modules/generation_parameters_copypaste.py') diff --git a/html/card-no-preview.png b/html/card-no-preview.png new file mode 100644 index 00000000..e2beb269 Binary files /dev/null and b/html/card-no-preview.png differ diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html new file mode 100644 index 00000000..7314b063 --- /dev/null +++ b/html/extra-networks-card.html @@ -0,0 +1,11 @@ +
+
+
+ +
+ {name} +
+
+ diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html new file mode 100644 index 00000000..389358d6 --- /dev/null +++ b/html/extra-networks-no-cards.html @@ -0,0 +1,8 @@ +
+

Nothing here. Add some content to the following directories:

+ + +
+ diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js new file mode 100644 index 00000000..71e522d1 --- /dev/null +++ b/javascript/extraNetworks.js @@ -0,0 +1,60 @@ + +function setupExtraNetworksForTab(tabname){ + gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') + + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh')) + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close')) +} + +var activePromptTextarea = null; +var activePositivePromptTextarea = null; + +function setupExtraNetworks(){ + setupExtraNetworksForTab('txt2img') + setupExtraNetworksForTab('img2img') + + function registerPrompt(id, isNegative){ + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); + + if (activePromptTextarea == null){ + activePromptTextarea = textarea + } + if (activePositivePromptTextarea == null && ! isNegative){ + activePositivePromptTextarea = textarea + } + + textarea.addEventListener("focus", function(){ + activePromptTextarea = textarea; + if(! isNegative) activePositivePromptTextarea = textarea; + }); + } + + registerPrompt('txt2img_prompt') + registerPrompt('txt2img_neg_prompt', true) + registerPrompt('img2img_prompt') + registerPrompt('img2img_neg_prompt', true) +} + +onUiLoaded(setupExtraNetworks) + +function cardClicked(textToAdd, allowNegativePrompt){ + textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea + + textarea.value = textarea.value + " " + textToAdd + updateInput(textarea) + + return false +} + +function saveCardPreview(event, tabname, filename){ + textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') + button = gradioApp().getElementById(tabname + '_save_preview') + + textarea.value = filename + updateInput(textarea) + + button.click() + + event.stopPropagation() + event.preventDefault() +} diff --git a/javascript/hints.js b/javascript/hints.js index e746e20d..f4079f96 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -21,6 +21,8 @@ titles = { "\U0001F5D1": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", + "\u{1f3b4}": "Show extra networks", + "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/ui.js b/javascript/ui.js index 3ba90ca8..a7e75439 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { return [prompt, negative_prompt] } - - opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -239,11 +237,14 @@ onUiUpdate(function(){ return } + prompt.parentElement.insertBefore(counter, prompt) counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" - textarea.addEventListener("input", () => update_token_counter(id_button)); + textarea.addEventListener("input", function(){ + update_token_counter(id_button); + }); } registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') @@ -261,10 +262,8 @@ onUiUpdate(function(){ }) } } - }) - onOptionsChanged(function(){ elem = gradioApp().getElementById('sd_checkpoint_hash') sd_checkpoint_hash = opts.sd_checkpoint_hash || "" diff --git a/modules/api/api.py b/modules/api/api.py index 9814bbc2..2c371e6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -480,7 +480,7 @@ class Api: def train_hypernetwork(self, args: dict): try: shared.state.begin() - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -491,16 +491,15 @@ class Api: except Exception as e: error = e finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {error}".format(error = error)) + return TrainResponse(info="train embedding error: {error}".format(error=error)) def get_memory(self): try: diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000..1978673d --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re +from collections import defaultdict + +from modules import errors + +extra_network_registry = {} + + +def initialize(): + extra_network_registry.clear() + + +def register_extra_network(extra_network): + extra_network_registry[extra_network.name] = extra_network + + +class ExtraNetworkParams: + def __init__(self, items=None): + self.items = items or [] + + +class ExtraNetwork: + def __init__(self, name): + self.name = name + + def activate(self, p, params_list): + """ + Called by processing on every run. Whatever the extra network is meant to do should be activated here. + Passes arguments related to this extra network in params_list. + User passes arguments by specifying this in his prompt: + + + + Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments + separated by colon. + + Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - + in this case, all effects of this extra networks should be disabled. + + Can be called multiple times before deactivate() - each new call should override the previous call completely. + + For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: + + > "1girl, " + + params_list will be: + + [ + ExtraNetworkParams(items=["agm", "1.1"]), + ExtraNetworkParams(items=["ray"]) + ] + + """ + raise NotImplementedError + + def deactivate(self, p): + """ + Called at the end of processing for housekeeping. No need to do anything here. + """ + + raise NotImplementedError + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + print(f"Skipping unknown extra network: {extra_network_name}") + continue + + try: + extra_network.activate(p, extra_network_args) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.activate(p, []) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name}") + + +def deactivate(p, extra_network_data): + """call deactivate for extra networks in extra_network_data in specified order, then call + deactivate for all remaining registered networks""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating extra network {extra_network_name}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") + + +re_extra_net = re.compile(r"<(\w+):([^>]+)>") + + +def parse_prompt(prompt): + res = defaultdict(list) + + def found(m): + name = m.group(1) + args = m.group(2) + + res[name].append(ExtraNetworkParams(items=args.split(":"))) + + return "" + + prompt = re.sub(re_extra_net, found, prompt) + + return prompt, res + + +def parse_prompts(prompts): + res = [] + extra_data = None + + for prompt in prompts: + updated_prompt, parsed_extra_data = parse_prompt(prompt) + + if extra_data is None: + extra_data = parsed_extra_data + + res.append(updated_prompt) + + return res, extra_data + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000..6a0c4ba8 --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,21 @@ +from modules import extra_networks +from modules.hypernetworks import hypernetwork + + +class ExtraNetworkHypernet(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('hypernet') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + hypernetwork.load_hypernetworks(names, multipliers) + + def deactivate(p, self): + pass diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a381ff59..46e12dc6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): from modules import ui settings_map = { - 'sd_hypernetwork': 'Hypernet', - 'sd_hypernetwork_strength': 'Hypernet strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'inpainting_mask_weight': 'Conditional mask weight', 'sd_model_checkpoint': 'Model hash', @@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res: res["Clip skip"] = "1" - if "Hypernet strength" not in res: - res["Hypernet strength"] = "1" - - if "Hypernet" in res: - hypernet_name = res["Hypernet"] - hypernet_hash = res.get("Hypernet hash", None) - res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res["Prompt"] += f"""""" if "Hires resize-1" not in res: res["Hires resize-1"] = 0 diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..80a47c79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} class HypernetworkModule(torch.nn.Module): - multiplier = 1.0 activation_dict = { "linear": torch.nn.Identity, "relu": torch.nn.ReLU, @@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() + self.multiplier = 1.0 + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" @@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) + return x + self.linear(x) * (self.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure -def apply_strength(value=None): - HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength - #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): if layer_structure is None: @@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters(): param.requires_grad = mode + def to(self, device): + for k, layers in self.layers.items(): + for layer in layers: + layer.to(device) + + return self + + def set_multiplier(self, multiplier): + for k, layers in self.layers.items(): + for layer in layers: + layer.multiplier = multiplier + + return self + def eval(self): for k, layers in self.layers.items(): for layer in layers: @@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None if self.optimizer_state_dict: self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print("Loaded existing optimizer from checkpoint") - print(f"Optimizer name is {self.optimizer_name}") + if shared.opts.print_hypernet_extra: + print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: self.optimizer_name = "AdamW" - print("No saved optimizer exists in checkpoint") + if shared.opts.print_hypernet_extra: + print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: @@ -306,23 +320,43 @@ def list_hypernetworks(path): return res -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - # Prevent any file named "None.pt" from being loaded. - if path is not None and filename != "None": - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) +def load_hypernetwork(name): + path = shared.hypernetworks.get(name, None) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print("Unloading hypernetwork") + if path is None: + return None + + hypernetwork = Hypernetwork() + + try: + hypernetwork.load(path) + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + return hypernetwork + + +def load_hypernetworks(names, multipliers=None): + already_loaded = {} + + for hypernetwork in shared.loaded_hypernetworks: + if hypernetwork.name in names: + already_loaded[hypernetwork.name] = hypernetwork - shared.loaded_hypernetwork = None + shared.loaded_hypernetworks.clear() + + for i, name in enumerate(names): + hypernetwork = already_loaded.get(name, None) + if hypernetwork is None: + hypernetwork = load_hypernetwork(name) + + if hypernetwork is None: + continue + + hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0) + shared.loaded_hypernetworks.append(hypernetwork) def find_closest_hypernetwork_name(search: str): @@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0] -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) +def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) if hypernetwork_layers is None: - return context, context + return context_k, context_v if layer is not None: layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) + context_k = hypernetwork_layers[0](context_k) + context_v = hypernetwork_layers[1](context_v) + return context_k, context_v + + +def apply_hypernetworks(hypernetworks, context, layer=None): + context_k = context + context_v = context + for hypernetwork in hypernetworks: + context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer) + return context_k, context_v @@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) + hypernetwork = Hypernetwork() + hypernetwork.load(path) + shared.loaded_hypernetworks = [hypernetwork] shared.state.job = "train-hypernetwork" shared.state.textinfo = "Initializing hypernetwork training..." @@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else: images_dir = None - hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() initial_step = hypernetwork.step or 0 diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 81e3f519..76599f5a 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) + def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) @@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' @@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/processing.py b/modules/processing.py index a3e9f709..b5deeacf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), - "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), @@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': - shared.reload_hypernetworks() # make onchange call for changing hypernet if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() # make onchange call for changing SD model + sd_models.reload_model_weights() if k == 'sd_vae': - sd_vae.reload_vae_weights() # make onchange call for changing VAE + sd_vae.reload_vae_weights() res = process_images_inner(p) @@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': shared.reload_hypernetworks() - if k == 'sd_model_checkpoint': sd_models.reload_model_weights() - if k == 'sd_vae': sd_vae.reload_vae_weights() + if k == 'sd_model_checkpoint': + sd_models.reload_model_weights() + + if k == 'sd_vae': + sd_vae.reload_vae_weights() return res @@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: cache[0] = (required_prompts, steps) return cache[1] + p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts) + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + extra_networks.activate(p, extra_network_data) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) @@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + extra_networks.deactivate(p, extra_network_data) devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index cdc63ed7..4fa54329 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) del context, context_k, context_v, x @@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) @@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) * self.scale v = self.to_v(context_v) del context, context_k, context_v, x @@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x @@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) diff --git a/modules/shared.py b/modules/shared.py index 2f366454..c0e11f18 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -23,6 +23,7 @@ demo = None sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file + parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = {} -loaded_hypernetwork = None +loaded_hypernetworks = [] def reload_hypernetworks(): @@ -153,8 +154,6 @@ def reload_hypernetworks(): global hypernetworks hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - class State: @@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), - "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), @@ -661,3 +658,17 @@ mem_mon.start() def listfiles(dirname): filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] + + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5a7be422..4e90f690 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -50,6 +50,7 @@ class Embedding: self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None + self.filename = None def save(self, filename): embedding_data = { @@ -182,6 +183,7 @@ class EmbeddingDatabase: embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] + embedding.filename = path if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) diff --git a/modules/ui.py b/modules/ui.py index 06c11848..d23b2b8e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ +extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): @@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) @@ -354,10 +357,10 @@ def create_toprow(is_img2img): negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") @@ -395,11 +398,14 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_styles_row"): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id="style_create") + + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -616,11 +622,15 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -794,14 +804,20 @@ def create_ui(): token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] @@ -1064,6 +1080,8 @@ def create_ui(): token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), @@ -1666,10 +1684,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") + with gr.TabItem("Licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") @@ -1756,11 +1772,9 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")): audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") + footer = shared.html("footer.html") + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( diff --git a/modules/ui_components.py b/modules/ui_components.py index 97acff06..46324425 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent): return "button" +class ToolButtonTop(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool-top", **kwargs) + + def get_block_name(self): + return "button" + + class FormRow(gr.Row, gr.components.FormComponent): """Same as gr.Row but fits inside gradio forms""" diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000..253e90f7 --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,149 @@ +import os.path + +from modules import shared +import gradio as gr +import json + +from modules.generation_parameters_copypaste import image_from_url_text + +extra_pages = [] + + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + + extra_pages.append(page) + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + + def refresh(self): + pass + + def create_html(self, tabname): + items_html = '' + + for item in self.list_items(): + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + res = "
    " + items_html + "
    " + + return res + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + preview = item.get("preview", None) + + args = { + "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "prompt": json.dumps(item["prompt"]), + "tabname": json.dumps(tabname), + "local_preview": json.dumps(item["local_preview"]), + "name": item["name"], + "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + } + + return self.card_page.format(**args) + + +def intialize(): + extra_pages.clear() + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.stored_extra_pages = extra_pages.copy() + ui.tabname = tabname + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + + for page in ui.stored_extra_pages: + with gr.Tab(page.title): + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) + button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + + def refresh(): + res = [] + + for pg in ui.stored_extra_pages: + pg.refresh() + res.append(pg.create_html(ui.tabname)) + + return res + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + image.save(filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000..312dbaf0 --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,34 @@ +import os + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def list_items(self): + for name, path in shared.hypernetworks.items(): + path, ext = os.path.splitext(path) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000..e4a6e3bf --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,32 @@ +import os + +from modules import ui_extra_networks, sd_hijack + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def list_items(self): + for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + path, ext = os.path.splitext(embedding.filename) + preview_file = path + ".preview.png" + + preview = None + if os.path.isfile(preview_file): + preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + + yield { + "name": embedding.name, + "filename": embedding.filename, + "preview": preview, + "prompt": embedding.name, + "local_preview": path + ".preview.png", + } + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/script.js b/script.js index 3345e32b..97e0bfcf 100644 --- a/script.js +++ b/script.js @@ -13,6 +13,7 @@ function get_uiCurrentTabContent() { } uiUpdateCallbacks = [] +uiLoadedCallbacks = [] uiTabChangeCallbacks = [] optionsChangedCallbacks = [] let uiCurrentTab = null @@ -20,6 +21,9 @@ let uiCurrentTab = null function onUiUpdate(callback){ uiUpdateCallbacks.push(callback) } +function onUiLoaded(callback){ + uiLoadedCallbacks.push(callback) +} function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } @@ -38,8 +42,15 @@ function executeCallbacks(queue, m) { queue.forEach(function(x){runCallback(x, m)}) } +var executedOnLoaded = false; + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ + if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){ + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + executeCallbacks(uiUpdateCallbacks, m); const newTab = get_uiCurrentTab(); if ( newTab && ( newTab !== uiCurrentTab ) ) { @@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() { /** * Add a ctrl+enter as a shortcut to start a generation */ - document.addEventListener('keydown', function(e) { +document.addEventListener('keydown', function(e) { var handled = false; if (e.key !== undefined) { if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6629f5d5..b1badec9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,6 @@ import modules.scripts as scripts import gradio as gr from modules import images, paths, sd_samplers, processing, sd_models, sd_vae -from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs): raise RuntimeError(f"Unknown checkpoint: {x}") -def apply_hypernetwork(p, x, xs): - if x.lower() in ["", "none"]: - name = None - else: - name = hypernetwork.find_closest_hypernetwork_name(x) - if not name: - raise RuntimeError(f"Unknown hypernetwork: {x}") - hypernetwork.load_hypernetwork(name) - - -def apply_hypernetwork_strength(p, x, xs): - hypernetwork.apply_strength(x) - - -def confirm_hypernetworks(p, xs): - for x in xs: - if x.lower() in ["", "none"]: - continue - if not hypernetwork.find_closest_hypernetwork_name(x): - raise RuntimeError(f"Unknown hypernetwork: {x}") - - def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -208,8 +185,6 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma max", float, apply_field("s_tmax")), @@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.hypernetwork = opts.sd_hypernetwork self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object): modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() - hypernetwork.load_hypernetwork(self.hypernetwork) - hypernetwork.apply_strength() - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers diff --git a/style.css b/style.css index 3a515ebd..5e8bc2ca 100644 --- a/style.css +++ b/style.css @@ -132,13 +132,6 @@ } #roll_col > button { - min-width: 2em; - min-height: 2em; - max-width: 2em; - max-height: 2em; - flex-grow: 0; - padding-left: 0.25em; - padding-right: 0.25em; margin: 0.1em 0; } @@ -146,9 +139,10 @@ min-width: 0 !important; max-width: 8em !important; margin-right: 1em; + gap: 0; } #interrogate, #deepbooru{ - margin: 0em 0.25em 0.9em 0.25em; + margin: 0em 0.25em 0.5em 0.25em; min-width: 8em; max-width: 8em; } @@ -157,8 +151,17 @@ min-width: 8em !important; } +#txt2img_styles_row, #img2img_styles_row{ + gap: 0.25em; + margin-top: 0.5em; +} + +#txt2img_styles_row > button, #img2img_styles_row > button{ + margin: 0; +} + #txt2img_styles, #img2img_styles{ - margin-top: 1em; + padding: 0; } #txt2img_styles ul, #img2img_styles ul{ @@ -635,17 +638,21 @@ canvas[key="mask"] { background-color: rgb(31 41 55 / var(--tw-bg-opacity)); } -.gr-button-tool{ +.gr-button-tool, .gr-button-tool-top{ max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 1.6em 0.7em 0.55em 0; } -#tab_modelmerger .gr-button-tool{ +.gr-button-tool{ margin: 0.6em 0em 0.55em 0; } +.gr-button-tool-top, #settings .gr-button-tool{ + margin: 1.6em 0.7em 0.55em 0; +} + + #modelmerger_results_container{ margin-top: 1em; overflow: visible; @@ -763,81 +770,88 @@ footer { line-height: 2.4em; } -/* The following handles localization for right-to-left (RTL) languages like Arabic. -The rtl media type will only be activated by the logic in javascript/localization.js. -If you change anything above, you need to make sure it is RTL compliant by just running -your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/. -Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/ -@media rtl { - /* this part was added manually */ - :host { - direction: rtl; - } - select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress { - direction: ltr; - } - #script_list > label > select, - #x_type > label > select, - #y_type > label > select { - direction: rtl; - } - .gr-radio, .gr-checkbox{ - margin-left: 0.25em; - } +#txt2img_extra_networks, #img2img_extra_networks{ + margin-top: -1em; +} - /* automatically generated with few manual modifications */ - .performance .time { - margin-right: unset; - margin-left: 0; - } - .justify-center.overflow-x-scroll { - justify-content: right; - } - .justify-center.overflow-x-scroll button:first-of-type { - margin-left: unset; - margin-right: auto; - } - .justify-center.overflow-x-scroll button:last-of-type { - margin-right: unset; - margin-left: auto; - } - #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ - margin-right: unset; - margin-left: 8em; - } - #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ - right: unset; - left: 0; - } - .progressDiv .progress{ - padding: 0 0 0 8px; - text-align: left; - } - #lightboxModal{ - left: unset; - right: 0; - } - .modalPrev, .modalNext{ - border-radius: 3px 0 0 3px; - } - .modalNext { - right: unset; - left: 0; - border-radius: 0 3px 3px 0; - } - #imageARPreview{ - left:unset; - right:0px; - } - #txt2img_skip, #img2img_skip{ - right: unset; - left: 0px; - } - #context-menu{ - box-shadow:-1px 1px 2px #CE6400; - } - .gr-box > div > div > input.gr-text-input{ - right: unset; - left: 0.5em; - } +.extra-networks > div > [id *= '_extra_']{ + margin: 0.3em; } + +.extra-network-cards .nocards{ + margin: 1.25em 0.5em 0.5em 0.5em; +} + +.extra-network-cards .nocards h1{ + font-size: 1.5em; + margin-bottom: 1em; +} + +.extra-network-cards .nocards li{ + margin-left: 0.5em; +} + +.extra-network-cards .card{ + display: inline-block; + margin: 0.5em; + width: 16em; + height: 24em; + box-shadow: 0 0 5px rgba(128, 128, 128, 0.5); + border-radius: 0.2em; + position: relative; + + background-size: auto 100%; + background-position: center; + overflow: hidden; + cursor: pointer; + + background-image: url('./file=html/card-no-preview.png') +} + +.extra-network-cards .card:hover{ + box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); +} + +.extra-network-cards .card .actions .additional{ + display: none; +} + +.extra-network-cards .card .actions{ + position: absolute; + bottom: 0; + left: 0; + right: 0; + padding: 0.5em; + color: white; + background: rgba(0,0,0,0.5); + box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5); + text-shadow: 0 0 0.2em black; +} + +.extra-network-cards .card .actions:hover{ + box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important; +} + +.extra-network-cards .card .actions .name{ + font-size: 1.7em; + font-weight: bold; + line-break: anywhere; +} + +.extra-network-cards .card .actions:hover .additional{ + display: block; +} + +.extra-network-cards .card ul{ + margin: 0.25em 0 0.75em 0.25em; + cursor: unset; +} + +.extra-network-cards .card ul a{ + cursor: pointer; +} + +.extra-network-cards .card ul a:hover{ + color: red; +} + diff --git a/webui.py b/webui.py index 865a7300..e8dd822a 100644 --- a/webui.py +++ b/webui.py @@ -9,16 +9,18 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook, errors +from modules import import_hook, errors, extra_networks +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path import torch + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -84,10 +86,17 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) - shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: try: @@ -209,6 +218,15 @@ def webui(): modules.sd_models.list_models() + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if __name__ == "__main__": if cmd_opts.nowebui: -- cgit v1.2.3 From 5eee2ac39863f9e44591b50d0710dd2615416a13 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:15:42 +0100 Subject: add data-dir flag and set all user data directories based on it --- modules/extensions.py | 2 +- modules/generation_parameters_copypaste.py | 4 ++-- modules/gfpgan_model.py | 5 ++--- modules/hashes.py | 4 +++- modules/interrogate.py | 2 +- modules/paths.py | 10 +++++++++- modules/processing.py | 3 ++- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 5 ++--- modules/shared.py | 11 ++++++----- modules/textual_inversion/preprocess.py | 5 ++--- modules/ui.py | 6 +++--- modules/ui_extensions.py | 2 +- modules/upscaler.py | 5 ++--- 14 files changed, 39 insertions(+), 31 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/extensions.py b/modules/extensions.py index b522125c..92ee8144 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -7,7 +7,7 @@ import git from modules import paths, shared extensions = [] -extensions_dir = os.path.join(paths.script_path, "extensions") +extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..35f72808 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.shared import script_path +from modules.paths import data_path, script_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: - filename = os.path.join(script_path, "params.txt") + filename = os.path.join(data_path, "params.txt") if os.path.exists(filename): with open(filename, "r", encoding="utf8") as file: prompt = file.read() diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 1e2dbc32..fbe6215a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -6,12 +6,11 @@ import facexlib import gfpgan import modules.face_restoration -from modules import shared, devices, modelloader -from modules.paths import models_path +from modules import paths, shared, devices, modelloader model_dir = "GFPGAN" user_path = None -model_path = os.path.join(models_path, model_dir) +model_path = os.path.join(paths.models_path, model_dir) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" have_gfpgan = False loaded_gfpgan_model = None diff --git a/modules/hashes.py b/modules/hashes.py index b85a7580..819362a3 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -4,8 +4,10 @@ import os.path import filelock +from modules.paths import data_path -cache_filename = "cache.json" + +cache_filename = os.path.join(data_path, "cache.json") cache_data = None diff --git a/modules/interrogate.py b/modules/interrogate.py index c72ff694..cbb80683 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -12,7 +12,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader, errors +from modules import devices, paths, shared, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' diff --git a/modules/paths.py b/modules/paths.py index 20b3e4d8..08e6f9b9 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -4,7 +4,15 @@ import sys import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -models_path = os.path.join(script_path, "models") + +# Parse the --data-dir flag first so we can use it as a base for our other argument default values +parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) +cmd_opts_pre = parser.parse_known_args()[0] +data_path = cmd_opts_pre.data_dir +models_path = os.path.join(data_path, "models") + +# data_path = cmd_opts_pre.data sys.path.insert(0, script_path) # search for directory of stable diffusion in following places diff --git a/modules/processing.py b/modules/processing.py index 262806a1..5072fc40 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared +import modules.paths as paths import modules.face_restoration import modules.images as images import modules.styles @@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) diff --git a/modules/sd_models.py b/modules/sd_models.py index 37dad18d..b2d48a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,13 +12,13 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} checkpoint_alisases = {} @@ -307,7 +307,7 @@ def enable_midas_autodownload(): location automatically. """ - midas_path = os.path.join(models_path, 'midas') + midas_path = os.path.join(paths.models_path, 'midas') # stable-diffusion-stability-ai hard-codes the midas model path to # a location that differs from where other scripts using this model look. diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 4ce238b8..9b00f76e 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,13 +3,12 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks, sd_models -from modules.paths import models_path +from modules import paths, shared, devices, script_callbacks, sd_models import glob from copy import deepcopy -vae_path = os.path.abspath(os.path.join(models_path, "VAE")) +vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_dict = {} diff --git a/modules/shared.py b/modules/shared.py index 14be993d..474fcc42 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.styles import modules.devices as devices from modules import localization, extensions, script_loading, errors, ui_components, shared_items -from modules.paths import models_path, script_path +from modules.paths import models_path, script_path, data_path demo = None @@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") @@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") -parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") @@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) +parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) -parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) +parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index c0ac11d3..2239cb84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -6,8 +6,7 @@ import sys import tqdm import time -from modules import shared, images, deepbooru -from modules.paths import models_path +from modules import paths, shared, images, deepbooru from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop @@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = None try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) + dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) except Exception as e: print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..0117df3e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path +from modules.paths import script_path, data_path from modules.shared import opts, cmd_opts, restricted_opts @@ -1497,8 +1497,8 @@ def create_ui(): with open(cssfile, "r", encoding="utf8") as file: css += file.read() + "\n" - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + if os.path.exists(os.path.join(data_path, "user.css")): + with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file: css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 742e745e..66a41865 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url): normalized_url = normalize_git_url(url) assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' - tmpdir = os.path.join(paths.script_path, "tmp", dirname) + tmpdir = os.path.join(paths.data_path, "tmp", dirname) try: shutil.rmtree(tmpdir, True) diff --git a/modules/upscaler.py b/modules/upscaler.py index a5bf5acb..e2eaa730 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -11,7 +11,6 @@ from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) -from modules.paths import models_path class Upscaler: @@ -39,7 +38,7 @@ class Upscaler: self.mod_scale = None if self.model_path is None and self.name: - self.model_path = os.path.join(models_path, self.name) + self.model_path = os.path.join(shared.models_path, self.name) if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) @@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler): def __init__(self, dirname=None): super().__init__(False) self.name = "Nearest" - self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file + self.scalers = [UpscalerData("Nearest", None, self)] -- cgit v1.2.3 From 6b3981c0685cd1df750df4eb51823f1cfd70c6d5 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:00:09 +0100 Subject: clean up unused script_path imports --- modules/codeformer_model.py | 2 +- modules/generation_parameters_copypaste.py | 2 +- webui.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) (limited to 'modules/generation_parameters_copypaste.py') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ab40d842..01fb7bd8 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,7 +8,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader -from modules.paths import script_path, models_path +from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 35f72808..773c5c0e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.paths import data_path, script_path +from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image diff --git a/webui.py b/webui.py index e1565a8d..41f32f5c 100644 --- a/webui.py +++ b/webui.py @@ -15,7 +15,6 @@ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call -from modules.paths import script_path import torch -- cgit v1.2.3