diff options
author | Aarni Koskela <akx@iki.fi> | 2023-12-31 17:52:32 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2024-01-02 08:44:38 +0000 |
commit | 2cacbc124c49f45da5b66b79d9b0a3ab943472eb (patch) | |
tree | 27ae05c022710aaa60e7791ca6bde3e7f60b7511 /modules | |
parent | 51f1cca8524d3ffa8930b32a571d239c60d65725 (diff) | |
download | stable-diffusion-webui-gfx803-2cacbc124c49f45da5b66b79d9b0a3ab943472eb.tar.gz stable-diffusion-webui-gfx803-2cacbc124c49f45da5b66b79d9b0a3ab943472eb.tar.bz2 stable-diffusion-webui-gfx803-2cacbc124c49f45da5b66b79d9b0a3ab943472eb.zip |
load_spandrel_model: make `half` `prefer_half`
As discussed with the Spandrel folks, it's good to heed Spandrel's
"supports half precision" flag to avoid e.g. black blotches and what-not.
Diffstat (limited to 'modules')
-rw-r--r-- | modules/modelloader.py | 20 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 2 |
2 files changed, 15 insertions, 7 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index a7194137..e100bb24 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -139,23 +139,31 @@ def load_upscalers(): def load_spandrel_model( - path: str, + path: str | os.PathLike, *, device: str | torch.device | None, - half: bool = False, + prefer_half: bool = False, dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel - model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) - if half: - model_descriptor.model.half() + half = False + if prefer_half: + if model_descriptor.supports_half: + model_descriptor.model.half() + half = True + else: + logger.info("Model %s does not support half precision, ignoring --half", path) if dtype: model_descriptor.model.to(dtype=dtype) model_descriptor.model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + logger.debug( + "Loaded %s from %s (device=%s, half=%s, dtype=%s)", + model_descriptor, path, device, half, dtype, + ) return model_descriptor diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 4d35b695..ff9d8ac0 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler): model_descriptor = modelloader.load_spandrel_model(
info.local_data_path,
device=self.device,
- half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(
|