aboutsummaryrefslogtreecommitdiffstats
path: root/modules/swinir_model.py
blob: f515779ecd8196811571e64e45927db5e98ac470 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import contextlib
import os
import sys
import traceback

import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url

import modules.images
from modules import modelloader
from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net

model_dir = "SwinIR"
model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
model_name = "SwinIR x4"
model_path = os.path.join(models_path, model_dir)
cmd_path = ""
precision_scope = (
    torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)


def load_model(path, scale=4):
    global model_path
    global model_name
    if "http" in path:
        dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
        filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
    else:
        filename = path
    if filename is None or not os.path.exists(filename):
        return None
    model = net(
        upscale=scale,
        in_chans=3,
        img_size=64,
        window_size=8,
        img_range=1.0,
        depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
        embed_dim=240,
        num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
        mlp_ratio=2,
        upsampler="nearest+conv",
        resi_connection="3conv",
    )

    pretrained_model = torch.load(filename)
    model.load_state_dict(pretrained_model["params_ema"], strict=True)
    if not cmd_opts.no_half:
        model = model.half()
    return model


def setup_model(dirname):
    global model_path
    global model_name
    global cmd_path
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    cmd_path = dirname
    model_file = ""
    try:
        models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)

        if len(models) != 0:
            model_file = models[0]
            name = modelloader.friendly_name(model_file)
        else:
            # Add the "default" model if none are found.
            model_file = model_url
            name = model_name

        modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
    except Exception:
        print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)


def upscale(
    img,
    model,
    tile=opts.SWIN_tile,
    tile_overlap=opts.SWIN_tile_overlap,
    window_size=8,
    scale=4,
):
    img = np.array(img)
    img = img[:, :, ::-1]
    img = np.moveaxis(img, 2, 0) / 255
    img = torch.from_numpy(img).float()
    img = img.unsqueeze(0).to(device)
    with torch.no_grad(), precision_scope("cuda"):
        _, _, h_old, w_old = img.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
        img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
        img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
        output = inference(img, model, tile, tile_overlap, window_size, scale)
        output = output[..., : h_old * scale, : w_old * scale]
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        if output.ndim == 3:
            output = np.transpose(
                output[[2, 1, 0], :, :], (1, 2, 0)
            )  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        return Image.fromarray(output, "RGB")


def inference(img, model, tile, tile_overlap, window_size, scale):
    # test the image tile by tile
    b, c, h, w = img.size()
    tile = min(tile, h, w)
    assert tile % window_size == 0, "tile size should be a multiple of window_size"
    sf = scale

    stride = tile - tile_overlap
    h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
    w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
    E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
    W = torch.zeros_like(E, dtype=torch.half, device=device)

    for h_idx in h_idx_list:
        for w_idx in w_idx_list:
            in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
            out_patch = model(in_patch)
            out_patch_mask = torch.ones_like(out_patch)

            E[
                ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
            ].add_(out_patch)
            W[
                ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
            ].add_(out_patch_mask)
    output = E.div_(W)

    return output


class UpscalerSwin(modules.images.Upscaler):
    def __init__(self, filename, title):
        self.name = title
        self.filename = filename

    def do_upscale(self, img):
        model = load_model(self.filename)
        if model is None:
            return img
        model = model.to(device)
        img = upscale(img, model)
        try:
            torch.cuda.empty_cache()
        except:
            pass
        return img