diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-04 19:05:40 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-04 19:05:50 +0000 |
commit | 45601766409e531d2b4ee512bf1433600f140183 (patch) | |
tree | 44fc1f32cd7d8a51242d75d76c7b300d69f6d8ae | |
parent | 31a9966b9d76cb9a2dd7c09c47e236fae33836e2 (diff) | |
download | stable-diffusion-webui-gfx803-45601766409e531d2b4ee512bf1433600f140183.tar.gz stable-diffusion-webui-gfx803-45601766409e531d2b4ee512bf1433600f140183.tar.bz2 stable-diffusion-webui-gfx803-45601766409e531d2b4ee512bf1433600f140183.zip |
added VAE selection to checkpoint user metadata
-rw-r--r-- | modules/extra_networks.py | 19 | ||||
-rw-r--r-- | modules/sd_vae.py | 13 | ||||
-rw-r--r-- | modules/ui_extra_networks.py | 13 | ||||
-rw-r--r-- | modules/ui_extra_networks_checkpoints.py | 3 | ||||
-rw-r--r-- | modules/ui_extra_networks_checkpoints_user_metadata.py | 60 |
5 files changed, 96 insertions, 12 deletions
diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 6ae07e91..fa28ac75 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,3 +1,5 @@ +import json
+import os
import re
from collections import defaultdict
@@ -177,3 +179,20 @@ def parse_prompts(prompts): return res, extra_data
+
+def get_user_metadata(filename):
+ if filename is None:
+ return {}
+
+ basename, ext = os.path.splitext(filename)
+ metadata_filename = basename + '.json'
+
+ metadata = {}
+ try:
+ if os.path.isfile(metadata_filename):
+ with open(metadata_filename, "r", encoding="utf8") as file:
+ metadata = json.load(file)
+ except Exception as e:
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+ return metadata
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 84271db0..0bd5e19b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,6 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -16,6 +16,7 @@ checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: return base_vae @@ -100,6 +101,16 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return None, None + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return vae_from_metadata, "from user metadata" + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f2752f10..c6390db7 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,7 +2,7 @@ import os.path import urllib.parse
from pathlib import Path
-from modules import shared, ui_extra_networks_user_metadata, errors
+from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr
@@ -101,16 +101,7 @@ class ExtraNetworksPage: def read_user_metadata(self, item):
filename = item.get("filename", None)
- basename, ext = os.path.splitext(filename)
- metadata_filename = basename + '.json'
-
- metadata = {}
- try:
- if os.path.isfile(metadata_filename):
- with open(metadata_filename, "r", encoding="utf8") as file:
- metadata = json.load(file)
- except Exception as e:
- errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+ metadata = extra_networks.get_user_metadata(filename)
desc = metadata.get("description", None)
if desc is not None:
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 891d8f2c..77885022 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -3,6 +3,7 @@ import os from modules import shared, ui_extra_networks, sd_models
from modules.ui_extra_networks import quote_js
+from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
+ def create_user_metadata_editor(self, ui, tabname):
+ return CheckpointUserMetadataEditor(ui, tabname, self)
diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py new file mode 100644 index 00000000..2c69aab8 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -0,0 +1,60 @@ +import gradio as gr
+
+from modules import ui_extra_networks_user_metadata, sd_vae
+from modules.ui_common import create_refresh_button
+
+
+class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
+ def __init__(self, ui, tabname, page):
+ super().__init__(ui, tabname, page)
+
+ self.select_vae = None
+
+ def save_user_metadata(self, name, desc, notes, vae):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["notes"] = notes
+ user_metadata["vae"] = vae
+
+ self.write_user_metadata(name, user_metadata)
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+ values = super().put_values_into_components(name)
+
+ return [
+ *values[0:5],
+ user_metadata.get('vae', ''),
+ ]
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ with gr.Row():
+ self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
+ create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ self.create_default_buttons()
+
+ viewed_components = [
+ self.edit_name,
+ self.edit_description,
+ self.html_filedata,
+ self.html_preview,
+ self.edit_notes,
+ self.select_vae,
+ ]
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ edited_components = [
+ self.edit_description,
+ self.edit_notes,
+ self.select_vae,
+ ]
+
+ self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|