aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNuullll <vfirst218@gmail.com>2023-12-06 12:55:42 +0000
committerNuullll <vfirst218@gmail.com>2023-12-06 12:55:47 +0000
commit746783f7a47f38f728f221cc26fe04035d3ca66b (patch)
tree69a7b1c6aa2f2e20db0994ae58204b26694a296e
parentf92d61497a426a19818625c3ccdaae9beeb82b31 (diff)
downloadstable-diffusion-webui-gfx803-746783f7a47f38f728f221cc26fe04035d3ca66b.tar.gz
stable-diffusion-webui-gfx803-746783f7a47f38f728f221cc26fe04035d3ca66b.tar.bz2
stable-diffusion-webui-gfx803-746783f7a47f38f728f221cc26fe04035d3ca66b.zip
[IPEX] Fix embedding
Cast `torch.bmm` args into same `dtype`. Fixes the following error when using Text Inversion embedding (#14224): ``` RuntimeError: could not create a primitive descriptor for a matmul primitive ```
-rw-r--r--modules/xpu_specific.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index d933c790..ec1ad100 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -48,3 +48,6 @@ if has_xpu:
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+ CondFunc('torch.bmm',
+ lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
+ lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)