diff options
author | Nuullll <vfirst218@gmail.com> | 2023-12-06 12:55:42 +0000 |
---|---|---|
committer | Nuullll <vfirst218@gmail.com> | 2023-12-06 12:55:47 +0000 |
commit | 746783f7a47f38f728f221cc26fe04035d3ca66b (patch) | |
tree | 69a7b1c6aa2f2e20db0994ae58204b26694a296e /modules/xpu_specific.py | |
parent | f92d61497a426a19818625c3ccdaae9beeb82b31 (diff) | |
download | stable-diffusion-webui-gfx803-746783f7a47f38f728f221cc26fe04035d3ca66b.tar.gz stable-diffusion-webui-gfx803-746783f7a47f38f728f221cc26fe04035d3ca66b.tar.bz2 stable-diffusion-webui-gfx803-746783f7a47f38f728f221cc26fe04035d3ca66b.zip |
[IPEX] Fix embedding
Cast `torch.bmm` args into same `dtype`.
Fixes the following error when using Text Inversion embedding (#14224):
```
RuntimeError: could not create a primitive descriptor for a matmul
primitive
```
Diffstat (limited to 'modules/xpu_specific.py')
-rw-r--r-- | modules/xpu_specific.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c790..ec1ad100 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,6 @@ if has_xpu: CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) |