aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-12-10 06:14:30 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-12-10 06:14:45 +0000
commit7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d (patch)
tree1fbb4d81ced27863380f8fab80305519b8014e59 /modules/sd_hijack_optimizations.py
parent2641d1b83bca6961ab1367fd9b67ebd8a94af35b (diff)
downloadstable-diffusion-webui-gfx803-7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d.tar.gz
stable-diffusion-webui-gfx803-7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d.tar.bz2
stable-diffusion-webui-gfx803-7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d.zip
do not replace entire unet for the resolution hack
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py28
1 files changed, 0 insertions, 28 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 8cd4c954..85909eb9 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
-
-def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
- if self.num_classes is not None:
- assert y.shape == (x.shape[0],)
- emb = emb + self.label_emb(y)
-
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context)
- hs.append(h)
- h = self.middle_block(h, emb, context)
- for module in self.output_blocks:
- if h.shape[-2:] != hs[-1].shape[-2:]:
- h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
- h = torch.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context)
- h = h.type(x.dtype)
- if self.predict_codebook_ids:
- return self.id_predictor(h)
- else:
- return self.out(h)