Files
pytorch/tools/optional_submodules.py
Nikita Shulga ee56e9f8a8 [BE] Make Eigen an optional dependency (#155955)
Whose version is controlled by `eigen_pin.txt`, but which will be installed only if BLAS providers could not be found.
Why this is good for CI: we don't really build with Eigen ever and gitlab can be down when github is up, which causes spurious CI failures in the past, for example.

Remove eigen submodule and replace it with eigen_pin.txt

Fixes https://github.com/pytorch/pytorch/issues/108773
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155955
Approved by: https://github.com/atalman
2025-06-21 03:02:02 +00:00

64 lines
1.7 KiB
Python

import os
from pathlib import Path
from subprocess import check_call
repo_root = Path(__file__).absolute().parent.parent
third_party_path = repo_root / "third_party"
def _read_file(path: Path) -> str:
with path.open(encoding="utf-8") as f:
return f.read().strip()
def _checkout_by_tag(repo: str, tag: str) -> None:
check_call(
[
"git",
"clone",
"--depth",
"1",
"--branch",
tag,
repo,
],
cwd=third_party_path,
)
def read_nccl_pin() -> str:
nccl_file = "nccl-cu12.txt"
if os.getenv("DESIRED_CUDA", os.getenv("CUDA_VERSION", "")).startswith("11"):
nccl_file = "nccl-cu11.txt"
nccl_pin_path = repo_root / ".ci" / "docker" / "ci_commit_pins" / nccl_file
return _read_file(nccl_pin_path)
def checkout_nccl() -> None:
release_tag = read_nccl_pin()
print(f"-- Checkout nccl release tag: {release_tag}")
nccl_basedir = third_party_path / "nccl"
if not nccl_basedir.exists():
_checkout_by_tag("https://github.com/NVIDIA/nccl", release_tag)
def checkout_eigen() -> None:
eigen_tag = _read_file(third_party_path / "eigen_pin.txt")
print(f"-- Checkout Eigen release tag: {eigen_tag}")
eigen_basedir = third_party_path / "eigen"
if not eigen_basedir.exists():
_checkout_by_tag("https://gitlab.com/libeigen/eigen", eigen_tag)
if __name__ == "__main__":
import sys
if len(sys.argv) == 1:
# If no arguments are given checkout all optional dependency
checkout_nccl()
checkout_eigen()
else:
# Otherwise just call top-level function of choice
globals()[sys.argv[1]]()