aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-25 12:04:39 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-25 12:04:39 +0000
commit073f6eac22357dbc8ed3f2c55f292a82c7ab25d9 (patch)
treef7c9fd5c52b1936467a7baa14e9f723be11dd45c /modules/sd_hijack.py
parent615b2fc9ce8cb0c61424aa03655f82209f425d21 (diff)
downloadstable-diffusion-webui-gfx803-073f6eac22357dbc8ed3f2c55f292a82c7ab25d9.tar.gz
stable-diffusion-webui-gfx803-073f6eac22357dbc8ed3f2c55f292a82c7ab25d9.tar.bz2
stable-diffusion-webui-gfx803-073f6eac22357dbc8ed3f2c55f292a82c7ab25d9.zip
potential fix for embeddings no loading on AMD cards
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index ccbaa9ad..7b2030d4 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -201,7 +201,7 @@ class StableDiffusionModelHijack:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path)
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -217,7 +217,7 @@ class StableDiffusionModelHijack:
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
- self.word_embeddings[name] = emb.detach()
+ self.word_embeddings[name] = emb.detach().to(device)
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]