diff options
author | 不会画画的中医不是好程序员 <yfszzx@gmail.com> | 2022-10-25 07:38:33 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-25 07:38:33 +0000 |
commit | 5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f (patch) | |
tree | 7e23ff3f180ada71bc0824270efa82e7cfe68397 /modules/swinir_model.py | |
parent | ff305acd51cc71c5eea8aee0f537a26a6d1ba2a1 (diff) | |
parent | 91c1e1e6a92061b99c92a5b1d548535907d2ad96 (diff) | |
download | stable-diffusion-webui-gfx803-5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f.tar.gz stable-diffusion-webui-gfx803-5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f.tar.bz2 stable-diffusion-webui-gfx803-5bfa2b23ca6dad7d2664bfab5dab4b1dfabf5b7f.zip |
Merge branch 'AUTOMATIC1111:master' into Inspiron
Diffstat (limited to 'modules/swinir_model.py')
-rw-r--r-- | modules/swinir_model.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/modules/swinir_model.py b/modules/swinir_model.py index baa02e3d..4253b66d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -7,8 +7,8 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm -from modules import modelloader -from modules.shared import cmd_opts, opts, device +from modules import modelloader, devices +from modules.shared import cmd_opts, opts from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler): model = self.load_model(model_file) if model is None: return img - model = model.to(device) + model = model.to(devices.device_swinir) img = upscale(img, model) try: torch.cuda.empty_cache() @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) + img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old @@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) - W = torch.zeros_like(E, dtype=torch.half, device=device) + E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) + W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: |