diff options
author | C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> | 2022-09-20 13:36:20 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2022-09-20 20:31:06 +0000 |
commit | 948eff4b3caa237334389a5a08adda130e2b43a5 (patch) | |
tree | 34578575c2519ae73780839f5bdb5790bce5662a | |
parent | 7267b7d2d91c559626eaf43e3c0cd9c5918918dd (diff) | |
download | stable-diffusion-webui-gfx803-948eff4b3caa237334389a5a08adda130e2b43a5.tar.gz stable-diffusion-webui-gfx803-948eff4b3caa237334389a5a08adda130e2b43a5.tar.bz2 stable-diffusion-webui-gfx803-948eff4b3caa237334389a5a08adda130e2b43a5.zip |
make swinir actually useful
-rw-r--r-- | modules/swinir.py (renamed from swinir.py) | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/swinir.py b/modules/swinir.py index cb2bbe3d..6c7f0a2d 100644 --- a/swinir.py +++ b/modules/swinir.py @@ -12,7 +12,13 @@ import modules.images from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(cmd_opts.esrgan_models_path))):
+def load_model(task = "realsr", large_model = True, model_path="C:/sd/ESRGANn/4x-large.pth", scale=4):
+
+ try:
+ modules.shared.sd_upscalers.append(UpscalerSwin("McSwinnySwin"))
+ except Exception:
+ print(f"Error loading ESRGAN model", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
if not large_model:
# use 'nearest+conv' to avoid block artifacts
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8,
@@ -26,12 +32,16 @@ def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(c mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
pretrained_model = torch.load(model_path)
- model.load_state_dict(pretrained_model, strict=True)
+ model.load_state_dict(pretrained_model["params_ema"], strict=True)
return model.half().to(device)
def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, window_size = 8, scale = 4):
- img = cv2.imread(img, cv2.IMREAD_COLOR).astype(np.float16) / 255.
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.moveaxis(img, 2, 0) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(device)
model = load_model()
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
@@ -45,7 +55,7 @@ def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, w if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
- return output
+ return Image.fromarray(output, 'RGB')
def inference(img, model, tile, tile_overlap, window_size, scale):
@@ -71,4 +81,12 @@ def inference(img, model, tile, tile_overlap, window_size, scale): W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
output = E.div_(W)
- return output
\ No newline at end of file + return output
+
+class UpscalerSwin(modules.images.Upscaler):
+ def __init__(self, title):
+ self.name = title
+
+ def do_upscale(self, img):
+ img = upscale(img)
+ return img
|