From c3eced22fc7b9da4fbb2f55f2d53a7e5e511cfbd Mon Sep 17 00:00:00 2001 From: Leo Mozoloa Date: Thu, 4 May 2023 16:14:33 +0200 Subject: Fix some Lora's not working --- extensions-builtin/Lora/lora.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 6f246921..bcf36d77 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -165,8 +165,10 @@ def load_lora(name, filename): module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.MultiheadAttention: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.Conv2d: + elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') continue @@ -232,6 +234,8 @@ def lora_calc_updown(lora, module, target): if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) else: updown = up @ down -- cgit v1.2.3 From 2cb3b0be1def43e0d225b45a640592a7999a0d69 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 7 May 2023 08:25:34 +0300 Subject: if present, use Lora's "ss_output_name" field to refer to it in prompt --- extensions-builtin/Lora/extra_networks_lora.py | 1 + extensions-builtin/Lora/lora.py | 13 ++++++++++--- extensions-builtin/Lora/ui_extra_networks_lora.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 45f899fc..ccb249ac 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,6 +1,7 @@ from modules import extra_networks, shared import lora + class ExtraNetworkLora(extra_networks.ExtraNetwork): def __init__(self): super().__init__('lora') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 6f246921..e3ca7fa2 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -93,6 +93,7 @@ class LoraOnDisk: self.metadata = m self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text + self.alias = self.metadata.get('ss_output_name', self.name) class LoraModule: @@ -199,11 +200,11 @@ def load_loras(names, multipliers=None): loaded_loras.clear() - loras_on_disk = [available_loras.get(name, None) for name in names] + loras_on_disk = [available_lora_aliases.get(name, None) for name in names] if any([x is None for x in loras_on_disk]): list_available_loras() - loras_on_disk = [available_loras.get(name, None) for name in names] + loras_on_disk = [available_lora_aliases.get(name, None) for name in names] for i, name in enumerate(names): lora = already_loaded.get(name, None) @@ -343,6 +344,7 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): def list_available_loras(): available_loras.clear() + available_lora_aliases.clear() os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) @@ -356,11 +358,16 @@ def list_available_loras(): continue name = os.path.splitext(os.path.basename(filename))[0] + entry = LoraOnDisk(name, filename) - available_loras[name] = LoraOnDisk(name, filename) + available_loras[name] = entry + + available_lora_aliases[name] = entry + available_lora_aliases[entry.alias] = entry available_loras = {} +available_lora_aliases = {} loaded_loras = [] list_available_loras() diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 68b11332..a0edbc1e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -21,7 +21,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), - "prompt": json.dumps(f""), + "prompt": json.dumps(f""), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, } -- cgit v1.2.3 From 2473bafa67b2dd0077f752bf23e4bf8f89990a8c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 8 May 2023 07:28:30 +0300 Subject: read infotext params from the other extension for Lora if it's not active --- extensions-builtin/Lora/lora.py | 36 +++++++++++++++++++++++++- extensions-builtin/Lora/scripts/lora_script.py | 1 + 2 files changed, 36 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index e3ca7fa2..94ec021b 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -4,7 +4,7 @@ import re import torch from typing import Union -from modules import shared, devices, sd_models, errors +from modules import shared, devices, sd_models, errors, scripts metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} @@ -366,6 +366,40 @@ def list_available_loras(): available_lora_aliases[entry.alias] = entry +re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def infotext_pasted(infotext, params): + if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: + return # if the other extension is active, it will handle those fields, no need to do anything + + added = [] + + for k, v in params.items(): + if not k.startswith("AddNet Model "): + continue + + num = k[13:] + + if params.get("AddNet Module " + num) != "LoRA": + continue + + name = params.get("AddNet Model " + num) + if name is None: + continue + + m = re_lora_name.match(name) + if m: + name = m.group(1) + + multiplier = params.get("AddNet Weight A " + num, "1.0") + + added.append(f"") + + if added: + params["Prompt"] += "\n" + "".join(added) + + available_loras = {} available_lora_aliases = {} loaded_loras = [] diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 3fc38ab9..2f2267a2 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -49,6 +49,7 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) script_callbacks.on_before_ui(before_ui) +script_callbacks.on_infotext_pasted(lora.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { -- cgit v1.2.3 From 083dc3c76ab7dbc7b2b04f3396d4f5280b002906 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 8 May 2023 11:33:45 +0300 Subject: directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for create HTML for extra network pages only on demand allow directories starting with . to still list their models for lora, checkpoints, etc keep "search" filter for extra networks when user refreshes the page --- extensions-builtin/Lora/lora.py | 6 +--- html/extra-networks-card.html | 2 +- javascript/extraNetworks.js | 25 ++++++++++++--- modules/modelloader.py | 27 +++++----------- modules/shared.py | 17 ++++++++++ modules/ui_extra_networks.py | 69 +++++++++++++++++++++++++++++------------ 6 files changed, 97 insertions(+), 49 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 83c1c6fd..83933639 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -352,11 +352,7 @@ def list_available_loras(): os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - candidates = \ - glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ - glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ - glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) - + candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in sorted(candidates, key=str.lower): if os.path.isdir(filename): continue diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index ef4b613a..1d546217 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -6,7 +6,7 @@ - + {name} {description} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c8f6b386..c85bc79a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,4 +1,3 @@ - function setupExtraNetworksForTab(tabname){ gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') @@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){ tabs.appendChild(search) tabs.appendChild(refresh) - search.addEventListener("input", function(){ + var applyFilter = function(){ var searchTerm = search.value.toLowerCase() gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ + var searchOnly = elem.querySelector('.search_only') var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase() - elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" + + var visible = text.indexOf(searchTerm) != -1 + + if(searchOnly && searchTerm.length < 4){ + visible = false + } + + elem.style.display = visible ? "" : "none" }) - }); + } + + search.addEventListener("input", applyFilter); + applyFilter(); + + extraNetworksApplyFilter[tabname] = applyFilter; +} + +function applyExtraNetworkFilter(tabname){ + setTimeout(extraNetworksApplyFilter[tabname], 1); } +var extraNetworksApplyFilter = {} var activePromptTextarea = {}; function setupExtraNetworks(){ diff --git a/modules/modelloader.py b/modules/modelloader.py index 522affc6..f2274488 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -22,9 +22,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None """ output = [] - if ext_filter is None: - ext_filter = [] - try: places = [] @@ -39,22 +36,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None places.append(model_path) for place in places: - if os.path.exists(place): - for file in glob.iglob(place + '**/**', recursive=True): - full_path = file - if os.path.isdir(full_path): - continue - if os.path.islink(full_path) and not os.path.exists(full_path): - print(f"Skipping broken symlink: {full_path}") - continue - if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): - continue - if len(ext_filter) != 0: - model_name, extension = os.path.splitext(file) - if extension not in ext_filter: - continue - if file not in output: - output.append(full_path) + for full_path in shared.walk_files(place, allowed_extensions=ext_filter): + if os.path.islink(full_path) and not os.path.exists(full_path): + print(f"Skipping broken symlink: {full_path}") + continue + if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): + continue + if full_path not in output: + output.append(full_path) if model_url is not None and len(output) == 0: if download_name is not None: diff --git a/modules/shared.py b/modules/shared.py index 91aac1a3..dd374713 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -726,3 +726,20 @@ def html(filename): return file.read() return "" + + +def walk_files(path, allowed_extensions=None): + if not os.path.exists(path): + return + + if allowed_extensions is not None: + allowed_extensions = set(allowed_extensions) + + for root, dirs, files in os.walk(path): + for filename in files: + if allowed_extensions is not None: + _, ext = os.path.splitext(filename) + if ext not in allowed_extensions: + continue + + yield os.path.join(root, filename) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index aa2f5d1b..86c05a55 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -89,19 +89,22 @@ class ExtraNetworksPage: subdirs = {} for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: - for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): - if not os.path.isdir(x): - continue + for root, dirs, files in os.walk(parentdir): + for dirname in dirs: + x = os.path.join(root, dirname) - subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - while subdir.startswith("/"): - subdir = subdir[1:] + if not os.path.isdir(x): + continue - is_empty = len(os.listdir(x)) == 0 - if not is_empty and not subdir.endswith("/"): - subdir = subdir + "/" + subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") + while subdir.startswith("/"): + subdir = subdir[1:] - subdirs[subdir] = 1 + is_empty = len(os.listdir(x)) == 0 + if not is_empty and not subdir.endswith("/"): + subdir = subdir + "/" + + subdirs[subdir] = 1 if subdirs: subdirs = {"": 1, **subdirs} @@ -157,8 +160,20 @@ class ExtraNetworksPage: if metadata: metadata_button = f"" + local_path = "" + filename = item.get("filename", "") + for reldir in self.allowed_directories_for_previews(): + absdir = os.path.abspath(reldir) + + if filename.startswith(absdir): + local_path = filename[len(absdir):] + + # if this is true, the item must not be show in the default view, and must instead only be + # shown when searching for it + serach_only = "/." in local_path or "\\." in local_path + args = { - "style": f"'{height}{width}{background_image}'", + "style": f"'display: none; {height}{width}{background_image}'", "prompt": item.get("prompt", None), "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), @@ -168,6 +183,7 @@ class ExtraNetworksPage: "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), "metadata_button": metadata_button, + "serach_only": " search_only" if serach_only else "", } return self.card_page.format(**args) @@ -209,6 +225,11 @@ def intialize(): class ExtraNetworksUi: def __init__(self): self.pages = None + """gradio HTML components related to extra networks' pages""" + + self.page_contents = None + """HTML content of the above; empty initially, filled when extra pages have to be shown""" + self.stored_extra_pages = None self.button_save_preview = None @@ -236,17 +257,22 @@ def pages_in_preferred_order(pages): def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] + ui.pages_contents = [] ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: for page in ui.stored_extra_pages: - with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")): + page_id = page.title.lower().replace(" ", "_") - page_elem = gr.HTML(page.create_html(ui.tabname)) + with gr.Tab(page.title, id=page_id): + elem_id = f"{tabname}_{page_id}_cards_html" + page_elem = gr.HTML('', elem_id=elem_id) ui.pages.append(page_elem) - filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) + + gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) @@ -254,19 +280,22 @@ def create_ui(container, button, tabname): def toggle_visibility(is_visible): is_visible = not is_visible - return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) + + if is_visible and not ui.pages_contents: + refresh() + + return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents state_visible = gr.State(value=False) - button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button]) + button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages]) def refresh(): - res = [] - for pg in ui.stored_extra_pages: pg.refresh() - res.append(pg.create_html(ui.tabname)) - return res + ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] + + return ui.pages_contents button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) -- cgit v1.2.3 From ec0da07236d286f37c86f9cd92642e24381dd6a5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 8 May 2023 12:07:43 +0300 Subject: Lora: add an option to use old method of applying loras --- extensions-builtin/Lora/lora.py | 56 +++++++++++++++++++++++--- extensions-builtin/Lora/scripts/lora_script.py | 5 +++ 2 files changed, 55 insertions(+), 6 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 83933639..d488b5ae 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -245,6 +245,19 @@ def lora_calc_updown(lora, module, target): return updown +def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + weights_backup = getattr(self, "lora_weights_backup", None) + + if weights_backup is None: + return + + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): """ Applies the currently selected set of Loras to the weights of torch layer self. @@ -269,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu self.lora_weights_backup = weights_backup if current_names != wanted_names: - if weights_backup is not None: - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) - else: - self.weight.copy_(weights_backup) + lora_restore_weights_from_backup(self) for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) @@ -305,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu setattr(self, "lora_current_names", wanted_names) +def lora_forward(module, input, original_forward): + """ + Old way of applying Lora by executing operations during layer's forward. + Stacking many loras this way results in big performance degradation. + """ + + if len(loaded_loras) == 0: + return original_forward(module, input) + + input = devices.cond_cast_unet(input) + + lora_restore_weights_from_backup(module) + lora_reset_cached_weight(module) + + res = original_forward(module, input) + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is None: + continue + + module.up.to(device=devices.device) + module.down.to(device=devices.device) + + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return res + + def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): setattr(self, "lora_current_names", ()) setattr(self, "lora_weights_backup", None) def lora_Linear_forward(self, input): + if shared.opts.lora_functional: + return lora_forward(self, input, torch.nn.Linear_forward_before_lora) + lora_apply_weights(self) return torch.nn.Linear_forward_before_lora(self, input) @@ -323,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs): def lora_Conv2d_forward(self, input): + if shared.opts.lora_functional: + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora) + lora_apply_weights(self) return torch.nn.Conv2d_forward_before_lora(self, input) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2f2267a2..a67b8a69 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -55,3 +55,8 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), })) + + +shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { + "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), +})) -- cgit v1.2.3 From 34a82a345abe89faafbd43fa34f40dd110559071 Mon Sep 17 00:00:00 2001 From: Sayo Date: Mon, 8 May 2023 19:55:05 +0800 Subject: Add api method to get LoRA models --- extensions-builtin/Lora/lora.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index d488b5ae..8fc1ddca 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -2,7 +2,9 @@ import glob import os import re import torch -from typing import Union +from typing import Union, List, Optional +from fastapi import FastAPI +import gradio as gr from modules import shared, devices, sd_models, errors, scripts @@ -443,9 +445,19 @@ def infotext_pasted(infotext, params): if added: params["Prompt"] += "\n" + "".join(added) +def api(_: gr.Blocks, app: FastAPI): + @app.get("/sdapi/v1/loras") + async def getloras(): + return [{"name": name, "path": available_loras[name].filename, "prompt": ""} for name in available_loras] + available_loras = {} available_lora_aliases = {} loaded_loras = [] list_available_loras() +try: + import modules.script_callbacks as script_callbacks + script_callbacks.on_app_started(api) +except: + pass \ No newline at end of file -- cgit v1.2.3 From f9abe4cddcdc6704be02633d9d5ed9640d6b9008 Mon Sep 17 00:00:00 2001 From: Sayo Date: Mon, 8 May 2023 20:38:10 +0800 Subject: Add api method to get LoRA models with prompt --- extensions-builtin/Lora/lora.py | 13 +++---------- extensions-builtin/Lora/scripts/api.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) create mode 100644 extensions-builtin/Lora/scripts/api.py (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 8fc1ddca..05162e41 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -2,9 +2,8 @@ import glob import os import re import torch -from typing import Union, List, Optional -from fastapi import FastAPI -import gradio as gr +from typing import Union +import scripts.api as api from modules import shared, devices, sd_models, errors, scripts @@ -445,12 +444,6 @@ def infotext_pasted(infotext, params): if added: params["Prompt"] += "\n" + "".join(added) -def api(_: gr.Blocks, app: FastAPI): - @app.get("/sdapi/v1/loras") - async def getloras(): - return [{"name": name, "path": available_loras[name].filename, "prompt": ""} for name in available_loras] - - available_loras = {} available_lora_aliases = {} loaded_loras = [] @@ -458,6 +451,6 @@ loaded_loras = [] list_available_loras() try: import modules.script_callbacks as script_callbacks - script_callbacks.on_app_started(api) + script_callbacks.on_app_started(api.api) except: pass \ No newline at end of file diff --git a/extensions-builtin/Lora/scripts/api.py b/extensions-builtin/Lora/scripts/api.py new file mode 100644 index 00000000..f1f2e2fc --- /dev/null +++ b/extensions-builtin/Lora/scripts/api.py @@ -0,0 +1,31 @@ +from fastapi import FastAPI +import gradio as gr +import json +import os +import lora + +def get_lora_prompts(path): + directory, filename = os.path.split(path) + name_without_ext = os.path.splitext(filename)[0] + new_filename = name_without_ext + '.civitai.info' + try: + new_path = os.path.join(directory, new_filename) + if os.path.exists(new_path): + with open(new_path, 'r') as f: + data = json.load(f) + trained_words = data.get('trainedWords', []) + if len(trained_words) > 0: + result = ','.join(trained_words) + return result + else: + return '' + else: + return '' + except Exception as e: + return '' + +def api(_: gr.Blocks, app: FastAPI): + @app.get("/sdapi/v1/loras") + async def get_loras(): + return [{"name": name, "path": lora.available_loras[name].filename, "prompt": get_lora_prompts(lora.available_loras[name].filename)} for name in lora.available_loras] + -- cgit v1.2.3 From eb95809501068a38f2b6bdb01b6ae5b86ff7ae87 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 9 May 2023 11:25:46 +0300 Subject: rework loras api --- extensions-builtin/Lora/lora.py | 6 ----- extensions-builtin/Lora/scripts/api.py | 31 -------------------------- extensions-builtin/Lora/scripts/lora_script.py | 21 ++++++++++++++++- 3 files changed, 20 insertions(+), 38 deletions(-) delete mode 100644 extensions-builtin/Lora/scripts/api.py (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 05162e41..ba1293df 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -3,7 +3,6 @@ import os import re import torch from typing import Union -import scripts.api as api from modules import shared, devices, sd_models, errors, scripts @@ -449,8 +448,3 @@ available_lora_aliases = {} loaded_loras = [] list_available_loras() -try: - import modules.script_callbacks as script_callbacks - script_callbacks.on_app_started(api.api) -except: - pass \ No newline at end of file diff --git a/extensions-builtin/Lora/scripts/api.py b/extensions-builtin/Lora/scripts/api.py deleted file mode 100644 index f1f2e2fc..00000000 --- a/extensions-builtin/Lora/scripts/api.py +++ /dev/null @@ -1,31 +0,0 @@ -from fastapi import FastAPI -import gradio as gr -import json -import os -import lora - -def get_lora_prompts(path): - directory, filename = os.path.split(path) - name_without_ext = os.path.splitext(filename)[0] - new_filename = name_without_ext + '.civitai.info' - try: - new_path = os.path.join(directory, new_filename) - if os.path.exists(new_path): - with open(new_path, 'r') as f: - data = json.load(f) - trained_words = data.get('trainedWords', []) - if len(trained_words) > 0: - result = ','.join(trained_words) - return result - else: - return '' - else: - return '' - except Exception as e: - return '' - -def api(_: gr.Blocks, app: FastAPI): - @app.get("/sdapi/v1/loras") - async def get_loras(): - return [{"name": name, "path": lora.available_loras[name].filename, "prompt": get_lora_prompts(lora.available_loras[name].filename)} for name in lora.available_loras] - diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index a67b8a69..7db971fd 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,12 +1,12 @@ import torch import gradio as gr +from fastapi import FastAPI import lora import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared - def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora @@ -60,3 +60,22 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), })) + + +def create_lora_json(obj: lora.LoraOnDisk): + return { + "name": obj.name, + "alias": obj.alias, + "path": obj.filename, + "metadata": obj.metadata, + } + + +def api_loras(_: gr.Blocks, app: FastAPI): + @app.get("/sdapi/v1/loras") + async def get_loras(): + return [create_lora_json(obj) for obj in lora.available_loras.values()] + + +script_callbacks.on_app_started(api_loras) + -- cgit v1.2.3