diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-21 05:36:07 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-21 05:36:07 +0000 |
commit | 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 (patch) | |
tree | 75230fd3b1c9c53593aca65d7260b62b6df2d82b /modules/hypernetworks/hypernetwork.py | |
parent | e33cace2c2074ef342d027c1f31ffc4b3c3e877e (diff) | |
download | stable-diffusion-webui-gfx803-40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75.tar.gz stable-diffusion-webui-gfx803-40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75.tar.bz2 stable-diffusion-webui-gfx803-40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75.zip |
extra networks UI
rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 107 |
1 files changed, 75 insertions, 32 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..80a47c79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
- multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__()
+ self.multiplier = 1.0
+
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x
def forward(self, x):
- return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
def trainables(self):
layer_structure = []
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure
-def apply_strength(value=None):
- HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None:
@@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters():
param.requires_grad = mode
+ def to(self, device):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.to(device)
+
+ return self
+
+ def set_multiplier(self, multiplier):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.multiplier = multiplier
+
+ return self
+
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
@@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None
if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- print("Loaded existing optimizer from checkpoint")
- print(f"Optimizer name is {self.optimizer_name}")
+ if shared.opts.print_hypernet_extra:
+ print("Loaded existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
else:
self.optimizer_name = "AdamW"
- print("No saved optimizer exists in checkpoint")
+ if shared.opts.print_hypernet_extra:
+ print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
@@ -306,23 +320,43 @@ def list_hypernetworks(path): return res
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- # Prevent any file named "None.pt" from being loaded.
- if path is not None and filename != "None":
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+def load_hypernetwork(name):
+ path = shared.hypernetworks.get(name, None)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print("Unloading hypernetwork")
+ if path is None:
+ return None
+
+ hypernetwork = Hypernetwork()
+
+ try:
+ hypernetwork.load(path)
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+ already_loaded = {}
+
+ for hypernetwork in shared.loaded_hypernetworks:
+ if hypernetwork.name in names:
+ already_loaded[hypernetwork.name] = hypernetwork
- shared.loaded_hypernetwork = None
+ shared.loaded_hypernetworks.clear()
+
+ for i, name in enumerate(names):
+ hypernetwork = already_loaded.get(name, None)
+ if hypernetwork is None:
+ hypernetwork = load_hypernetwork(name)
+
+ if hypernetwork is None:
+ continue
+
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+ shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0]
-def apply_hypernetwork(hypernetwork, context, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
- return context, context
+ return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
+ context_k = hypernetwork_layers[0](context_k)
+ context_v = hypernetwork_layers[1](context_v)
+ return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+ context_k = context
+ context_v = context
+ for hypernetwork in hypernetworks:
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+
return context_k, context_v
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x)
context = default(context, x)
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+ hypernetwork = Hypernetwork()
+ hypernetwork.load(path)
+ shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else:
images_dir = None
- hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
|