diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-09-21 07:25:45 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-09-21 07:25:45 +0000 |
commit | a8ff0139637f67af848621968f72e8e620c8e575 (patch) | |
tree | e1ca2fffcdb311e42339d274aab6c38e619a116d /launch.py | |
parent | 4b26b50df0b417216ff4c12ce115394a828cdd05 (diff) | |
download | stable-diffusion-webui-gfx803-a8ff0139637f67af848621968f72e8e620c8e575.tar.gz stable-diffusion-webui-gfx803-a8ff0139637f67af848621968f72e8e620c8e575.tar.bz2 stable-diffusion-webui-gfx803-a8ff0139637f67af848621968f72e8e620c8e575.zip |
added --skip-torch-cuda-test to launcher for #746
Diffstat (limited to 'launch.py')
-rw-r--r-- | launch.py | 15 |
1 files changed, 13 insertions, 2 deletions
@@ -23,6 +23,16 @@ taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HAS codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
+args = shlex.split(commandline_args)
+
+
+def extract_arg(args, name):
+ return [x for x in args if x != name], name in args
+
+
+args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
+
+
def repo_dir(name):
return os.path.join(dir_repos, name)
@@ -95,7 +105,8 @@ print(f"Commit hash: {commit}") if not is_installed("torch"):
run(f'"{python}" -m {torch_command}', "Installing torch", "Couldn't install torch")
-run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU'")
+if not skip_torch_cuda_test:
+ run_python("import torch; assert not torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDINE_ARGS variable to disable this check'")
if not is_installed("k_diffusion.sampling"):
run_pip(f"install {k_diffusion_package}", "k-diffusion")
@@ -115,7 +126,7 @@ if not is_installed("lpips"): run_pip(f"install -r {requirements_file}", "requirements for Web UI")
-sys.argv += shlex.split(commandline_args)
+sys.argv += args
def start_webui():
|