aboutsummaryrefslogtreecommitdiffstats
path: root/modules/swinir_model.py
diff options
context:
space:
mode:
author不会画画的中医不是好程序员 <yfszzx@gmail.com>2022-10-25 07:38:33 +0000
committerGitHub <noreply@github.com>2022-10-25 07:38:33 +0000
commit5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f (patch)
tree7e23ff3f180ada71bc0824270efa82e7cfe68397 /modules/swinir_model.py
parentff305acd51cc71c5eea8aee0f537a26a6d1ba2a1 (diff)
parent91c1e1e6a92061b99c92a5b1d548535907d2ad96 (diff)
downloadstable-diffusion-webui-gfx803-5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f.tar.gz
stable-diffusion-webui-gfx803-5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f.tar.bz2
stable-diffusion-webui-gfx803-5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f.zip
Merge branch 'AUTOMATIC1111:master' into Inspiron
Diffstat (limited to 'modules/swinir_model.py')
-rw-r--r--modules/swinir_model.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index baa02e3d..4253b66d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -7,8 +7,8 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
-from modules import modelloader
-from modules.shared import cmd_opts, opts, device
+from modules import modelloader, devices
+from modules.shared import cmd_opts, opts
from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file)
if model is None:
return img
- model = model.to(device)
+ model = model.to(devices.device_swinir)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
@@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
@@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=device)
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
+ W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: