From e5d3ae2bf4e9d39c35e6edc96d6449fd42528e55 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 20:39:04 +0300 Subject: user metadata system for custom networks --- extensions-builtin/Lora/ui_extra_networks_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/ui_extra_networks_lora.py') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index da49790b..29b16c1c 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -20,7 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): yield { "name": name, - "filename": path, + "filename": lora_on_disk.filename, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), -- cgit v1.2.3 From 11f339733de860b0b51adebe15dc945df7189edf Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 00:56:53 +0300 Subject: add lora user metadata editor dialog inspired by MrKuenning's mockup from #7458 --- extensions-builtin/Lora/ui_edit_user_metadata.py | 187 ++++++++++++++++++++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 17 +- javascript/extraNetworks.js | 18 ++- modules/ui_extra_networks_user_metadata.py | 23 ++- style.css | 9 +- 5 files changed, 241 insertions(+), 13 deletions(-) create mode 100644 extensions-builtin/Lora/ui_edit_user_metadata.py (limited to 'extensions-builtin/Lora/ui_extra_networks_lora.py') diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py new file mode 100644 index 00000000..c7dbd1c1 --- /dev/null +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -0,0 +1,187 @@ +import html +import json +import random + +import gradio as gr +import re + +from modules import ui_extra_networks_user_metadata + + +def is_non_comma_tagset(tags): + average_tag_length = sum(len(x) for x in tags.keys()) / len(tags) + + return average_tag_length >= 16 + + +re_word = re.compile(r"[-_\w']+") +re_comma = re.compile(r" *, *") + + +def build_tags(metadata): + tags = {} + + for _, tags_dict in metadata.get("ss_tag_frequency", {}).items(): + for tag, tag_count in tags_dict.items(): + tag = tag.strip() + tags[tag] = tags.get(tag, 0) + int(tag_count) + + if tags and is_non_comma_tagset(tags): + new_tags = {} + + for text, text_count in tags.items(): + for word in re.findall(re_word, text): + if len(word) < 3: + continue + + new_tags[word] = new_tags.get(word, 0) + text_count + + tags = new_tags + + ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True) + + return [(tag, tags[tag]) for tag in ordered_tags] + + +class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.taginfo = None + self.edit_activation_text = None + self.slider_preferred_weight = None + self.edit_notes = None + + def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["activation text"] = activation_text + user_metadata["preferred weight"] = preferred_weight + user_metadata["notes"] = notes + + self.write_user_metadata(name, user_metadata) + + def get_metadata_table(self, name): + table = super().get_metadata_table(name) + item = self.page.items.get(name, {}) + metadata = json.loads(item.get("metadata") or '{}') + + keys = [ + ('ss_sd_model_name', "Model:"), + ('ss_resolution', "Resolution:"), + ('ss_clip_skip', "Clip skip:"), + ] + + for key, label in keys: + value = metadata.get(key, None) + if value is not None and str(value) != "None": + table.append((label, html.escape(value))) + + image_count = 0 + for _, params in metadata.get("ss_dataset_dirs", {}).items(): + image_count += int(params.get("img_count", 0)) + + if image_count: + table.append(("Dataset size:", image_count)) + + return table + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + item = self.page.items.get(name, {}) + metadata = json.loads(item.get("metadata") or '{}') + + tags = build_tags(metadata) + gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] + + return [ + *values[0:4], + gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), + user_metadata.get('activation text', ''), + float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('notes', ''), + gr.update(visible=True if tags else False), + gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), + ] + + def generate_random_prompt(self, name): + item = self.page.items.get(name, {}) + metadata = json.loads(item.get("metadata") or '{}') + tags = build_tags(metadata) + + return self.generate_random_prompt_from_tags(tags) + + def generate_random_prompt_from_tags(self, tags): + max_count = None + res = [] + for tag, count in tags: + if not max_count: + max_count = count + + v = random.random() * max_count + if count > v: + res.append(tag) + + return ", ".join(sorted(res)) + + def create_editor(self): + self.create_default_editor_elems() + + self.taginfo = gr.HighlightedText(label="Tags") + self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") + self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) + + with gr.Row() as row_random_prompt: + with gr.Column(scale=8): + random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) + + with gr.Column(scale=1, min_width=120): + generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg") + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt]) + + def select_tag(activation_text, evt: gr.SelectData): + tag = evt.value[0] + + words = re.split(re_comma, activation_text) + if tag in words: + words = [x for x in words if x != tag and x.strip()] + return ", ".join(words) + + return activation_text + ", " + tag if activation_text else tag + + self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.taginfo, + self.edit_activation_text, + self.slider_preferred_weight, + self.edit_notes, + row_random_prompt, + random_prompt, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.edit_activation_text, + self.slider_preferred_weight, + self.edit_notes, + ] + + self.button_save\ + .click(fn=self.save_lora_user_metadata, inputs=[self.edit_name_input, *edited_components], outputs=[]) \ + .then(fn=None, _js="extraNetworksReloadAll") diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 29b16c1c..95296275 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -3,6 +3,7 @@ import os import lora from modules import shared, ui_extra_networks +from ui_edit_user_metadata import LoraUserMetadataEditor class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): @@ -18,19 +19,29 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): alias = lora_on_disk.get_alias() - yield { + item = { "name": name, "filename": lora_on_disk.filename, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), - "prompt": json.dumps(f""), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, - } + self.read_user_metadata(item) + activation_text = item["user_metadata"].get("activation text") + preferred_weight = item["user_metadata"].get("preferred weight", 0.0) + item["prompt"] = json.dumps(f"") + + if activation_text: + item["prompt"] += " + " + json.dumps(" " + activation_text) + + yield item + def allowed_directories_for_previews(self): return [shared.cmd_opts.lora_dir] + def create_user_metadata_editor(self, ui, tabname): + return LoraUserMetadataEditor(ui, tabname, self) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 7007b353..8b67bf2b 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -113,7 +113,7 @@ function setupExtraNetworks() { onUiLoaded(setupExtraNetworks); -var re_extranet = /<([^:]+:[^:]+):[\d.]+>/; +var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g; function tryToRemoveExtraNetworkFromPrompt(textarea, text) { @@ -121,15 +121,22 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var replaced = false; var newTextareaText; if (m) { + var extraTextAfterNet = m[2]; var partToSearch = m[1]; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) { + var foundAtPosition = -1; + newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { m = found.match(re_extranet); if (m[1] == partToSearch) { replaced = true; + foundAtPosition = pos; return ""; } return found; }); + + if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); + } } else { newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { if (found == text) { @@ -288,3 +295,10 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { event.stopPropagation(); } + +function extraNetworksReloadAll() { + closePopup(); + + gradioApp().getElementById('txt2img_extra_refresh').click(); + gradioApp().getElementById('img2img_extra_refresh').click(); +} diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 8d20d026..0dbd7419 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -52,7 +52,7 @@ class UserMetadataEditor: def create_default_buttons(self): - with gr.Row(): + with gr.Row(elem_classes="edit-user-metadata-buttons"): self.button_cancel = gr.Button('Cancel') self.button_replace_preview = gr.Button('Replace preview', variant='primary') self.button_save = gr.Button('Save', variant='primary') @@ -88,8 +88,7 @@ class UserMetadataEditor: stats = os.stat(filename) params = [ ('File size: ', sysinfo.pretty_bytes(stats.st_size)), - ('Created: ', datetime.datetime.fromtimestamp(stats.st_ctime).strftime('%Y-%m-%d %H:%M')), - ('Last modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), + ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ] return params @@ -100,7 +99,12 @@ class UserMetadataEditor: def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) - params = self.get_metadata_table(name) + try: + params = self.get_metadata_table(name) + except Exception as e: + errors.display(e, f"reading metadata info for {name}") + params = [] + table = '' + "".join(f"" for name, value in params) + '' return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) @@ -128,7 +132,9 @@ class UserMetadataEditor: .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) - self.button_save.click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[]).then(fn=None, _js="closePopup") + self.button_save\ + .click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[])\ + .then(fn=None, _js="extraNetworksReloadAll") def create_ui(self): with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: @@ -142,7 +148,7 @@ class UserMetadataEditor: def save_preview(self, index, gallery, name): if len(gallery) == 0: print("There is no image in gallery to save as a preview.") - return [self.get_card_html(name)] + [page.create_html(self.ui.tabname) for page in self.ui.stored_extra_pages] + return [self.get_card_html(name)] + self.regenerate_ui_pages() item = self.page.items.get(name, {}) @@ -156,7 +162,10 @@ class UserMetadataEditor: images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - return [self.get_card_html(name)] + [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + return [self.get_card_html(name)] + self.regenerate_ui_pages() + + def regenerate_ui_pages(self): + return [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] def setup_ui(self, gallery): self.button_replace_preview.click( diff --git a/style.css b/style.css index 4431c1aa..af6344a8 100644 --- a/style.css +++ b/style.css @@ -1004,7 +1004,7 @@ footer { } div.block.gradio-box.edit-user-metadata { - min-width: 56em; + width: 56em; background: var(--body-background-fill); padding: 2em !important; } @@ -1021,3 +1021,10 @@ div.block.gradio-box.edit-user-metadata { .edit-user-metadata .wrap.translucent{ background: var(--body-background-fill); } +.edit-user-metadata .gradio-highlightedtext span{ + word-break: break-word; +} + +.edit-user-metadata-buttons{ + margin-top: 1.5em; +} -- cgit v1.2.3 From a1d6ada69ac686a628e79b61b8f86d01592a7209 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 08:38:23 +0300 Subject: allow refreshing single card after editing user metadata instead of all cards --- extensions-builtin/Lora/ui_edit_user_metadata.py | 4 +- extensions-builtin/Lora/ui_extra_networks_lora.py | 54 +++++++++++++---------- html/extra-networks-card.html | 2 +- javascript/extraNetworks.js | 19 +++++--- modules/ui_extra_networks.py | 20 +++++++++ modules/ui_extra_networks_checkpoints.py | 31 +++++++------ modules/ui_extra_networks_hypernets.py | 31 +++++++------ modules/ui_extra_networks_textual_inversion.py | 30 +++++++------ modules/ui_extra_networks_user_metadata.py | 38 ++++++++++------ 9 files changed, 142 insertions(+), 87 deletions(-) (limited to 'extensions-builtin/Lora/ui_extra_networks_lora.py') diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c7dbd1c1..2aa65223 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -182,6 +182,4 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_notes, ] - self.button_save\ - .click(fn=self.save_lora_user_metadata, inputs=[self.edit_name_input, *edited_components], outputs=[]) \ - .then(fn=None, _js="extraNetworksReloadAll") + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 95296275..80e741dc 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -13,31 +13,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def refresh(self): lora.list_available_loras() - def list_items(self): - for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()): - path, ext = os.path.splitext(lora_on_disk.filename) - - alias = lora_on_disk.get_alias() - - item = { - "name": name, - "filename": lora_on_disk.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename), - "local_preview": f"{path}.{shared.opts.samples_format}", - "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, - "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, - } - - self.read_user_metadata(item) - activation_text = item["user_metadata"].get("activation text") - preferred_weight = item["user_metadata"].get("preferred weight", 0.0) - item["prompt"] = json.dumps(f"") - - if activation_text: - item["prompt"] += " + " + json.dumps(" " + activation_text) + def create_item(self, name, index=None): + lora_on_disk = lora.available_loras.get(name) + + path, ext = os.path.splitext(lora_on_disk.filename) + + alias = lora_on_disk.get_alias() + + item = { + "name": name, + "filename": lora_on_disk.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(lora_on_disk.filename), + "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, + "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + } + self.read_user_metadata(item) + activation_text = item["user_metadata"].get("activation text") + preferred_weight = item["user_metadata"].get("preferred weight", 0.0) + item["prompt"] = json.dumps(f"") + + if activation_text: + item["prompt"] += " + " + json.dumps(" " + activation_text) + + return item + + def list_items(self): + for index, name in enumerate(lora.available_loras): + item = self.create_item(name, index) yield item def allowed_directories_for_previews(self): diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index fb787ffe..eb8b1a67 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,4 +1,4 @@ -
+
{background_image}
{edit_button} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 8b67bf2b..e453094a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -296,9 +296,18 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { event.stopPropagation(); } -function extraNetworksReloadAll() { - closePopup(); - - gradioApp().getElementById('txt2img_extra_refresh').click(); - gradioApp().getElementById('img2img_extra_refresh').click(); +function extraNetworksRefreshSingleCard(page, tabname, name) { + requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) { + if (data && data.html) { + var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function + + var newDiv = document.createElement('DIV'); + newDiv.innerHTML = data.html; + var newCard = newDiv.firstElementChild; + + newCard.style = ''; + card.parentElement.insertBefore(newCard, card); + card.parentElement.removeChild(card); + } + }); } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index eaae6217..42c4d0ac 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -51,9 +51,26 @@ def get_metadata(page: str = "", item: str = ""): return JSONResponse({"metadata": metadata}) +def get_single_card(page: str = "", tabname: str = "", name: str = ""): + from starlette.responses import JSONResponse + + page = next(iter([x for x in extra_pages if x.name == page]), None) + + try: + item = page.create_item(name) + except Exception as e: + errors.display(e, "creating item for extra network") + item = page.items.get(name) + + item_html = page.create_html_for_item(item, tabname) + + return JSONResponse({"html": item_html}) + + def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) + app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) class ExtraNetworksPage: @@ -168,6 +185,9 @@ class ExtraNetworksPage: return res + def create_item(self, name, index=None): + raise NotImplementedError() + def list_items(self): raise NotImplementedError() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index bb5071e6..ef8cdf35 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -12,21 +12,24 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.refresh_checkpoints() + def create_item(self, name, index=None): + checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name) + path, ext = os.path.splitext(checkpoint.filename) + return { + "name": checkpoint.name_for_extra, + "filename": checkpoint.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": f"{path}.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, + + } + def list_items(self): - checkpoint: sd_models.CheckpointInfo - for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()): - path, ext = os.path.splitext(checkpoint.filename) - yield { - "name": checkpoint.name_for_extra, - "filename": checkpoint.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), - "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', - "local_preview": f"{path}.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, - - } + for index, name in enumerate(sd_models.checkpoints_list): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index ea0b7a44..8dae23c6 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -11,21 +11,24 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.reload_hypernetworks() + def create_item(self, name, index=None): + full_path = shared.hypernetworks[name] + path, ext = os.path.splitext(full_path) + + return { + "name": name, + "filename": full_path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(path), + "prompt": json.dumps(f""), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, + } + def list_items(self): - for index, (name, full_path) in enumerate(shared.hypernetworks.items()): - path, ext = os.path.splitext(full_path) - - yield { - "name": name, - "filename": full_path, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(path), - "prompt": json.dumps(f""), - "local_preview": f"{path}.preview.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, - - } + for index, name in enumerate(shared.hypernetworks): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 58a61c55..159f2d64 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,20 +12,24 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + def create_item(self, name, index=None): + embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) + + path, ext = os.path.splitext(embedding.filename) + return { + "name": name, + "filename": embedding.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(embedding.filename), + "prompt": json.dumps(embedding.name), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, + } + def list_items(self): - for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()): - path, ext = os.path.splitext(embedding.filename) - yield { - "name": embedding.name, - "filename": embedding.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(embedding.filename), - "prompt": json.dumps(embedding.name), - "local_preview": f"{path}.preview.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, - - } + for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 0dbd7419..01ff4e4b 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -23,8 +23,10 @@ class UserMetadataEditor: self.edit_name = None self.edit_description = None + self.edit_notes = None self.html_filedata = None self.html_preview = None + self.html_status = None self.button_cancel = None self.button_replace_preview = None @@ -57,6 +59,8 @@ class UserMetadataEditor: self.button_replace_preview = gr.Button('Replace preview', variant='primary') self.button_save = gr.Button('Save', variant='primary') + self.html_status = gr.HTML(elem_classes="edit-user-metadata-status") + self.button_cancel.click(fn=None, _js="closePopup") def get_card_html(self, name): @@ -107,7 +111,7 @@ class UserMetadataEditor: table = '' + "".join(f"" for name, value in params) + '' - return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''), def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) @@ -117,24 +121,30 @@ class UserMetadataEditor: with open(basename + '.json', "w", encoding="utf8") as file: json.dump(metadata, file) - def save_user_metadata(self, name, desc): + def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc + user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) + def setup_save_handler(self, button, func, components): + button\ + .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\ + .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[]) + def create_editor(self): self.create_default_editor_elems() + self.edit_notes = gr.TextArea(label='Notes', lines=4) + self.create_default_buttons() self.button_edit\ - .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) - self.button_save\ - .click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[])\ - .then(fn=None, _js="extraNetworksReloadAll") + self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes]) def create_ui(self): with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: @@ -147,8 +157,7 @@ class UserMetadataEditor: def save_preview(self, index, gallery, name): if len(gallery) == 0: - print("There is no image in gallery to save as a preview.") - return [self.get_card_html(name)] + self.regenerate_ui_pages() + return self.get_card_html(name), "There is no image in gallery to save as a preview." item = self.page.items.get(name, {}) @@ -162,17 +171,20 @@ class UserMetadataEditor: images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - return [self.get_card_html(name)] + self.regenerate_ui_pages() - - def regenerate_ui_pages(self): - return [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + return self.get_card_html(name), '' def setup_ui(self, gallery): self.button_replace_preview.click( fn=self.save_preview, _js="function(x, y, z){return [selected_gallery_index(), y, z]}", inputs=[self.edit_name_input, gallery, self.edit_name_input], - outputs=[self.html_preview, *self.ui.pages] + outputs=[self.html_preview, self.html_status] + ).then( + fn=None, + _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", + inputs=[self.edit_name_input], + outputs=[] ) + -- cgit v1.2.3 From 47d9dd0240872dc70fd26bc1bf309f49fe17c104 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 09:25:32 +0300 Subject: speedup extra networks listing --- extensions-builtin/Lora/lora.py | 12 +++++++--- extensions-builtin/Lora/ui_edit_user_metadata.py | 9 ++++---- extensions-builtin/Lora/ui_extra_networks_lora.py | 9 ++++---- modules/cache.py | 27 ++++++++++++----------- modules/ui_extra_networks.py | 20 +++++++++++------ modules/ui_extra_networks_checkpoints.py | 4 ++-- modules/ui_extra_networks_hypernets.py | 4 ++-- modules/ui_extra_networks_textual_inversion.py | 4 ++-- 8 files changed, 51 insertions(+), 38 deletions(-) (limited to 'extensions-builtin/Lora/ui_extra_networks_lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cd46e6c7..c8710922 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -3,7 +3,7 @@ import re import torch from typing import Union -from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes +from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} @@ -78,9 +78,16 @@ class LoraOnDisk: self.metadata = {} self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + def read_metadata(): + metadata = sd_models.read_metadata_from_safetensors(filename) + metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text + + return metadata + if self.is_safetensors: try: - self.metadata = sd_models.read_metadata_from_safetensors(filename) + #self.metadata = sd_models.read_metadata_from_safetensors(filename) + self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) except Exception as e: errors.display(e, f"reading lora {filename}") @@ -91,7 +98,6 @@ class LoraOnDisk: self.metadata = m - self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text self.alias = self.metadata.get('ss_output_name', self.name) self.hash = None diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 2aa65223..6db63b09 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -1,5 +1,4 @@ import html -import json import random import gradio as gr @@ -64,7 +63,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def get_metadata_table(self, name): table = super().get_metadata_table(name) item = self.page.items.get(name, {}) - metadata = json.loads(item.get("metadata") or '{}') + metadata = item.get("metadata") or {} keys = [ ('ss_sd_model_name', "Model:"), @@ -91,7 +90,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) values = super().put_values_into_components(name) item = self.page.items.get(name, {}) - metadata = json.loads(item.get("metadata") or '{}') + metadata = item.get("metadata") or {} tags = build_tags(metadata) gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] @@ -108,7 +107,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def generate_random_prompt(self, name): item = self.page.items.get(name, {}) - metadata = json.loads(item.get("metadata") or '{}') + metadata = item.get("metadata") or {} tags = build_tags(metadata) return self.generate_random_prompt_from_tags(tags) @@ -142,7 +141,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_notes = gr.TextArea(label='Notes', lines=4) - generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt]) + generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False) def select_tag(activation_text, evt: gr.SelectData): tag = evt.value[0] diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 80e741dc..b2bc1810 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,8 +1,8 @@ -import json import os import lora from modules import shared, ui_extra_networks +from modules.ui_extra_networks import quote_js from ui_edit_user_metadata import LoraUserMetadataEditor @@ -20,6 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): alias = lora_on_disk.get_alias() + # in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string item = { "name": name, "filename": lora_on_disk.filename, @@ -27,17 +28,17 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), "local_preview": f"{path}.{shared.opts.samples_format}", - "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, + "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, } self.read_user_metadata(item) activation_text = item["user_metadata"].get("activation text") preferred_weight = item["user_metadata"].get("preferred weight", 0.0) - item["prompt"] = json.dumps(f"") + item["prompt"] = quote_js(f"") if activation_text: - item["prompt"] += " + " + json.dumps(" " + activation_text) + item["prompt"] += " + " + quote_js(" " + activation_text) return item diff --git a/modules/cache.py b/modules/cache.py index 4c2db604..07180602 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,12 +1,12 @@ import json import os.path - -import filelock +import threading from modules.paths import data_path, script_path cache_filename = os.path.join(data_path, "cache.json") cache_data = None +cache_lock = threading.Lock() def dump_cache(): @@ -14,7 +14,7 @@ def dump_cache(): Saves all cache data to a file. """ - with filelock.FileLock(f"{cache_filename}.lock"): + with cache_lock: with open(cache_filename, "w", encoding="utf8") as file: json.dump(cache_data, file, indent=4) @@ -33,17 +33,18 @@ def cache(subsection): global cache_data if cache_data is None: - with filelock.FileLock(f"{cache_filename}.lock"): - if not os.path.isfile(cache_filename): - cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + with cache_lock: + if cache_data is None: + if not os.path.isfile(cache_filename): cache_data = {} + else: + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 42c4d0ac..f9d1fa31 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -73,6 +73,12 @@ def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) +def quote_js(s): + s = s.replace('\\', '\\\\') + s = s.replace('"', '\\"') + return f'"{s}"' + + class ExtraNetworksPage: def __init__(self, title): self.title = title @@ -203,7 +209,7 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' @@ -211,9 +217,9 @@ class ExtraNetworksPage: metadata_button = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" + metadata_button = f"" - edit_button = f"
" + edit_button = f"
" local_path = "" filename = item.get("filename", "") @@ -239,12 +245,12 @@ class ExtraNetworksPage: "background_image": background_image, "style": f"'display: none; {height}{width}'", "prompt": item.get("prompt", None), - "tabname": json.dumps(tabname), - "local_preview": json.dumps(item["local_preview"]), + "tabname": quote_js(tabname), + "local_preview": quote_js(item["local_preview"]), "name": item["name"], "description": (item.get("description") or ""), "card_clicked": onclick, - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), "metadata_button": metadata_button, "edit_button": edit_button, @@ -359,7 +365,7 @@ def create_ui(container, button, tabname): page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ef8cdf35..e73b5b1f 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -1,8 +1,8 @@ import html -import json import os from modules import shared, ui_extra_networks, sd_models +from modules.ui_extra_networks import quote_js class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -21,7 +21,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), - "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8dae23c6..e53ccb42 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,7 +1,7 @@ -import json import os from modules import shared, ui_extra_networks +from modules.ui_extra_networks import quote_js class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): @@ -21,7 +21,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(path), - "prompt": json.dumps(f""), + "prompt": quote_js(f""), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, } diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 159f2d64..d1794e50 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,7 +1,7 @@ -import json import os from modules import ui_extra_networks, sd_hijack, shared +from modules.ui_extra_networks import quote_js class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): @@ -22,7 +22,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(embedding.filename), - "prompt": json.dumps(embedding.name), + "prompt": quote_js(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, } -- cgit v1.2.3 From b75b004fe62826455f1aa77e849e7da13902cb17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 23:13:55 +0300 Subject: lora extension rework to include other types of networks --- extensions-builtin/Lora/extra_networks_lora.py | 18 +- extensions-builtin/Lora/lora.py | 537 ---------------------- extensions-builtin/Lora/lyco_helpers.py | 15 + extensions-builtin/Lora/network.py | 98 ++++ extensions-builtin/Lora/network_hada.py | 59 +++ extensions-builtin/Lora/network_lora.py | 70 +++ extensions-builtin/Lora/network_lyco.py | 39 ++ extensions-builtin/Lora/networks.py | 443 ++++++++++++++++++ extensions-builtin/Lora/scripts/lora_script.py | 79 ++-- extensions-builtin/Lora/ui_extra_networks_lora.py | 8 +- 10 files changed, 777 insertions(+), 589 deletions(-) delete mode 100644 extensions-builtin/Lora/lora.py create mode 100644 extensions-builtin/Lora/lyco_helpers.py create mode 100644 extensions-builtin/Lora/network.py create mode 100644 extensions-builtin/Lora/network_hada.py create mode 100644 extensions-builtin/Lora/network_lora.py create mode 100644 extensions-builtin/Lora/network_lyco.py create mode 100644 extensions-builtin/Lora/networks.py (limited to 'extensions-builtin/Lora/ui_extra_networks_lora.py') diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 66ee9c85..8a6639cf 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,5 +1,5 @@ from modules import extra_networks, shared -import lora +import networks class ExtraNetworkLora(extra_networks.ExtraNetwork): @@ -9,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): + if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -21,12 +21,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) - lora.load_loras(names, multipliers) + networks.load_networks(names, multipliers) if shared.opts.lora_add_hashes_to_infotext: - lora_hashes = [] - for item in lora.loaded_loras: - shorthash = item.lora_on_disk.shorthash + network_hashes = [] + for item in networks.loaded_networks: + shorthash = item.network_on_disk.shorthash if not shorthash: continue @@ -36,10 +36,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): alias = alias.replace(":", "").replace(",", "") - lora_hashes.append(f"{alias}: {shorthash}") + network_hashes.append(f"{alias}: {shorthash}") - if lora_hashes: - p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes) + if network_hashes: + p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) def deactivate(self, p): pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py deleted file mode 100644 index 9cdff6ed..00000000 --- a/extensions-builtin/Lora/lora.py +++ /dev/null @@ -1,537 +0,0 @@ -import os -import re -import torch -from typing import Union - -from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache - -metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} - -re_digits = re.compile(r"\d+") -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") -re_compiled = {} - -suffix_conversion = { - "attentions": {}, - "resnets": { - "conv1": "in_layers_2", - "conv2": "out_layers_3", - "time_emb_proj": "emb_layers_1", - "conv_shortcut": "skip_connection", - } -} - - -def convert_diffusers_name_to_compvis(key, is_sd2): - def match(match_list, regex_text): - regex = re_compiled.get(regex_text) - if regex is None: - regex = re.compile(regex_text) - re_compiled[regex_text] = regex - - r = re.match(regex, key) - if not r: - return False - - match_list.clear() - match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) - return True - - m = [] - - if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) - return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - - if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): - return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" - - if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): - return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" - - if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): - if is_sd2: - if 'mlp_fc1' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" - - if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): - if 'mlp_fc1' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return key - - -class LoraOnDisk: - def __init__(self, name, filename): - self.name = name - self.filename = filename - self.metadata = {} - self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - - def read_metadata(): - metadata = sd_models.read_metadata_from_safetensors(filename) - metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text - - return metadata - - if self.is_safetensors: - try: - self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) - except Exception as e: - errors.display(e, f"reading lora {filename}") - - if self.metadata: - m = {} - for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): - m[k] = v - - self.metadata = m - - self.alias = self.metadata.get('ss_output_name', self.name) - - self.hash = None - self.shorthash = None - self.set_hash( - self.metadata.get('sshs_model_hash') or - hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or - '' - ) - - def set_hash(self, v): - self.hash = v - self.shorthash = self.hash[0:12] - - if self.shorthash: - available_lora_hash_lookup[self.shorthash] = self - - def read_hash(self): - if not self.hash: - self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') - - def get_alias(self): - if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases: - return self.name - else: - return self.alias - - -class LoraModule: - def __init__(self, name, lora_on_disk: LoraOnDisk): - self.name = name - self.lora_on_disk = lora_on_disk - self.multiplier = 1.0 - self.modules = {} - self.mtime = None - - self.mentioned_name = None - """the text that was used to add lora to prompt - can be either name or an alias""" - - -class LoraUpDownModule: - def __init__(self): - self.up = None - self.down = None - self.alpha = None - - -def assign_lora_names_to_compvis_modules(sd_model): - lora_layer_mapping = {} - - if shared.sd_model.is_sdxl: - for i, embedder in enumerate(shared.sd_model.conditioner.embedders): - if not hasattr(embedder, 'wrapped'): - continue - - for name, module in embedder.wrapped.named_modules(): - lora_name = f'{i}_{name.replace(".", "_")}' - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - else: - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - - for name, module in shared.sd_model.model.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - - sd_model.lora_layer_mapping = lora_layer_mapping - - -def load_lora(name, lora_on_disk): - lora = LoraModule(name, lora_on_disk) - lora.mtime = os.path.getmtime(lora_on_disk.filename) - - sd = sd_models.read_state_dict(lora_on_disk.filename) - - # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 - if not hasattr(shared.sd_model, 'lora_layer_mapping'): - assign_lora_names_to_compvis_modules(shared.sd_model) - - keys_failed_to_match = {} - is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping - - for key_lora, weight in sd.items(): - key_lora_without_lora_parts, lora_key = key_lora.split(".", 1) - - key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2) - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - - if sd_module is None: - m = re_x_proj.match(key) - if m: - sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) - - # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" - if sd_module is None and "lora_unet" in key_lora_without_lora_parts: - key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model") - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts: - key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model") - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - - if sd_module is None: - keys_failed_to_match[key_lora] = key - continue - - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - - if lora_key == "alpha": - lora_module.alpha = weight.item() - continue - - if type(sd_module) == torch.nn.Linear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.MultiheadAttention: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) - else: - print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}') - continue - raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}") - - with torch.no_grad(): - module.weight.copy_(weight) - - module.to(device=devices.cpu, dtype=devices.dtype) - - if lora_key == "lora_up.weight": - lora_module.up = module - elif lora_key == "lora_down.weight": - lora_module.down = module - else: - raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha") - - if keys_failed_to_match: - print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") - - return lora - - -def load_loras(names, multipliers=None): - already_loaded = {} - - for lora in loaded_loras: - if lora.name in names: - already_loaded[lora.name] = lora - - loaded_loras.clear() - - loras_on_disk = [available_lora_aliases.get(name, None) for name in names] - if any(x is None for x in loras_on_disk): - list_available_loras() - - loras_on_disk = [available_lora_aliases.get(name, None) for name in names] - - failed_to_load_loras = [] - - for i, name in enumerate(names): - lora = already_loaded.get(name, None) - - lora_on_disk = loras_on_disk[i] - - if lora_on_disk is not None: - if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: - try: - lora = load_lora(name, lora_on_disk) - except Exception as e: - errors.display(e, f"loading Lora {lora_on_disk.filename}") - continue - - lora.mentioned_name = name - - lora_on_disk.read_hash() - - if lora is None: - failed_to_load_loras.append(name) - print(f"Couldn't find Lora with name {name}") - continue - - lora.multiplier = multipliers[i] if multipliers else 1.0 - loaded_loras.append(lora) - - if failed_to_load_loras: - sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) - - -def lora_calc_updown(lora, module, target): - with torch.no_grad(): - up = module.up.weight.to(target.device, dtype=target.dtype) - down = module.down.weight.to(target.device, dtype=target.dtype) - - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) - else: - updown = up @ down - - updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - - return updown - - -def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): - weights_backup = getattr(self, "lora_weights_backup", None) - - if weights_backup is None: - return - - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) - else: - self.weight.copy_(weights_backup) - - -def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): - """ - Applies the currently selected set of Loras to the weights of torch layer self. - If weights already have this particular set of loras applied, does nothing. - If not, restores orginal weights from backup and alters weights according to loras. - """ - - lora_layer_name = getattr(self, 'lora_layer_name', None) - if lora_layer_name is None: - return - - current_names = getattr(self, "lora_current_names", ()) - wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) - - weights_backup = getattr(self, "lora_weights_backup", None) - if weights_backup is None: - if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) - else: - weights_backup = self.weight.to(devices.cpu, copy=True) - - self.lora_weights_backup = weights_backup - - if current_names != wanted_names: - lora_restore_weights_from_backup(self) - - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None and hasattr(self, 'weight'): - self.weight += lora_calc_updown(lora, module, self.weight) - continue - - module_q = lora.modules.get(lora_layer_name + "_q_proj", None) - module_k = lora.modules.get(lora_layer_name + "_k_proj", None) - module_v = lora.modules.get(lora_layer_name + "_v_proj", None) - module_out = lora.modules.get(lora_layer_name + "_out_proj", None) - - if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: - updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) - updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) - updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) - updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - - self.in_proj_weight += updown_qkv - self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) - continue - - if module is None: - continue - - print(f'failed to calculate lora weights for layer {lora_layer_name}') - - self.lora_current_names = wanted_names - - -def lora_forward(module, input, original_forward): - """ - Old way of applying Lora by executing operations during layer's forward. - Stacking many loras this way results in big performance degradation. - """ - - if len(loaded_loras) == 0: - return original_forward(module, input) - - input = devices.cond_cast_unet(input) - - lora_restore_weights_from_backup(module) - lora_reset_cached_weight(module) - - res = original_forward(module, input) - - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is None: - continue - - module.up.to(device=devices.device) - module.down.to(device=devices.device) - - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - - return res - - -def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): - self.lora_current_names = () - self.lora_weights_backup = None - - -def lora_Linear_forward(self, input): - if shared.opts.lora_functional: - return lora_forward(self, input, torch.nn.Linear_forward_before_lora) - - lora_apply_weights(self) - - return torch.nn.Linear_forward_before_lora(self, input) - - -def lora_Linear_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) - - -def lora_Conv2d_forward(self, input): - if shared.opts.lora_functional: - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora) - - lora_apply_weights(self) - - return torch.nn.Conv2d_forward_before_lora(self, input) - - -def lora_Conv2d_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) - - -def lora_MultiheadAttention_forward(self, *args, **kwargs): - lora_apply_weights(self) - - return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) - - -def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) - - -def list_available_loras(): - available_loras.clear() - available_lora_aliases.clear() - forbidden_lora_aliases.clear() - available_lora_hash_lookup.clear() - forbidden_lora_aliases.update({"none": 1, "Addams": 1}) - - os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - - candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) - for filename in candidates: - if os.path.isdir(filename): - continue - - name = os.path.splitext(os.path.basename(filename))[0] - try: - entry = LoraOnDisk(name, filename) - except OSError: # should catch FileNotFoundError and PermissionError etc. - errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True) - continue - - available_loras[name] = entry - - if entry.alias in available_lora_aliases: - forbidden_lora_aliases[entry.alias.lower()] = 1 - - available_lora_aliases[name] = entry - available_lora_aliases[entry.alias] = entry - - -re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") - - -def infotext_pasted(infotext, params): - if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: - return # if the other extension is active, it will handle those fields, no need to do anything - - added = [] - - for k in params: - if not k.startswith("AddNet Model "): - continue - - num = k[13:] - - if params.get("AddNet Module " + num) != "LoRA": - continue - - name = params.get("AddNet Model " + num) - if name is None: - continue - - m = re_lora_name.match(name) - if m: - name = m.group(1) - - multiplier = params.get("AddNet Weight A " + num, "1.0") - - added.append(f"") - - if added: - params["Prompt"] += "\n" + "".join(added) - - -available_loras = {} -available_lora_aliases = {} -available_lora_hash_lookup = {} -forbidden_lora_aliases = {} -loaded_loras = [] - -list_available_loras() diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py new file mode 100644 index 00000000..9ea499fb --- /dev/null +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -0,0 +1,15 @@ +import torch + + +def make_weight_cp(t, wa, wb): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +def rebuild_conventional(up, down, shape, dyn_dim=None): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + if dyn_dim is not None: + up = up[:, :dyn_dim] + down = down[:dyn_dim, :] + return (up @ down).reshape(shape) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py new file mode 100644 index 00000000..a1fe6bbf --- /dev/null +++ b/extensions-builtin/Lora/network.py @@ -0,0 +1,98 @@ +import os +from collections import namedtuple + +import torch + +from modules import devices, sd_models, cache, errors, hashes, shared + +NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + + +class NetworkOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + self.metadata = {} + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + + def read_metadata(): + metadata = sd_models.read_metadata_from_safetensors(filename) + metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text + + return metadata + + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.alias = self.metadata.get('ss_output_name', self.name) + + self.hash = None + self.shorthash = None + self.set_hash( + self.metadata.get('sshs_model_hash') or + hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or + '' + ) + + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + + if self.shorthash: + import networks + networks.available_network_hash_lookup[self.shorthash] = self + + def read_hash(self): + if not self.hash: + self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') + + def get_alias(self): + import networks + if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: + return self.name + else: + return self.alias + + +class Network: # LoraModule + def __init__(self, name, network_on_disk: NetworkOnDisk): + self.name = name + self.network_on_disk = network_on_disk + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + self.mentioned_name = None + """the text that was used to add the network to prompt - can be either name or an alias""" + + +class ModuleType: + def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: + return None + + +class NetworkModule: + def __init__(self, net: Network, weights: NetworkWeights): + self.network = net + self.network_key = weights.network_key + self.sd_key = weights.sd_key + self.sd_module = weights.sd_module + + def calc_updown(self, target): + raise NotImplementedError() + + def forward(self, x, y): + raise NotImplementedError() + diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py new file mode 100644 index 00000000..15e7ffd8 --- /dev/null +++ b/extensions-builtin/Lora/network_hada.py @@ -0,0 +1,59 @@ +import lyco_helpers +import network +import network_lyco + + +class ModuleTypeHada(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): + return NetworkModuleHada(net, weights) + + return None + + +class NetworkModuleHada(network_lyco.NetworkModuleLyco): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["hada_w1_a"] + self.w1b = weights.w["hada_w1_b"] + self.dim = self.w1b.shape[0] + self.w2a = weights.w["hada_w2_a"] + self.w2b = weights.w["hada_w2_b"] + + self.t1 = weights.w.get("hada_t1") + self.t2 = weights.w.get("hada_t2") + + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + + if self.t1 is not None: + output_shape = [w1a.size(1), w1b.size(1)] + t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) + output_shape += t1.shape[2:] + else: + if len(w1b.shape) == 4: + output_shape += w1b.shape[2:] + updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) + + if self.t2 is not None: + t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + else: + updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) + + updown = updown1 * updown2 + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py new file mode 100644 index 00000000..b2d96537 --- /dev/null +++ b/extensions-builtin/Lora/network_lora.py @@ -0,0 +1,70 @@ +import torch + +import network +from modules import devices + + +class ModuleTypeLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): + return NetworkModuleLora(net, weights) + + return None + + +class NetworkModuleLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.up = self.create_module(weights.w["lora_up.weight"]) + self.down = self.create_module(weights.w["lora_down.weight"]) + self.alpha = weights.w["alpha"] if "alpha" in weights.w else None + + def create_module(self, weight, none_ok=False): + if weight is None and none_ok: + return None + + if type(self.sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) + else: + print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + return None + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + + def calc_updown(self, target): + up = self.up.weight.to(target.device, dtype=target.dtype) + down = self.down.weight.to(target.device, dtype=target.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + + return updown + + def forward(self, x, y): + self.up.to(device=devices.device) + self.down.to(device=devices.device) + + return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + + diff --git a/extensions-builtin/Lora/network_lyco.py b/extensions-builtin/Lora/network_lyco.py new file mode 100644 index 00000000..18a822fa --- /dev/null +++ b/extensions-builtin/Lora/network_lyco.py @@ -0,0 +1,39 @@ +import torch + +import lyco_helpers +import network +from modules import devices + + +class NetworkModuleLyco(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def finalize_updown(self, updown, orig_weight, output_shape): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + scale = ( + self.scale if self.scale is not None + else self.alpha / self.dim if self.dim is not None and self.alpha is not None + else 1.0 + ) + + return updown * scale * self.network.multiplier + diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py new file mode 100644 index 00000000..5b0ddfb6 --- /dev/null +++ b/extensions-builtin/Lora/networks.py @@ -0,0 +1,443 @@ +import os +import re + +import network +import network_lora +import network_hada + +import torch +from typing import Union + +from modules import shared, devices, sd_models, errors, scripts, sd_hijack + +module_types = [ + network_lora.ModuleTypeLora(), + network_hada.ModuleTypeHada(), +] + + +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} + + +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return key + + +def assign_network_names_to_compvis_modules(sd_model): + network_layer_mapping = {} + + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + network_name = f'{i}_{name.replace(".", "_")}' + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + for name, module in shared.sd_model.model.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + sd_model.network_layer_mapping = network_layer_mapping + + +def load_network(name, network_on_disk): + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + + sd = sd_models.read_state_dict(network_on_disk.filename) + + # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 + if not hasattr(shared.sd_model, 'network_layer_mapping'): + assign_network_names_to_compvis_modules(shared.sd_model) + + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping + + matched_networks = {} + + for key_network, weight in sd.items(): + key_network_without_network_parts, network_part = key_network.split(".", 1) + + key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) + + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + keys_failed_to_match[key_network] = key + continue + + if key not in matched_networks: + matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) + + matched_networks[key].w[network_part] = weight + + for key, weights in matched_networks.items(): + net_module = None + for nettype in module_types: + net_module = nettype.create_module(net, weights) + if net_module is not None: + break + + if net_module is None: + raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") + + net.modules[key] = net_module + + if keys_failed_to_match: + print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") + + return net + + +def load_networks(names, multipliers=None): + already_loaded = {} + + for net in loaded_networks: + if net.name in names: + already_loaded[net.name] = net + + loaded_networks.clear() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + + failed_to_load_networks = [] + + for i, name in enumerate(names): + net = already_loaded.get(name, None) + + network_on_disk = networks_on_disk[i] + + if network_on_disk is not None: + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: + try: + net = load_network(name, network_on_disk) + except Exception as e: + errors.display(e, f"loading network {network_on_disk.filename}") + continue + + net.mentioned_name = name + + network_on_disk.read_hash() + + if net is None: + failed_to_load_networks.append(name) + print(f"Couldn't find network with name {name}") + continue + + net.multiplier = multipliers[i] if multipliers else 1.0 + loaded_networks.append(net) + + if failed_to_load_networks: + sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) + + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + weights_backup = getattr(self, "network_weights_backup", None) + + if weights_backup is None: + return + + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of networks to the weights of torch layer self. + If weights already have this particular set of networks applied, does nothing. + If not, restores orginal weights from backup and alters weights according to networks. + """ + + network_layer_name = getattr(self, 'network_layer_name', None) + if network_layer_name is None: + return + + current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.network_weights_backup = weights_backup + + if current_names != wanted_names: + network_restore_weights_from_backup(self) + + for net in loaded_networks: + module = net.modules.get(network_layer_name, None) + if module is not None and hasattr(self, 'weight'): + with torch.no_grad(): + updown = module.calc_updown(self.weight) + + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + + self.weight += updown + + module_q = net.modules.get(network_layer_name + "_q_proj", None) + module_k = net.modules.get(network_layer_name + "_k_proj", None) + module_v = net.modules.get(network_layer_name + "_v_proj", None) + module_out = net.modules.get(network_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + with torch.no_grad(): + updown_q = module_q.calc_updown(self.in_proj_weight) + updown_k = module_k.calc_updown(self.in_proj_weight) + updown_v = module_v.calc_updown(self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate network weights for layer {network_layer_name}') + + self.network_current_names = wanted_names + + +def network_forward(module, input, original_forward): + """ + Old way of applying Lora by executing operations during layer's forward. + Stacking many loras this way results in big performance degradation. + """ + + if len(loaded_networks) == 0: + return original_forward(module, input) + + input = devices.cond_cast_unet(input) + + network_restore_weights_from_backup(module) + network_reset_cached_weight(module) + + y = original_forward(module, input) + + network_layer_name = getattr(module, 'network_layer_name', None) + for lora in loaded_networks: + module = lora.modules.get(network_layer_name, None) + if module is None: + continue + + y = module.forward(y, input) + + return y + + +def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + self.network_current_names = () + self.network_weights_backup = None + + +def network_Linear_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Linear_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Linear_forward_before_network(self, input) + + +def network_Linear_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + + +def network_Conv2d_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Conv2d_forward_before_network(self, input) + + +def network_Conv2d_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_forward(self, *args, **kwargs): + network_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + for filename in candidates: + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + try: + entry = network.NetworkOnDisk(name, filename) + except OSError: # should catch FileNotFoundError and PermissionError etc. + errors.report(f"Failed to load network {name} from {filename}", exc_info=True) + continue + + available_networks[name] = entry + + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + + available_network_aliases[name] = entry + available_network_aliases[entry.alias] = entry + + +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def infotext_pasted(infotext, params): + if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: + return # if the other extension is active, it will handle those fields, no need to do anything + + added = [] + + for k in params: + if not k.startswith("AddNet Model "): + continue + + num = k[13:] + + if params.get("AddNet Module " + num) != "LoRA": + continue + + name = params.get("AddNet Model " + num) + if name is None: + continue + + m = re_network_name.match(name) + if m: + name = m.group(1) + + multiplier = params.get("AddNet Weight A " + num, "1.0") + + added.append(f"") + + if added: + params["Prompt"] += "\n" + "".join(added) + + +available_networks = {} +available_network_aliases = {} +loaded_networks = [] +available_network_hash_lookup = {} +forbidden_network_aliases = {} + +list_available_networks() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index e650f469..81e6572a 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -4,18 +4,19 @@ import torch import gradio as gr from fastapi import FastAPI -import lora +import network +import networks import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): - torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora - torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora + torch.nn.Linear.forward = torch.nn.Linear_forward_before_network + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network def before_ui(): @@ -23,50 +24,50 @@ def before_ui(): extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) -if not hasattr(torch.nn, 'Linear_forward_before_lora'): - torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_forward_before_network'): + torch.nn.Linear_forward_before_network = torch.nn.Linear.forward -if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): - torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict +if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'): + torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict -if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): - torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_forward_before_network'): + torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward -if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): - torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): + torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict -if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): - torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): + torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward -if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): - torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'): + torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict -torch.nn.Linear.forward = lora.lora_Linear_forward -torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict -torch.nn.Conv2d.forward = lora.lora_Conv2d_forward -torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict -torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward -torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict +torch.nn.Linear.forward = networks.network_Linear_forward +torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict +torch.nn.Conv2d.forward = networks.network_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict -script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) script_callbacks.on_before_ui(before_ui) -script_callbacks.on_infotext_pasted(lora.infotext_pasted) +script_callbacks.on_infotext_pasted(networks.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { - "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras), + "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), })) shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { - "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), + "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), })) -def create_lora_json(obj: lora.LoraOnDisk): +def create_lora_json(obj: network.NetworkOnDisk): return { "name": obj.name, "alias": obj.alias, @@ -75,17 +76,17 @@ def create_lora_json(obj: lora.LoraOnDisk): } -def api_loras(_: gr.Blocks, app: FastAPI): +def api_networks(_: gr.Blocks, app: FastAPI): @app.get("/sdapi/v1/loras") async def get_loras(): - return [create_lora_json(obj) for obj in lora.available_loras.values()] + return [create_lora_json(obj) for obj in networks.available_networks.values()] @app.post("/sdapi/v1/refresh-loras") async def refresh_loras(): - return lora.list_available_loras() + return networks.list_available_networks() -script_callbacks.on_app_started(api_loras) +script_callbacks.on_app_started(api_networks) re_lora = re.compile(" Date: Mon, 17 Jul 2023 18:56:14 +0300 Subject: hide cards for networks of incompatible stable diffusion version in Lora extra networks interface --- extensions-builtin/Lora/network.py | 20 +++++++++++++ extensions-builtin/Lora/scripts/lora_script.py | 2 ++ extensions-builtin/Lora/ui_edit_user_metadata.py | 20 +++++++++---- extensions-builtin/Lora/ui_extra_networks_lora.py | 34 +++++++++++++++++++---- html/extra-networks-card.html | 2 +- javascript/extraNetworks.js | 2 +- modules/sd_models.py | 3 ++ modules/ui_extra_networks.py | 3 +- modules/ui_extra_networks_user_metadata.py | 7 ++++- style.css | 6 +++- 10 files changed, 84 insertions(+), 15 deletions(-) (limited to 'extensions-builtin/Lora/ui_extra_networks_lora.py') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index fe42dbdd..8ecfa29a 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -1,5 +1,6 @@ import os from collections import namedtuple +import enum from modules import sd_models, cache, errors, hashes, shared @@ -8,6 +9,13 @@ NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SDXL = 4 + + class NetworkOnDisk: def __init__(self, name, filename): self.name = name @@ -44,6 +52,18 @@ class NetworkOnDisk: '' ) + self.sd_version = self.detect_version() + + def detect_version(self): + if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): + return SdVersion.SDXL + elif str(self.metadata.get('ss_v2', "")) == "True": + return SdVersion.SD2 + elif len(self.metadata): + return SdVersion.SD1 + + return SdVersion.Unknown + def set_hash(self, v): self.hash = v self.shorthash = self.hash[0:12] diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index f478f718..cd28afc9 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -63,6 +63,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), + "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), + "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), })) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 354a1d68..c8730443 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -46,14 +46,17 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def __init__(self, ui, tabname, page): super().__init__(ui, tabname, page) + self.select_sd_version = None + self.taginfo = None self.edit_activation_text = None self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc + user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["notes"] = notes @@ -112,11 +115,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] return [ - *values[0:4], + *values[0:5], + item.get("sd_version", "Unknown"), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), - user_metadata.get('notes', ''), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -141,10 +144,15 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) return ", ".join(sorted(res)) + def create_extra_default_items_in_left_column(self): + + # this would be a lot better as gr.Radio but I can't make it work + self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + def create_editor(self): self.create_default_editor_elems() - self.taginfo = gr.HighlightedText(label="Tags") + self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) @@ -178,10 +186,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_description, self.html_filedata, self.html_preview, + self.edit_notes, + self.select_sd_version, self.taginfo, self.edit_activation_text, self.slider_preferred_weight, - self.edit_notes, row_random_prompt, random_prompt, ] @@ -192,6 +201,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) edited_components = [ self.edit_description, + self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, self.edit_notes, diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index b6171a26..4b32098b 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,7 +1,9 @@ import os + +import network import networks -from modules import shared, ui_extra_networks +from modules import shared, ui_extra_networks, paths from modules.ui_extra_networks import quote_js from ui_edit_user_metadata import LoraUserMetadataEditor @@ -13,14 +15,13 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def refresh(self): networks.list_available_networks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) path, ext = os.path.splitext(lora_on_disk.filename) alias = lora_on_disk.get_alias() - # in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string item = { "name": name, "filename": lora_on_disk.filename, @@ -30,6 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + "sd_version": lora_on_disk.sd_version.name, } self.read_user_metadata(item) @@ -40,15 +42,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + sd_version = item["user_metadata"].get("sd version") + if sd_version in network.SdVersion.__members__: + item["sd_version"] = sd_version + sd_version = network.SdVersion[sd_version] + else: + sd_version = lora_on_disk.sd_version + + if shared.opts.lora_show_all or not enable_filter: + pass + elif sd_version == network.SdVersion.Unknown: + model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1 + if model_version.name in shared.opts.lora_hide_unknown_for_versions: + return None + elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL: + return None + elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2: + return None + elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1: + return None + return item def list_items(self): for index, name in enumerate(networks.available_networks): item = self.create_item(name, index) - yield item + + if item is not None: + yield item def allowed_directories_for_previews(self): - return [shared.cmd_opts.lora_dir] + return [shared.cmd_opts.lora_dir, os.path.join(paths.models_path, "LyCORIS")] def create_user_metadata_editor(self, ui, tabname): return LoraUserMetadataEditor(ui, tabname, self) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index eb8b1a67..39674666 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,8 +1,8 @@
{background_image}
- {edit_button} {metadata_button} + {edit_button}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index e453094a..5582a6e5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -213,7 +213,7 @@ function popup(contents) { globalPopupInner.classList.add('global-popup-inner'); globalPopup.appendChild(globalPopupInner); - gradioApp().appendChild(globalPopup); + gradioApp().querySelector('.main').appendChild(globalPopup); } globalPopupInner.innerHTML = ''; diff --git a/modules/sd_models.py b/modules/sd_models.py index 729f03d7..4d9382dd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -290,6 +290,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer state_dict = get_checkpoint_state_dict(checkpoint_info, timer) model.is_sdxl = hasattr(model, 'conditioner') + model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') + model.is_sd1 = not model.is_sdxl and not model.is_sd2 + if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6c73998f..49612298 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -62,7 +62,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): page = next(iter([x for x in extra_pages if x.name == page]), None) try: - item = page.create_item(name) + item = page.create_item(name, enable_filter=False) + page.items[name] = item except Exception as e: errors.display(e, "creating item for extra network") item = page.items.get(name) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 01ff4e4b..63d4b503 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -42,6 +42,9 @@ class UserMetadataEditor: return user_metadata + def create_extra_default_items_in_left_column(self): + pass + def create_default_editor_elems(self): with gr.Row(): with gr.Column(scale=2): @@ -49,6 +52,8 @@ class UserMetadataEditor: self.edit_description = gr.Textbox(label="Description", lines=4) self.html_filedata = gr.HTML() + self.create_extra_default_items_in_left_column() + with gr.Column(scale=1, min_width=0): self.html_preview = gr.HTML() @@ -111,7 +116,7 @@ class UserMetadataEditor: table = '' + "".join(f"" for name, value in params) + '' - return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''), + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) diff --git a/style.css b/style.css index 8a66c3d2..e249cfd3 100644 --- a/style.css +++ b/style.css @@ -841,7 +841,7 @@ footer { .extra-network-cards .card .card-button { text-shadow: 2px 2px 3px black; - padding: 0.25em; + padding: 0.25em 0.1em; font-size: 200%; width: 1.5em; } @@ -957,6 +957,10 @@ div.block.gradio-box.edit-user-metadata { text-align: left; } +.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{ + padding: 0.3em 1em; +} + .edit-user-metadata .wrap.translucent{ background: var(--body-background-fill); } -- cgit v1.2.3