diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2024-01-06 07:50:06 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-06 07:50:06 +0000 |
commit | a4ee64050a15263c73884ed7797c5332c9f559c1 (patch) | |
tree | 50be0ea6b7575a39e87a9b9af45b95889ab0f347 /extensions-builtin/Lora/network.py | |
parent | 942617f82830df279f69da95797dd17780451545 (diff) | |
parent | 44744d6005da5e424267698ee3279caa597dfebc (diff) | |
download | stable-diffusion-webui-gfx803-a4ee64050a15263c73884ed7797c5332c9f559c1.tar.gz stable-diffusion-webui-gfx803-a4ee64050a15263c73884ed7797c5332c9f559c1.tar.bz2 stable-diffusion-webui-gfx803-a4ee64050a15263c73884ed7797c5332c9f559c1.zip |
Merge pull request #14547 from AUTOMATIC1111/lyco-forward
Implement general forward method for all method in built-in lora ext
Diffstat (limited to 'extensions-builtin/Lora/network.py')
-rw-r--r-- | extensions-builtin/Lora/network.py | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index a62e5eff..b8fd9194 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,6 +3,9 @@ import os from collections import namedtuple
import enum
+import torch.nn as nn
+import torch.nn.functional as F
+
from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +118,29 @@ class NetworkModule: if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
+ self.ops = None
+ self.extra_kwargs = {}
+ if isinstance(self.sd_module, nn.Conv2d):
+ self.ops = F.conv2d
+ self.extra_kwargs = {
+ 'stride': self.sd_module.stride,
+ 'padding': self.sd_module.padding
+ }
+ elif isinstance(self.sd_module, nn.Linear):
+ self.ops = F.linear
+ elif isinstance(self.sd_module, nn.LayerNorm):
+ self.ops = F.layer_norm
+ self.extra_kwargs = {
+ 'normalized_shape': self.sd_module.normalized_shape,
+ 'eps': self.sd_module.eps
+ }
+ elif isinstance(self.sd_module, nn.GroupNorm):
+ self.ops = F.group_norm
+ self.extra_kwargs = {
+ 'num_groups': self.sd_module.num_groups,
+ 'eps': self.sd_module.eps
+ }
+
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -155,5 +181,10 @@ class NetworkModule: raise NotImplementedError()
def forward(self, x, y):
- raise NotImplementedError()
+ """A general forward implementation for all modules"""
+ if self.ops is None:
+ raise NotImplementedError()
+ else:
+ updown, ex_bias = self.calc_updown(self.sd_module.weight)
+ return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|