From 7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:14:30 +0300 Subject: do not replace entire unet for the resolution hack --- modules/sd_hijack_optimizations.py | 28 ---------------------------- 1 file changed, 28 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 8cd4c954..85909eb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) - -def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - if h.shape[-2:] != hs[-1].shape[-2:]: - h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) -- cgit v1.2.3