From 855b9e3d1c5a1bd8c2d815d38a38bc7c410be5a8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 16:15:53 +0300 Subject: Lora support! update readme to reflect some recent changes --- extensions-builtin/Lora/lora.py | 198 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 extensions-builtin/Lora/lora.py (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 00000000..7a3ad9a9 --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,198 @@ +import glob +import os +import re +import torch + +from modules import shared, devices, sd_models + +re_digits = re.compile(r"\d+") +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_text_block): + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers) + key, lora_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: + keys_failed_to_match.append(key_diffusers) + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.device, dtype=devices.dtype) + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_forward(module, input, res): + if len(loaded_loras) == 0: + return res + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None: + res = res + module.up(module.down(input)) * lora.multiplier + + return res + + +def lora_Linear_forward(self, input): + return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + + +def lora_Conv2d_forward(self, input): + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(lora_dir, exist_ok=True) + + candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +lora_dir = os.path.join(shared.models_path, "Lora") +available_loras = {} +loaded_loras = [] + +list_available_loras() + -- cgit v1.2.3 From a2749ec655af93d96ea7ebed85e8c1e7c5072b02 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 18:52:45 +0300 Subject: load Lora from .ckpt also --- extensions-builtin/Lora/lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7a3ad9a9..6d860224 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -179,7 +179,10 @@ def list_available_loras(): os.makedirs(lora_dir, exist_ok=True) - candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + candidates = \ + glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -195,4 +198,3 @@ available_loras = {} loaded_loras = [] list_available_loras() - -- cgit v1.2.3 From 500d9a32c7b1f877c8f44159a9a10c594b545a80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:11:37 +0300 Subject: add --lora-dir commandline option --- extensions-builtin/Lora/lora.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 6d860224..da1797dc 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -177,12 +177,12 @@ def lora_Conv2d_forward(self, input): def list_available_loras(): available_loras.clear() - os.makedirs(lora_dir, exist_ok=True) + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = \ - glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -193,7 +193,6 @@ def list_available_loras(): available_loras[name] = LoraOnDisk(name, filename) -lora_dir = os.path.join(shared.models_path, "Lora") available_loras = {} loaded_loras = [] -- cgit v1.2.3 From e407d1af897a7896d8c81e32dc86e7eb753ce207 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:12:51 +0300 Subject: add support for loras trained on kohya's scripts 0.4.0 (alphas) --- extensions-builtin/Lora/lora.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index da1797dc..220e64ff 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -92,6 +92,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers) continue + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: @@ -104,17 +113,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype) - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - if lora_key == "lora_up.weight": lora_module.up = module elif lora_key == "lora_down.weight": lora_module.down = module else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") @@ -161,7 +165,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier + res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] return res -- cgit v1.2.3 From c6f20f72629f3c417f10db2289d131441c6832f5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:52:55 +0300 Subject: make loras before 0.4.0 ALSO work --- extensions-builtin/Lora/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 220e64ff..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self): self.up = None self.down = None + self.alpha = None def assign_lora_names_to_compvis_modules(sd_model): @@ -165,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res -- cgit v1.2.3 From 1bfec873fa13d803f3d4ac2a12bf6983838233fe Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 11:29:46 +0300 Subject: add an experimental option to apply loras to outputs rather than inputs --- extensions-builtin/Lora/lora.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 137e58f7..cb8f1d36 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -166,7 +166,10 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + if shared.opts.lora_apply_to_outputs and res.shape == input.shape: + res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + else: + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res -- cgit v1.2.3