From b5d1af11b7dc718d4d91d379c75e46f4bd2e2fe6 Mon Sep 17 00:00:00 2001 From: Abdullah Barhoum Date: Sun, 11 Sep 2022 07:11:27 +0200 Subject: Modular device management --- modules/esrgan_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index e86ad775..7f3baf31 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -9,12 +9,13 @@ from PIL import Image import modules.esrgam_model_arch as arch from modules import shared from modules.shared import opts +from modules.devices import has_mps import modules.images def load_model(filename): # this code is adapted from https://github.com/xinntao/ESRGAN - pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None) + pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) if 'conv_first.weight' in pretrained_net: -- cgit v1.2.3