diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-13 12:07:37 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-13 12:07:37 +0000 |
commit | d8419762c1454ba51baa710d9ce8e762efc056ef (patch) | |
tree | fa02a021336d1e19f64118d1f0554ba89bc3ff39 /extensions-builtin/Lora/networks.py | |
parent | da80d649fd6a6083be02aca5695367bd25abf0d5 (diff) | |
download | stable-diffusion-webui-gfx803-d8419762c1454ba51baa710d9ce8e762efc056ef.tar.gz stable-diffusion-webui-gfx803-d8419762c1454ba51baa710d9ce8e762efc056ef.tar.bz2 stable-diffusion-webui-gfx803-d8419762c1454ba51baa710d9ce8e762efc056ef.zip |
Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r-- | extensions-builtin/Lora/networks.py | 61 |
1 files changed, 37 insertions, 24 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index ba621139..c252ed9e 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import logging
import os
import re
@@ -194,7 +195,7 @@ def load_network(name, network_on_disk): net.modules[key] = net_module
if keys_failed_to_match:
- print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
+ logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
return net
@@ -207,7 +208,6 @@ def purge_networks_from_memory(): devices.torch_gc()
-
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {}
@@ -248,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if net is None:
failed_to_load_networks.append(name)
- print(f"Couldn't find network with name {name}")
+ logging.info(f"Couldn't find network with name {name}")
continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@@ -257,7 +257,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No loaded_networks.append(net)
if failed_to_load_networks:
- sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
+ sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
@@ -314,17 +314,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
- with torch.no_grad():
- updown, ex_bias = module.calc_updown(self.weight)
+ try:
+ with torch.no_grad():
+ updown, ex_bias = module.calc_updown(self.weight)
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
- # inpainting model. zero pad updown to make channel[1] 4 to 9
- updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+ if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
- if ex_bias is not None and getattr(self, 'bias', None) is not None:
- self.bias += ex_bias
- continue
+ self.weight += updown
+ if ex_bias is not None and getattr(self, 'bias', None) is not None:
+ self.bias += ex_bias
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+
+ continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None)
@@ -332,21 +337,28 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
- with torch.no_grad():
- updown_q = module_q.calc_updown(self.in_proj_weight)
- updown_k = module_k.calc_updown(self.in_proj_weight)
- updown_v = module_v.calc_updown(self.in_proj_weight)
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out = module_out.calc_updown(self.out_proj.weight)
-
- self.in_proj_weight += updown_qkv
- self.out_proj.weight += updown_out
- continue
+ try:
+ with torch.no_grad():
+ updown_q = module_q.calc_updown(self.in_proj_weight)
+ updown_k = module_k.calc_updown(self.in_proj_weight)
+ updown_v = module_v.calc_updown(self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+ updown_out = module_out.calc_updown(self.out_proj.weight)
+
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += updown_out
+
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+
+ continue
if module is None:
continue
- print(f'failed to calculate network weights for layer {network_layer_name}')
+ logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names
@@ -519,6 +531,7 @@ def infotext_pasted(infotext, params): if added:
params["Prompt"] += "\n" + "".join(added)
+extra_network_lora = None
available_networks = {}
available_network_aliases = {}
|