From 7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:14:30 +0300 Subject: do not replace entire unet for the resolution hack --- modules/sd_hijack_unet.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 modules/sd_hijack_unet.py (limited to 'modules/sd_hijack_unet.py') diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py new file mode 100644 index 00000000..1b9d7757 --- /dev/null +++ b/modules/sd_hijack_unet.py @@ -0,0 +1,30 @@ +import torch + + +class TorchHijackForUnet: + """ + This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; + this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64 + """ + + def __getattr__(self, item): + if item == 'cat': + return self.cat + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + def cat(self, tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + if a.shape[-2:] != b.shape[-2:]: + a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") + + tensors = (a, b) + + return torch.cat(tensors, *args, **kwargs) + + +th = TorchHijackForUnet() -- cgit v1.2.3