mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Refactor functions from optional_submodules (#155954)
And use `pathlib.Path` instead of `os.path` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155954 Approved by: https://github.com/Skylion007 ghstack dependencies: #155947
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf90c9f8d1
commit
4886ba64dc
@ -1,43 +1,46 @@
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from subprocess import check_call
|
||||
|
||||
|
||||
repo_root = Path(__file__).absolute().parent.parent
|
||||
third_party_path = os.path.join(repo_root, "third_party")
|
||||
third_party_path = repo_root / "third_party"
|
||||
|
||||
|
||||
def _read_file(path: Path) -> str:
|
||||
with path.open(encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def _checkout_by_tag(repo: str, tag: str) -> None:
|
||||
check_call(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--depth",
|
||||
"1",
|
||||
"--branch",
|
||||
tag,
|
||||
repo,
|
||||
],
|
||||
cwd=third_party_path,
|
||||
)
|
||||
|
||||
|
||||
def read_nccl_pin() -> str:
|
||||
nccl_file = "nccl-cu12.txt"
|
||||
if os.getenv("DESIRED_CUDA", "").startswith("11") or os.getenv(
|
||||
"CUDA_VERSION", ""
|
||||
).startswith("11"):
|
||||
if os.getenv("DESIRED_CUDA", os.getenv("CUDA_VERSION", "")).startswith("11"):
|
||||
nccl_file = "nccl-cu11.txt"
|
||||
nccl_pin_path = os.path.join(
|
||||
repo_root, ".ci", "docker", "ci_commit_pins", nccl_file
|
||||
)
|
||||
with open(nccl_pin_path) as f:
|
||||
return f.read().strip()
|
||||
nccl_pin_path = repo_root / ".ci" / "docker" / "ci_commit_pins" / nccl_file
|
||||
return _read_file(nccl_pin_path)
|
||||
|
||||
|
||||
def checkout_nccl() -> None:
|
||||
release_tag = read_nccl_pin()
|
||||
print(f"-- Checkout nccl release tag: {release_tag}")
|
||||
nccl_basedir = os.path.join(third_party_path, "nccl")
|
||||
if not os.path.exists(nccl_basedir):
|
||||
subprocess.check_call(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--depth",
|
||||
"1",
|
||||
"--branch",
|
||||
release_tag,
|
||||
"https://github.com/NVIDIA/nccl.git",
|
||||
"nccl",
|
||||
],
|
||||
cwd=third_party_path,
|
||||
)
|
||||
nccl_basedir = third_party_path / "nccl"
|
||||
if not nccl_basedir.exists():
|
||||
_checkout_by_tag("https://github.com/NVIDIA/nccl", release_tag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user