diff options
author | Aarni Koskela <akx@iki.fi> | 2023-12-25 21:01:02 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-12-30 14:30:49 +0000 |
commit | b621a63cf68c788487684250856707cb352b82d0 (patch) | |
tree | 37fc1ce21103832f126ba799fad7ede6cbdd2308 | |
parent | b0f59342346b1c8b405f97c0e0bb01c6ae05c601 (diff) | |
download | stable-diffusion-webui-gfx803-b621a63cf68c788487684250856707cb352b82d0.tar.gz stable-diffusion-webui-gfx803-b621a63cf68c788487684250856707cb352b82d0.tar.bz2 stable-diffusion-webui-gfx803-b621a63cf68c788487684250856707cb352b82d0.zip |
Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN
-rw-r--r-- | .github/workflows/run_tests.yaml | 8 | ||||
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | modules/codeformer_model.py | 158 | ||||
-rw-r--r-- | modules/face_restoration_utils.py | 163 | ||||
-rw-r--r-- | modules/gfpgan_model.py | 166 | ||||
-rw-r--r-- | requirements.txt | 1 | ||||
-rw-r--r-- | requirements_versions.txt | 1 | ||||
-rw-r--r-- | test/conftest.py | 15 | ||||
-rw-r--r-- | test/test_face_restorers.py | 29 | ||||
-rw-r--r-- | test/test_files/two-faces.jpg | bin | 0 -> 14768 bytes | |||
-rw-r--r-- | test/test_outputs/.gitkeep | 0 |
11 files changed, 308 insertions, 234 deletions
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3dafaf8d..cd5c3f86 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -20,6 +20,12 @@ jobs: cache-dependency-path: | **/requirements*txt launch.py + - name: Cache models + id: cache-models + uses: actions/cache@v3 + with: + path: models + key: "2023-12-30" - name: Install test dependencies run: pip install wait-for-it -r requirements-test.txt env: @@ -33,6 +39,8 @@ jobs: TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu WEBUI_LAUNCH_LIVE_OUTPUT: "1" PYTHONUNBUFFERED: "1" + - name: Print installed packages + run: pip freeze - name: Start test server run: > python -m coverage run @@ -37,3 +37,4 @@ notification.mp3 /node_modules /package-lock.json /.coverage* +/test/test_outputs diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 517eadfd..ceda4bab 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,140 +1,62 @@ -import os
+from __future__ import annotations
+
+import logging
-import cv2
import torch
-import modules.face_restoration
-import modules.shared
-from modules import shared, devices, modelloader, errors
-from modules.paths import models_path
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
+
+logger = logging.getLogger(__name__)
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_download_name = 'codeformer-v0.1.0.pth'
-codeformer = None
+# used by e.g. postprocessing_codeformer.py
+codeformer: face_restoration.FaceRestoration | None = None
-class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
+class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
def name(self):
return "CodeFormer"
- def __init__(self, dirname):
- self.net = None
- self.face_helper = None
- self.cmd_dir = dirname
-
- def create_models(self):
- from facexlib.detection import retinaface
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
-
- if self.net is not None and self.face_helper is not None:
- self.net.to(devices.device_codeformer)
- return self.net, self.face_helper
- model_paths = modelloader.load_models(
- model_path,
- model_url,
- self.cmd_dir,
- download_name='codeformer-v0.1.0.pth',
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
ext_filter=['.pth'],
- )
-
- if len(model_paths) != 0:
- ckpt_path = model_paths[0]
- else:
- print("Unable to load codeformer model.")
- return None, None
- net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer)
-
- if hasattr(retinaface, 'device'):
- retinaface.device = devices.device_codeformer
-
- face_helper = FaceRestoreHelper(
- upscale_factor=1,
- face_size=512,
- crop_ratio=(1, 1),
- det_model='retinaface_resnet50',
- save_ext='png',
- use_parse=True,
- device=devices.device_codeformer,
- )
-
- self.net = net
- self.face_helper = face_helper
-
- def send_model_to(self, device):
- self.net.to(device)
- self.face_helper.face_det.to(device)
- self.face_helper.face_parse.to(device)
-
- def restore(self, np_image, w=None):
- from torchvision.transforms.functional import normalize
- from basicsr.utils import img2tensor, tensor2img
- np_image = np_image[:, :, ::-1]
-
- original_resolution = np_image.shape[0:2]
-
- self.create_models()
- if self.net is None or self.face_helper is None:
- return np_image
-
- self.send_model_to(devices.device_codeformer)
-
- self.face_helper.clean_all()
- self.face_helper.read_image(np_image)
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- self.face_helper.align_warp_face()
-
- for cropped_face in self.face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
-
- try:
- with torch.no_grad():
- res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)
- if isinstance(res, tuple):
- output = res[0]
- else:
- output = res
- if not isinstance(res, torch.Tensor):
- raise TypeError(f"Expected torch.Tensor, got {type(res)}")
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- devices.torch_gc()
- except Exception:
- errors.report('Failed inference for CodeFormer', exc_info=True)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
-
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
-
- self.face_helper.get_inverse_affine(None)
-
- restored_img = self.face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
+ ):
+ return modelloader.load_spandrel_model(
+ model_path,
+ device=devices.device_codeformer,
+ ).model
+ raise ValueError("No codeformer model found")
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(
- restored_img,
- (0, 0),
- fx=original_resolution[1]/restored_img.shape[1],
- fy=original_resolution[0]/restored_img.shape[0],
- interpolation=cv2.INTER_LINEAR,
- )
+ def get_device(self):
+ return devices.device_codeformer
- self.face_helper.clean_all()
+ def restore(self, np_image, w: float | None = None):
+ if w is None:
+ w = getattr(shared.opts, "code_former_weight", 0.5)
- if shared.opts.face_restoration_unload:
- self.send_model_to(devices.cpu)
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, w=w, adain=True)[0]
- return restored_img
+ return self.restore_with_helper(np_image, restore_face)
-def setup_model(dirname):
- os.makedirs(model_path, exist_ok=True)
+def setup_model(dirname: str) -> None:
+ global codeformer
try:
- global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception:
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py new file mode 100644 index 00000000..c65c85ef --- /dev/null +++ b/modules/face_restoration_utils.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import logging +import os +from functools import cached_property +from typing import TYPE_CHECKING, Callable + +import cv2 +import numpy as np +import torch + +from modules import devices, errors, face_restoration, shared + +if TYPE_CHECKING: + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +logger = logging.getLogger(__name__) + + +def create_face_helper(device) -> FaceRestoreHelper: + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + if hasattr(retinaface, 'device'): + retinaface.device = device + return FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=device, + ) + + +def restore_with_face_helper( + np_image: np.ndarray, + face_helper: FaceRestoreHelper, + restore_face: Callable[[np.ndarray], np.ndarray], +) -> np.ndarray: + """ + Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. + + `restore_face` should take a cropped face image and return a restored face image. + """ + from basicsr.utils import img2tensor, tensor2img + from torchvision.transforms.functional import normalize + np_image = np_image[:, :, ::-1] + original_resolution = np_image.shape[0:2] + + try: + logger.debug("Detecting faces...") + face_helper.clean_all() + face_helper.read_image(np_image) + face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + face_helper.align_warp_face() + logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) + for cropped_face in face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + + try: + with torch.no_grad(): + restored_face = tensor2img( + restore_face(cropped_face_t), + rgb2bgr=True, + min_max=(-1, 1), + ) + devices.torch_gc() + except Exception: + errors.report('Failed face-restoration inference', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + logger.debug("Merging restored faces into image") + face_helper.get_inverse_affine(None) + img = face_helper.paste_faces_to_input_image() + img = img[:, :, ::-1] + if original_resolution != img.shape[0:2]: + img = cv2.resize( + img, + (0, 0), + fx=original_resolution[1] / img.shape[1], + fy=original_resolution[0] / img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) + logger.debug("Face restoration complete") + finally: + face_helper.clean_all() + return img + + +class CommonFaceRestoration(face_restoration.FaceRestoration): + net: torch.Module | None + model_url: str + model_download_name: str + + def __init__(self, model_path: str): + super().__init__() + self.net = None + self.model_path = model_path + os.makedirs(model_path, exist_ok=True) + + @cached_property + def face_helper(self) -> FaceRestoreHelper: + return create_face_helper(self.get_device()) + + def send_model_to(self, device): + if self.net: + logger.debug("Sending %s to %s", self.net, device) + self.net.to(device) + if self.face_helper: + logger.debug("Sending face helper to %s", device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def get_device(self): + raise NotImplementedError("get_device must be implemented by subclasses") + + def load_net(self) -> torch.Module: + raise NotImplementedError("load_net must be implemented by subclasses") + + def restore_with_helper( + self, + np_image: np.ndarray, + restore_face: Callable[[np.ndarray], np.ndarray], + ) -> np.ndarray: + try: + if self.net is None: + self.net = self.load_net() + except Exception: + logger.warning("Unable to load face-restoration model", exc_info=True) + return np_image + + try: + self.send_model_to(self.get_device()) + return restore_with_face_helper(np_image, self.face_helper, restore_face) + finally: + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) + + +def patch_facexlib(dirname: str) -> None: + import facexlib.detection + import facexlib.parsing + + det_facex_load_file_from_url = facexlib.detection.load_file_from_url + par_facex_load_file_from_url = facexlib.parsing.load_file_from_url + + def update_kwargs(kwargs): + return dict(kwargs, save_dir=dirname, model_dir=None) + + def facex_load_file_from_url(**kwargs): + return det_facex_load_file_from_url(**update_kwargs(kwargs)) + + def facex_load_file_from_url2(**kwargs): + return par_facex_load_file_from_url(**update_kwargs(kwargs)) + + facexlib.detection.load_file_from_url = facex_load_file_from_url + facexlib.parsing.load_file_from_url = facex_load_file_from_url2 diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 6b6f17c4..a356b56f 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,126 +1,68 @@ +from __future__ import annotations
+
+import logging
import os
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader, errors
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
-model_file_path = None
+logger = logging.getLogger(__name__)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
- global loaded_gfpgan_model
- global model_path
- global model_file_path
- if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
- return loaded_gfpgan_model
-
- if gfpgan_constructor is None:
- return None
-
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
-
- if len(models) == 1 and models[0].startswith("http"):
- model_file = models[0]
- elif len(models) != 0:
- gfp_models = []
- for item in models:
- if 'GFPGAN' in os.path.basename(item):
- gfp_models.append(item)
- latest_file = max(gfp_models, key=os.path.getctime)
- model_file = latest_file
- else:
- print("Unable to load gfpgan model!")
- return None
-
- import facexlib.detection.retinaface
-
- if hasattr(facexlib.detection.retinaface, 'device'):
- facexlib.detection.retinaface.device = devices.device_gfpgan
- model_file_path = model_file
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
- loaded_gfpgan_model = model
-
- return model
-
-
-def send_model_to(model, device):
- model.gfpgan.to(device)
- model.face_helper.face_det.to(device)
- model.face_helper.face_parse.to(device)
+model_download_name = "GFPGANv1.4.pth"
+gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
+
+
+class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def get_device(self):
+ return devices.device_gfpgan
+
+ def load_net(self) -> None:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ if 'GFPGAN' in os.path.basename(model_path):
+ net = modelloader.load_spandrel_model(
+ model_path,
+ device=self.get_device(),
+ ).model
+ net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
+ return net
+ raise ValueError("No GFPGAN model found")
+
+ def restore(self, np_image):
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, return_rgb=False)[0]
+
+ return self.restore_with_helper(np_image, restore_face)
def gfpgan_fix_faces(np_image):
- model = gfpgann()
- if model is None:
- return np_image
-
- send_model_to(model, devices.device_gfpgan)
-
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- model.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- send_model_to(model, devices.cpu)
-
+ if gfpgan_face_restorer:
+ return gfpgan_face_restorer.restore(np_image)
+ logger.warning("GFPGAN face restorer not set up")
return np_image
-gfpgan_constructor = None
+def setup_model(dirname: str) -> None:
+ global gfpgan_face_restorer
-
-def setup_model(dirname):
try:
- os.makedirs(model_path, exist_ok=True)
- import gfpgan
- import facexlib.detection
- import facexlib.parsing
-
- global user_path
- global have_gfpgan
- global gfpgan_constructor
- global model_file_path
-
- facexlib_path = model_path
-
- if dirname is not None:
- facexlib_path = dirname
-
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
- def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
-
- def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- gfpgan.utils.load_file_from_url = my_load_file_from_url
- facexlib.detection.load_file_from_url = facex_load_file_from_url
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
- user_path = dirname
- have_gfpgan = True
- gfpgan_constructor = gfpgan.GFPGANer
-
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
- def name(self):
- return "GFPGAN"
-
- def restore(self, np_image):
- return gfpgan_fix_faces(np_image)
-
- shared.face_restorers.append(FaceRestorerGFPGAN())
+ face_restoration_utils.patch_facexlib(dirname)
+ gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
+ shared.face_restorers.append(gfpgan_face_restorer)
except Exception:
errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/requirements.txt b/requirements.txt index 36f5674a..b1329c9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ clean-fid einops
facexlib
fastapi>=0.90.1
-gfpgan
gradio==3.41.2
inflection
jsonmerge
diff --git a/requirements_versions.txt b/requirements_versions.txt index 042fa708..edbb6db9 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -7,7 +7,6 @@ clean-fid==0.1.35 einops==0.4.1
facexlib==0.3.0
fastapi==0.94.0
-gfpgan==1.3.8
gradio==3.41.2
httpcore==0.15
inflection==0.5.1
diff --git a/test/conftest.py b/test/conftest.py index 31a5d9ea..e4fc5678 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,16 @@ +import base64 import os import pytest -import base64 - test_files_path = os.path.dirname(__file__) + "/test_files" +test_outputs_path = os.path.dirname(__file__) + "/test_outputs" + + +def pytest_configure(config): + # We don't want to fail on Py.test command line arguments being + # parsed by webui: + os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") def file_to_base64(filename): @@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str: @pytest.fixture(scope="session") # session so we don't read this over and over def mask_basic_image_base64() -> str: return file_to_base64(os.path.join(test_files_path, "mask_basic.png")) + + +@pytest.fixture(scope="session") +def initialize() -> None: + import webui # noqa: F401 diff --git a/test/test_face_restorers.py b/test/test_face_restorers.py new file mode 100644 index 00000000..7760d51b --- /dev/null +++ b/test/test_face_restorers.py @@ -0,0 +1,29 @@ +import os +from test.conftest import test_files_path, test_outputs_path + +import numpy as np +import pytest +from PIL import Image + + +@pytest.mark.usefixtures("initialize") +@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"]) +def test_face_restorers(restorer_name): + from modules import shared + + if restorer_name == "gfpgan": + from modules import gfpgan_model + gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path) + restorer = gfpgan_model.gfpgan_fix_faces + elif restorer_name == "codeformer": + from modules import codeformer_model + codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path) + restorer = codeformer_model.codeformer.restore + else: + raise NotImplementedError("...") + img = Image.open(os.path.join(test_files_path, "two-faces.jpg")) + np_img = np.array(img, dtype=np.uint8) + fixed_image = restorer(np_img) + assert fixed_image.shape == np_img.shape + assert not np.allclose(fixed_image, np_img) # should have visibly changed + Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png")) diff --git a/test/test_files/two-faces.jpg b/test/test_files/two-faces.jpg Binary files differnew file mode 100644 index 00000000..c9d1b010 --- /dev/null +++ b/test/test_files/two-faces.jpg diff --git a/test/test_outputs/.gitkeep b/test/test_outputs/.gitkeep new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/test_outputs/.gitkeep |