aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorhidenorly <twitte.harold@gmail.com>2023-11-20 16:13:53 +0000
committerhidenorly <twitte.harold@gmail.com>2023-11-20 16:13:53 +0000
commit58c19545c83fa6925c9ce2216ee64964eb5129ce (patch)
tree54db6d2c25e297fff45088bb3c12cd4a86ac9474
parent5f36f6ab21228235021c2441a404f7d297ef6737 (diff)
downloadstable-diffusion-webui-gfx803-58c19545c83fa6925c9ce2216ee64964eb5129ce.tar.gz
stable-diffusion-webui-gfx803-58c19545c83fa6925c9ce2216ee64964eb5129ce.tar.bz2
stable-diffusion-webui-gfx803-58c19545c83fa6925c9ce2216ee64964eb5129ce.zip
Add FP32 fallback support on sd_vae_approx
This tries to execute interpolate with FP32 if it failed. Background is that on some environment such as Mx chip MacOS devices, we get error as follows: ``` "torch/nn/functional.py", line 3931, in interpolate return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half' ``` In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it. Note that the submodule may require additional modifications. The following is the example modification on the other submodule. ```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py class Upsample(nn.Module): ..snip.. def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" ) else: try: x = F.interpolate(x, scale_factor=2, mode="nearest") except: x = F.interpolate(x.to(th.float32), scale_factor=2, mode="nearest").to(x.dtype) if self.use_conv: x = self.conv(x) return x ..snip.. ``` You can see the FP32 fallback execution as same as sd_vae_approx.py.
-rw-r--r--modules/sd_vae_approx.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 3965e223..8370493f 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -21,7 +21,13 @@ class VAEApprox(nn.Module):
def forward(self, x):
extra = 11
- x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+ try:
+ x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+ except RuntimeError as e:
+ if "not implemented for" in str(e) and "Half" in str(e):
+ x = nn.functional.interpolate(x.to(torch.float32), (x.shape[2] * 2, x.shape[3] * 2)).to(x.dtype)
+ else:
+ print(f"An unexpected RuntimeError occurred: {str(e)}")
x = nn.functional.pad(x, (extra, extra, extra, extra))
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]: