From b7f0e815624dab182aff406c8f227b39ec17452f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 08:41:26 +0300 Subject: fix error that causes some extra networks to be disabled if both and are present in the prompt --- modules/extra_networks.py | 58 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 16 deletions(-) (limited to 'modules/extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index fa28ac75..b9533677 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,6 +1,7 @@ import json import os import re +import logging from collections import defaultdict from modules import errors @@ -86,27 +87,55 @@ class ExtraNetwork: raise NotImplementedError -def activate(p, extra_network_data): - """call activate for extra networks in extra_network_data in specified order, then call - activate for all remaining registered networks with an empty argument list""" +def lookup_extra_networks(extra_network_data): + """returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks. - activated = [] + Example input: + { + 'lora': [], + 'lyco': [], + 'hypernet': [] + } + + Example output: + + { + : [, ], + : [] + } + """ - for extra_network_name, extra_network_args in extra_network_data.items(): + res = {} + + for extra_network_name, extra_network_args in list(extra_network_data.items()): extra_network = extra_network_registry.get(extra_network_name, None) + alias = extra_network_aliases.get(extra_network_name, None) - if extra_network is None: - extra_network = extra_network_aliases.get(extra_network_name, None) + if alias is not None and extra_network is None: + extra_network = alias if extra_network is None: - print(f"Skipping unknown extra network: {extra_network_name}") + logging.info(f"Skipping unknown extra network: {extra_network_name}") continue + res.setdefault(extra_network, []).extend(extra_network_args) + + return res + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + activated = [] + + for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items(): + try: extra_network.activate(p, extra_network_args) activated.append(extra_network) except Exception as e: - errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): if extra_network in activated: @@ -125,19 +154,16 @@ def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks""" - for extra_network_name in extra_network_data: - extra_network = extra_network_registry.get(extra_network_name, None) - if extra_network is None: - continue + data = lookup_extra_networks(extra_network_data) + for extra_network in data: try: extra_network.deactivate(p) except Exception as e: - errors.display(e, f"deactivating extra network {extra_network_name}") + errors.display(e, f"deactivating extra network {extra_network.name}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: + if extra_network in data: continue try: -- cgit v1.2.3