aboutsummaryrefslogtreecommitdiffstats
path: root/modules/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/ui.py')
-rw-r--r--modules/ui.py42
1 files changed, 24 insertions, 18 deletions
diff --git a/modules/ui.py b/modules/ui.py
index d51f7a08..4958036a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -872,29 +872,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
with gr.Row():
- ckpt_name_list = sorted([x.title for x in modules.sd_models.checkpoints_list.values()])
- primary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_primary_model_name", label="Primary Model Name")
- secondary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
- submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
+ save_as_half = gr.Checkbox(value=False, label="Safe as float16")
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
- submit.click(
- fn=run_modelmerger,
- inputs=[
- primary_model_name,
- secondary_model_name,
- interp_method,
- interp_amount
- ],
- outputs=[
- submit_result,
- ]
- )
-
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -918,6 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
return comp(label=info.label, value=fun, **(args or {}))
components = []
+ component_dict = {}
def run_settings(*args):
changed = 0
@@ -973,7 +961,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
- components.append(create_setting_component(k))
+ component = create_setting_component(k)
+ component_dict[k] = component
+ components.append(component)
items_displayed += 1
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
@@ -1024,6 +1014,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[result, text_settings],
)
+ modelmerger_merge.click(
+ fn=run_modelmerger,
+ inputs=[
+ primary_model_name,
+ secondary_model_name,
+ interp_method,
+ interp_amount,
+ save_as_half,
+ ],
+ outputs=[
+ submit_result,
+ primary_model_name,
+ secondary_model_name,
+ component_dict['sd_model_checkpoint'],
+ ]
+ )
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]