aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com>2023-01-13 18:22:23 +0000
committerGitHub <noreply@github.com>2023-01-13 18:22:23 +0000
commita407c9f0147c779865c940cbf62c7019dbc1f7b4 (patch)
tree6aec56284753a197cab94a42f644b54c08bdaa79
parenteaebcf638391071172d504568d661931f7e3c740 (diff)
downloadstable-diffusion-webui-gfx803-a407c9f0147c779865c940cbf62c7019dbc1f7b4.tar.gz
stable-diffusion-webui-gfx803-a407c9f0147c779865c940cbf62c7019dbc1f7b4.tar.bz2
stable-diffusion-webui-gfx803-a407c9f0147c779865c940cbf62c7019dbc1f7b4.zip
Automatic torch install for amd on linux
This commit allows the launch script to automatically download rocm's torch version for AMD GPUs using an external GPU detection script. It also prints the operative system and GPU in use.
-rw-r--r--launch.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/launch.py b/launch.py
index bcbb792c..668548f1 100644
--- a/launch.py
+++ b/launch.py
@@ -7,6 +7,7 @@ import shlex
import platform
import argparse
import json
+import detection
dir_repos = "repositories"
dir_extensions = "extensions"
@@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
+# Get the GPU vendor and the operating system
+gpu = detection.check_gpu()
+if os.name == "posix":
+ os_name = platform.uname().system
+else:
+ os_name = os.name
def commit_hash():
global stored_commit_hash
@@ -173,7 +180,11 @@ def run_extensions_installers(settings_file):
def prepare_environment():
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
+ if gpu == "AMD" and os_name !="nt":
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2")
+ else:
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -295,6 +306,8 @@ def tests(test_dir):
def start():
+ print(f"Operating System: {os_name}")
+ print(f"GPU: {gpu}")
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
if '--nowebui' in sys.argv: