aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-10-01 03:53:25 +0000
committerbrkirch <brkirch@users.noreply.github.com>2022-10-01 03:53:25 +0000
commitbdaa36c84470adbdce3e98c01a69af5e95adfb02 (patch)
tree54999ce4bcefb0b360f3fe72969c12c248b6aa40
parent84e97a98c5233119d0f444e0a3a0f6391da23677 (diff)
downloadstable-diffusion-webui-gfx803-bdaa36c84470adbdce3e98c01a69af5e95adfb02.tar.gz
stable-diffusion-webui-gfx803-bdaa36c84470adbdce3e98c01a69af5e95adfb02.tar.bz2
stable-diffusion-webui-gfx803-bdaa36c84470adbdce3e98c01a69af5e95adfb02.zip
When device is MPS, use CPU for GFPGAN instead
GFPGAN will not work if the device is MPS, so default to CPU instead.
-rw-r--r--modules/devices.py2
-rw-r--r--modules/gfpgan_model.py6
2 files changed, 4 insertions, 4 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 07bb2339..08bb26d6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -34,7 +34,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device()
-device_codeformer = cpu if has_mps else device
+device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device
def randn(seed, shape):
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index bb30d733..fcd8544a 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(shared.device)
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
if gfpgan_constructor is None:
@@ -36,8 +36,8 @@ def gfpgann():
else:
print("Unable to load gfpgan model!")
return None
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
- model.gfpgan.to(shared.device)
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
+ model.gfpgan.to(devices.device_gfpgan)
loaded_gfpgan_model = model
return model