From 524d532b387732d4d32f237e792c7f201a934400 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 14:07:40 +0300 Subject: moved roll artist to built-in extensions --- .../roll-artist/scripts/roll-artist.py | 50 ++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 extensions-builtin/roll-artist/scripts/roll-artist.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/roll-artist/scripts/roll-artist.py b/extensions-builtin/roll-artist/scripts/roll-artist.py new file mode 100644 index 00000000..c3bc1fd0 --- /dev/null +++ b/extensions-builtin/roll-artist/scripts/roll-artist.py @@ -0,0 +1,50 @@ +import random + +from modules import script_callbacks, shared +import gradio as gr + +art_symbol = '\U0001f3a8' # 🎨 +global_prompt = None +related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" } + + +def roll_artist(prompt): + allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories]) + artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) + + return prompt + ", " + artist.name if prompt != '' else artist.name + + +def add_roll_button(prompt): + roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) + + roll.click( + fn=roll_artist, + _js="update_txt2img_tokens", + inputs=[ + prompt, + ], + outputs=[ + prompt, + ] + ) + + +def after_component(component, **kwargs): + global global_prompt + + elem_id = kwargs.get('elem_id', None) + if elem_id not in related_ids: + return + + if elem_id == "txt2img_prompt": + global_prompt = component + elif elem_id == "txt2img_clear_prompt": + add_roll_button(global_prompt) + elif elem_id == "img2img_prompt": + global_prompt = component + elif elem_id == "img2img_clear_prompt": + add_roll_button(global_prompt) + + +script_callbacks.on_after_component(after_component) -- cgit v1.2.3 From 924e222004ab54273806c5f2ca7a0e7cfa76ad83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 23:04:24 +0300 Subject: add option to show/hide warnings removed hiding warnings from LDSR fixed/reworked few places that produced warnings --- extensions-builtin/LDSR/ldsr_model_arch.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 0ad49f4e..bc11cc6e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -1,7 +1,6 @@ import os import gc import time -import warnings import numpy as np import torch @@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap from modules import shared, sd_hijack -warnings.filterwarnings("ignore", category=UserWarning) - cached_ldsr_model: torch.nn.Module = None -- cgit v1.2.3 From c12d7ddd725c485682c1caa025627c9ee936d743 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 15:58:32 +0300 Subject: add handling to some places in javascript that can potentially cause issues #6898 --- .../javascript/prompt-bracket-checker.js | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index eccfb0f9..251a1f57 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -93,10 +93,12 @@ function checkBrackets(evt) { } var shadowRootLoaded = setInterval(function() { - var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); - if(shadowTextArea.length < 1) { - return false; - } + var sahdowRoot = document.querySelector('gradio-app').shadowRoot; + if(! sahdowRoot) return false; + + var shadowTextArea = sahdowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); + if(shadowTextArea.length < 1) return false; + clearInterval(shadowRootLoaded); -- cgit v1.2.3 From 20a59ab3b171f398abd09087108c1ed087dbea9b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 20 Jan 2023 10:18:41 +0300 Subject: move token counter to the location of the prompt, add token counting for the negative prompt --- .../javascript/prompt-bracket-checker.js | 45 +++++++++++----------- 1 file changed, 23 insertions(+), 22 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 251a1f57..4a85c8eb 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -4,16 +4,10 @@ // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. -function checkBrackets(evt) { - textArea = evt.target; - tabName = evt.target.parentElement.parentElement.id.split("_")[0]; - counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter'); - - promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : ''; - - errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n'; - errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n'; - errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n'; +function checkBrackets(evt, textArea, counterElt) { + errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; + errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; + errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; openBracketRegExp = /\(/g; closeBracketRegExp = /\)/g; @@ -86,24 +80,31 @@ function checkBrackets(evt) { } if(counterElt.title != '') { - counterElt.style = 'color: #FF5555;'; + counterElt.classList.add('error'); } else { - counterElt.style = ''; + counterElt.classList.remove('error'); } } -var shadowRootLoaded = setInterval(function() { - var sahdowRoot = document.querySelector('gradio-app').shadowRoot; - if(! sahdowRoot) return false; +function setupBracketChecking(id_prompt, id_counter){ + var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); + var counter = gradioApp().getElementById(id_counter) + textarea.addEventListener("input", function(evt){ + checkBrackets(evt, textarea, counter) + }); +} - var shadowTextArea = sahdowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); - if(shadowTextArea.length < 1) return false; +var shadowRootLoaded = setInterval(function() { + var shadowRoot = document.querySelector('gradio-app').shadowRoot; + if(! shadowRoot) return false; + var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); + if(shadowTextArea.length < 1) return false; - clearInterval(shadowRootLoaded); + clearInterval(shadowRootLoaded); - document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets; - document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets; - document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets; - document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets; + setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') + setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') + setupBracketChecking('img2img_prompt', 'imgimg_token_counter') + setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') }, 1000); -- cgit v1.2.3 From 6d805b669e86233432f56ee1892d062103abe501 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 09:14:27 +0300 Subject: make CLIP interrogator download original text files if the directory does not exist remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them --- .../roll-artist/scripts/roll-artist.py | 50 ---------------------- 1 file changed, 50 deletions(-) delete mode 100644 extensions-builtin/roll-artist/scripts/roll-artist.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/roll-artist/scripts/roll-artist.py b/extensions-builtin/roll-artist/scripts/roll-artist.py deleted file mode 100644 index c3bc1fd0..00000000 --- a/extensions-builtin/roll-artist/scripts/roll-artist.py +++ /dev/null @@ -1,50 +0,0 @@ -import random - -from modules import script_callbacks, shared -import gradio as gr - -art_symbol = '\U0001f3a8' # 🎨 -global_prompt = None -related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" } - - -def roll_artist(prompt): - allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories]) - artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) - - return prompt + ", " + artist.name if prompt != '' else artist.name - - -def add_roll_button(prompt): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) - - roll.click( - fn=roll_artist, - _js="update_txt2img_tokens", - inputs=[ - prompt, - ], - outputs=[ - prompt, - ] - ) - - -def after_component(component, **kwargs): - global global_prompt - - elem_id = kwargs.get('elem_id', None) - if elem_id not in related_ids: - return - - if elem_id == "txt2img_prompt": - global_prompt = component - elif elem_id == "txt2img_clear_prompt": - add_roll_button(global_prompt) - elif elem_id == "img2img_prompt": - global_prompt = component - elif elem_id == "img2img_clear_prompt": - add_roll_button(global_prompt) - - -script_callbacks.on_after_component(after_component) -- cgit v1.2.3 From 855b9e3d1c5a1bd8c2d815d38a38bc7c410be5a8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 16:15:53 +0300 Subject: Lora support! update readme to reflect some recent changes --- extensions-builtin/Lora/extra_networks_lora.py | 20 +++ extensions-builtin/Lora/lora.py | 198 ++++++++++++++++++++++ extensions-builtin/Lora/scripts/lora_script.py | 30 ++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 35 ++++ 4 files changed, 283 insertions(+) create mode 100644 extensions-builtin/Lora/extra_networks_lora.py create mode 100644 extensions-builtin/Lora/lora.py create mode 100644 extensions-builtin/Lora/scripts/lora_script.py create mode 100644 extensions-builtin/Lora/ui_extra_networks_lora.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py new file mode 100644 index 00000000..8f2e753e --- /dev/null +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -0,0 +1,20 @@ +from modules import extra_networks +import lora + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lora') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + lora.load_loras(names, multipliers) + + def deactivate(self, p): + pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 00000000..7a3ad9a9 --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,198 @@ +import glob +import os +import re +import torch + +from modules import shared, devices, sd_models + +re_digits = re.compile(r"\d+") +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_text_block): + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers) + key, lora_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: + keys_failed_to_match.append(key_diffusers) + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.device, dtype=devices.dtype) + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_forward(module, input, res): + if len(loaded_loras) == 0: + return res + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None: + res = res + module.up(module.down(input)) * lora.multiplier + + return res + + +def lora_Linear_forward(self, input): + return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + + +def lora_Conv2d_forward(self, input): + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(lora_dir, exist_ok=True) + + candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +lora_dir = os.path.join(shared.models_path, "Lora") +available_loras = {} +loaded_loras = [] + +list_available_loras() + diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py new file mode 100644 index 00000000..60b9eb64 --- /dev/null +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -0,0 +1,30 @@ +import torch + +import lora +import extra_networks_lora +import ui_extra_networks_lora +from modules import script_callbacks, ui_extra_networks, extra_networks + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) + extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lora'): + torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): + torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward + +torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Conv2d.forward = lora.lora_Conv2d_forward + +script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py new file mode 100644 index 00000000..65397890 --- /dev/null +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -0,0 +1,35 @@ +import os +import lora + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + + def refresh(self): + lora.list_available_loras() + + def list_items(self): + for name, lora_on_disk in lora.available_loras.items(): + path, ext = os.path.splitext(lora_on_disk.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [lora.lora_dir] + -- cgit v1.2.3 From a2749ec655af93d96ea7ebed85e8c1e7c5072b02 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 18:52:45 +0300 Subject: load Lora from .ckpt also --- extensions-builtin/Lora/lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7a3ad9a9..6d860224 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -179,7 +179,10 @@ def list_available_loras(): os.makedirs(lora_dir, exist_ok=True) - candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + candidates = \ + glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -195,4 +198,3 @@ available_loras = {} loaded_loras = [] list_available_loras() - -- cgit v1.2.3 From 500d9a32c7b1f877c8f44159a9a10c594b545a80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:11:37 +0300 Subject: add --lora-dir commandline option --- extensions-builtin/Lora/lora.py | 9 ++++----- extensions-builtin/Lora/preload.py | 6 ++++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 extensions-builtin/Lora/preload.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 6d860224..da1797dc 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -177,12 +177,12 @@ def lora_Conv2d_forward(self, input): def list_available_loras(): available_loras.clear() - os.makedirs(lora_dir, exist_ok=True) + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = \ - glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -193,7 +193,6 @@ def list_available_loras(): available_loras[name] = LoraOnDisk(name, filename) -lora_dir = os.path.join(shared.models_path, "Lora") available_loras = {} loaded_loras = [] diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py new file mode 100644 index 00000000..863dc5c0 --- /dev/null +++ b/extensions-builtin/Lora/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 65397890..4406f8a0 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -31,5 +31,5 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): } def allowed_directories_for_previews(self): - return [lora.lora_dir] + return [shared.cmd_opts.lora_dir] -- cgit v1.2.3 From fe7a623e6b7e04bab2cfc96e8fd6cf49b48daee1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 00:02:41 +0300 Subject: add a slider for default value of added extra networks --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 4406f8a0..54a80d36 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,3 +1,4 @@ +import json import os import lora @@ -26,7 +27,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } -- cgit v1.2.3 From e407d1af897a7896d8c81e32dc86e7eb753ce207 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:12:51 +0300 Subject: add support for loras trained on kohya's scripts 0.4.0 (alphas) --- extensions-builtin/Lora/lora.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index da1797dc..220e64ff 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -92,6 +92,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers) continue + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: @@ -104,17 +113,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype) - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - if lora_key == "lora_up.weight": lora_module.up = module elif lora_key == "lora_down.weight": lora_module.down = module else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") @@ -161,7 +165,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier + res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] return res -- cgit v1.2.3 From c6f20f72629f3c417f10db2289d131441c6832f5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:52:55 +0300 Subject: make loras before 0.4.0 ALSO work --- extensions-builtin/Lora/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 220e64ff..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self): self.up = None self.down = None + self.alpha = None def assign_lora_names_to_compvis_modules(sd_model): @@ -165,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res -- cgit v1.2.3 From f99352582084890b9167c1bf8699865bea0cef5f Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 21:50:59 -0500 Subject: Make SwinIR interruptible --- extensions-builtin/SwinIR/scripts/swinir_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 9a74b253..3479760a 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts +from modules.shared import cmd_opts, opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: + if state.interrupted: + break + for w_idx in w_idx_list: + if state.interrupted: + break + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) -- cgit v1.2.3 From 3c47b050367ee220dcfed7be7883878417735614 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 22:00:27 -0500 Subject: Also make SwinIR skippable --- extensions-builtin/SwinIR/scripts/swinir_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 3479760a..e8783bca 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -145,11 +145,11 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break for w_idx in w_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] -- cgit v1.2.3 From 1bfec873fa13d803f3d4ac2a12bf6983838233fe Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 11:29:46 +0300 Subject: add an experimental option to apply loras to outputs rather than inputs --- extensions-builtin/Lora/lora.py | 5 ++++- extensions-builtin/Lora/scripts/lora_script.py | 7 ++++++- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 137e58f7..cb8f1d36 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -166,7 +166,10 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + if shared.opts.lora_apply_to_outputs and res.shape == input.shape: + res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + else: + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 60b9eb64..544b228d 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -3,7 +3,7 @@ import torch import lora import extra_networks_lora import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks +from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): @@ -28,3 +28,8 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) script_callbacks.on_before_ui(before_ui) + + +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), +})) -- cgit v1.2.3