aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 9779c30c..2d26b5f7 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -49,6 +49,8 @@ class StableDiffusionModelHijack:
fixes = None
comments = []
dir_mtime = None
+ layers = None
+ circular_enabled = False
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
@@ -105,6 +107,24 @@ class StableDiffusionModelHijack:
if cmd_opts.opt_split_attention:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ def flatten(el):
+ flattened = [flatten(children) for children in el.children()]
+ res = [el]
+ for c in flattened:
+ res += c
+ return res
+
+ self.layers = flatten(m)
+
+ def apply_circular(self, enable):
+ if self.circular_enabled == enable:
+ return
+
+ self.circular_enabled = enable
+
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
+ layer.padding_mode = 'circular' if enable else 'zeros'
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):