From 7d5c29b674bacc5654f8613af134632b7cbdb158 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 26 Sep 2022 10:27:18 -0500 Subject: Cleanup existing directories, fixes --- modules/codeformer_model.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) (limited to 'modules/codeformer_model.py') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index dc0a5eee..efd881eb 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -5,14 +5,13 @@ import traceback import cv2 import torch +import modules.face_restoration +import modules.shared from modules import shared, devices, modelloader from modules.paths import script_path, models_path -import modules.shared -import modules.face_restoration -from importlib import reload -# codeformer people made a choice to include modified basicsr library to their project, which makes -# it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN. +# codeformer people made a choice to include modified basicsr library to their project which makes +# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. # I am making a choice to include some files from codeformer to work around this issue. model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) @@ -31,11 +30,6 @@ def setup_model(dirname): if path is None: return - - # both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own - #stored_sys_path = sys.path - #sys.path = [path] + sys.path - try: from torchvision.transforms.functional import normalize from modules.codeformer.codeformer_arch import CodeFormer @@ -67,7 +61,6 @@ def setup_model(dirname): print("Unable to load codeformer model.") return None, None net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) - ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True) checkpoint = torch.load(ckpt_path)['params_ema'] net.load_state_dict(checkpoint) net.eval() -- cgit v1.2.3