diff options
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/Lora/extra_networks_lora.py | 10 | ||||
-rw-r--r-- | extensions-builtin/Lora/network.py | 7 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_norm.py | 28 | ||||
-rw-r--r-- | extensions-builtin/Lora/networks.py | 140 | ||||
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 22 | ||||
-rw-r--r-- | extensions-builtin/Lora/ui_extra_networks_lora.py | 3 | ||||
-rw-r--r-- | extensions-builtin/extra-options-section/scripts/extra_options_section.py | 12 |
7 files changed, 180 insertions, 42 deletions
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index ba2945c6..005ff32c 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def __init__(self):
super().__init__('lora')
+ self.errors = {}
+ """mapping of network names to the number of errors the network had during operation"""
+
def activate(self, p, params_list):
additional = shared.opts.sd_lora
+ self.errors.clear()
+
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p):
- pass
+ if self.errors:
+ p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
+
+ self.errors.clear()
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 0a18d69e..d8e8dfb7 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -133,7 +133,7 @@ class NetworkModule: return 1.0
- def finalize_updown(self, updown, orig_weight, output_shape):
+ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
@@ -145,7 +145,10 @@ class NetworkModule: if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
- return updown * self.calc_scale() * self.multiplier()
+ if ex_bias is not None:
+ ex_bias = ex_bias * self.multiplier()
+
+ return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target):
raise NotImplementedError()
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py new file mode 100644 index 00000000..ce450158 --- /dev/null +++ b/extensions-builtin/Lora/network_norm.py @@ -0,0 +1,28 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7e3415ac..22fdff4a 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import logging
import os
import re
@@ -7,6 +8,7 @@ import network_hada import network_ia3
import network_lokr
import network_full
+import network_norm
import torch
from typing import Union
@@ -19,6 +21,7 @@ module_types = [ network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
+ network_norm.ModuleTypeNorm(),
]
@@ -31,6 +34,8 @@ suffix_conversion = { "resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
+ "norm1": "in_layers_0",
+ "norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
@@ -190,7 +195,7 @@ def load_network(name, network_on_disk): net.modules[key] = net_module
if keys_failed_to_match:
- print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
+ logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
return net
@@ -203,7 +208,6 @@ def purge_networks_from_memory(): devices.torch_gc()
-
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {}
@@ -244,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if net is None:
failed_to_load_networks.append(name)
- print(f"Couldn't find network with name {name}")
+ logging.info(f"Couldn't find network with name {name}")
continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@@ -253,25 +257,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No loaded_networks.append(net)
if failed_to_load_networks:
- sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
+ sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
-def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
+ bias_backup = getattr(self, "network_bias_backup", None)
- if weights_backup is None:
+ if weights_backup is None and bias_backup is None:
return
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
+ if weights_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+ if bias_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias.copy_(bias_backup)
+ else:
+ self.bias.copy_(bias_backup)
else:
- self.weight.copy_(weights_backup)
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias = None
+ else:
+ self.bias = None
-def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
@@ -294,21 +311,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_weights_backup = weights_backup
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if bias_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
+ bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
+ elif getattr(self, 'bias', None) is not None:
+ bias_backup = self.bias.to(devices.cpu, copy=True)
+ else:
+ bias_backup = None
+ self.network_bias_backup = bias_backup
+
if current_names != wanted_names:
network_restore_weights_from_backup(self)
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
- with torch.no_grad():
- updown = module.calc_updown(self.weight)
-
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
- # inpainting model. zero pad updown to make channel[1] 4 to 9
- updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+ try:
+ with torch.no_grad():
+ updown, ex_bias = module.calc_updown(self.weight)
+
+ if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+
+ self.weight += updown
+ if ex_bias is not None and hasattr(self, 'bias'):
+ if self.bias is None:
+ self.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.bias += ex_bias
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- self.weight += updown
- continue
+ continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None)
@@ -316,21 +353,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
- with torch.no_grad():
- updown_q = module_q.calc_updown(self.in_proj_weight)
- updown_k = module_k.calc_updown(self.in_proj_weight)
- updown_v = module_v.calc_updown(self.in_proj_weight)
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out = module_out.calc_updown(self.out_proj.weight)
-
- self.in_proj_weight += updown_qkv
- self.out_proj.weight += updown_out
- continue
+ try:
+ with torch.no_grad():
+ updown_q, _ = module_q.calc_updown(self.in_proj_weight)
+ updown_k, _ = module_k.calc_updown(self.in_proj_weight)
+ updown_v, _ = module_v.calc_updown(self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+ updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
+
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += updown_out
+ if ex_bias is not None:
+ if self.out_proj.bias is None:
+ self.out_proj.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.out_proj.bias += ex_bias
+
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+
+ continue
if module is None:
continue
- print(f'failed to calculate network weights for layer {network_layer_name}')
+ logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names
@@ -397,6 +446,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs): return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+def network_GroupNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.GroupNorm_forward_before_network(self, input)
+
+
+def network_GroupNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def network_LayerNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.LayerNorm_forward_before_network(self, input)
+
+
+def network_LayerNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
+
+
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
@@ -473,6 +552,7 @@ def infotext_pasted(infotext, params): if added:
params["Prompt"] += "\n" + "".join(added)
+extra_network_lora = None
available_networks = {}
available_network_aliases = {}
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 6ab8b6e7..4c6e774a 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -23,9 +23,9 @@ def unload(): def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
- extra_network = extra_networks_lora.ExtraNetworkLora()
- extra_networks.register_extra_network(extra_network)
- extra_networks.register_extra_network_alias(extra_network, "lyco")
+ networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
+ extra_networks.register_extra_network(networks.extra_network_lora)
+ extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
if not hasattr(torch.nn, 'Linear_forward_before_network'):
@@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
+if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
+ torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
+
+if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
+ torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
+
+if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
+ torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
+
+if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
+ torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
+
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
@@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
+torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
+torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
+torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
+torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 3629e5c0..55409a78 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): item = {
"name": name,
"filename": lora_on_disk.filename,
+ "shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(lora_on_disk.filename),
+ "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 588b64d2..983f87ff 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -22,22 +22,23 @@ class ExtraOptionsSection(scripts.Script): self.comps = []
self.setting_names = []
self.infotext_fields = []
+ extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
- with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group():
+ with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
- row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols)
+ row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
for row in range(row_count):
with gr.Row():
for col in range(shared.opts.extra_options_cols):
index = row * shared.opts.extra_options_cols + col
- if index >= len(shared.opts.extra_options):
+ if index >= len(extra_options):
break
- setting_name = shared.opts.extra_options[index]
+ setting_name = extra_options[index]
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
@@ -64,7 +65,8 @@ class ExtraOptionsSection(scripts.Script): shared.options_templates.update(shared.options_section(('ui', "User interface"), {
- "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
+ "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
+ "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
}))
|