diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-12-30 15:06:31 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-30 15:06:31 +0000 |
commit | cd12c0e15c4dc1545cac18ba902ca17488812953 (patch) | |
tree | 9c70df74d3e426341d1189b1ceadbd8afffeae91 /extensions-builtin/ScuNET/scripts/scunet_model.py | |
parent | 05230c02606080527b65ace9eacb6fb835239877 (diff) | |
parent | 4ad0c0c0a805da4bac03cff86ea17c25a1291546 (diff) | |
download | stable-diffusion-webui-gfx803-cd12c0e15c4dc1545cac18ba902ca17488812953.tar.gz stable-diffusion-webui-gfx803-cd12c0e15c4dc1545cac18ba902ca17488812953.tar.bz2 stable-diffusion-webui-gfx803-cd12c0e15c4dc1545cac18ba902ca17488812953.zip |
Merge pull request #14425 from akx/spandrel
Use Spandrel for upscaling and face restoration architectures
Diffstat (limited to 'extensions-builtin/ScuNET/scripts/scunet_model.py')
-rw-r--r-- | extensions-builtin/ScuNET/scripts/scunet_model.py | 13 |
1 files changed, 2 insertions, 11 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64..5f3dd08b 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -7,9 +7,7 @@ from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet -from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): |