diff options
author | Dynamic <bradje@naver.com> | 2022-10-25 09:27:32 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-25 09:27:32 +0000 |
commit | 563fb0aa39faca32187e78c07bec695531f21f39 (patch) | |
tree | e8ba5b699b256ce90a07c52c52051e504f601659 /modules/swinir_model.py | |
parent | e595b41c9d8a596b9b29d9505324e9afca2f12b5 (diff) | |
parent | 3e15f8e0f5cc87507f77546d92435670644dbd18 (diff) | |
download | stable-diffusion-webui-gfx803-563fb0aa39faca32187e78c07bec695531f21f39.tar.gz stable-diffusion-webui-gfx803-563fb0aa39faca32187e78c07bec695531f21f39.tar.bz2 stable-diffusion-webui-gfx803-563fb0aa39faca32187e78c07bec695531f21f39.zip |
Merge branch 'AUTOMATIC1111:master' into kr-localization
Diffstat (limited to 'modules/swinir_model.py')
-rw-r--r-- | modules/swinir_model.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/modules/swinir_model.py b/modules/swinir_model.py index baa02e3d..4253b66d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -7,8 +7,8 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm -from modules import modelloader -from modules.shared import cmd_opts, opts, device +from modules import modelloader, devices +from modules.shared import cmd_opts, opts from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler): model = self.load_model(model_file) if model is None: return img - model = model.to(device) + model = model.to(devices.device_swinir) img = upscale(img, model) try: torch.cuda.empty_cache() @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) + img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old @@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) - W = torch.zeros_like(E, dtype=torch.half, device=device) + E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) + W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: |