aboutsummaryrefslogtreecommitdiffstats
path: root/modules/scunet_model.py
diff options
context:
space:
mode:
authorJairo Correa <jn.j41r0@gmail.com>2022-10-04 22:53:52 +0000
committerJairo Correa <jn.j41r0@gmail.com>2022-10-04 22:53:52 +0000
commit1f50971fb8c83c255c2819dd0b3f29a46b74f7d9 (patch)
treefd57f40a1ffa2b28105ec0bb3f7f3ab4a742681a /modules/scunet_model.py
parentad0cc85d1f0bd52877963f296eb1257a0c2b012b (diff)
parentef40e4cd4d383a3405e03f1da3f5b5a1820a8f53 (diff)
downloadstable-diffusion-webui-gfx803-1f50971fb8c83c255c2819dd0b3f29a46b74f7d9.tar.gz
stable-diffusion-webui-gfx803-1f50971fb8c83c255c2819dd0b3f29a46b74f7d9.tar.bz2
stable-diffusion-webui-gfx803-1f50971fb8c83c255c2819dd0b3f29a46b74f7d9.zip
Merge branch 'master' into fix-vram
Diffstat (limited to 'modules/scunet_model.py')
-rw-r--r--modules/scunet_model.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 7987ac14..fb64b740 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net
@@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None:
return img
- device = shared.device
+ device = devices.device_scunet
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(device)
img = img.to(device)
with torch.no_grad():
@@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
- device = shared.device
+ device = devices.device_scunet
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)