mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Only add triton dependency to CUDA and ROCm binaries if it hasn't been set as an installation requirement yet (#108424)
The dependency was added twice before in CUDA and ROCm binaries, one as an installation dependency from builder and the later as an extra dependency for dynamo, for example: ``` Requires-Python: >=3.8.0 Description-Content-Type: text/markdown License-File: LICENSE License-File: NOTICE Requires-Dist: filelock Requires-Dist: typing-extensions Requires-Dist: sympy Requires-Dist: networkx Requires-Dist: jinja2 Requires-Dist: fsspec Requires-Dist: pytorch-triton (==2.1.0+e6216047b8) Provides-Extra: dynamo Requires-Dist: pytorch-triton (==2.1.0+e6216047b8) ; extra == 'dynamo' Requires-Dist: jinja2 ; extra == 'dynamo' Provides-Extra: opt-einsum Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum' ``` In the previous release, we needed to remove this part from `setup.py` to build release binaries https://github.com/pytorch/pytorch/pull/96010. With this, that step isn't needed anymore because the dependency will come from builder. ### Testing Using the draft https://github.com/pytorch/pytorch/pull/108374 for testing and manually inspect the wheels artifact at https://github.com/pytorch/pytorch/actions/runs/6045878399 (don't want to go through all `ciflow/binaries` again) * torch-2.1.0.dev20230901+cu121-cp39-cp39-linux_x86_64 ``` Requires-Python: >=3.8.0 Description-Content-Type: text/markdown Requires-Dist: filelock Requires-Dist: typing-extensions Requires-Dist: sympy Requires-Dist: networkx Requires-Dist: jinja2 Requires-Dist: fsspec Requires-Dist: pytorch-triton (==2.1.0+e6216047b8) <-- This will be 2.1.0 on the release branch after https://github.com/pytorch/builder/pull/1515 Provides-Extra: dynamo Requires-Dist: jinja2 ; extra == 'dynamo' Provides-Extra: opt-einsum Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum' ``` * torch-2.1.0.dev20230901+cu121.with.pypi.cudnn-cp39-cp39-linux_x86_64 ``` Requires-Python: >=3.8.0 Description-Content-Type: text/markdown Requires-Dist: filelock Requires-Dist: typing-extensions Requires-Dist: sympy Requires-Dist: networkx Requires-Dist: jinja2 Requires-Dist: fsspec Requires-Dist: pytorch-triton (==2.1.0+e6216047b8) Requires-Dist: nvidia-cuda-nvrtc-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-cuda-runtime-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-cuda-cupti-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-cudnn-cu12 (==8.9.2.26) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-cublas-cu12 (==12.1.3.1) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-cufft-cu12 (==11.0.2.54) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-curand-cu12 (==10.3.2.106) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-cusolver-cu12 (==11.4.5.107) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-cusparse-cu12 (==12.1.0.106) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-nccl-cu12 (==2.18.1) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: nvidia-nvtx-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64" Requires-Dist: triton (==2.1.0) ; platform_system == "Linux" and platform_machine == "x86_64" <--This is 2.1.0 because it already has https://github.com/pytorch/pytorch/pull/108423, but the package doesn't exist yet atm Provides-Extra: dynamo Requires-Dist: jinja2 ; extra == 'dynamo' Provides-Extra: opt-einsum Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum' ``` * torch-2.1.0.dev20230901+rocm5.6-cp38-cp38-linux_x86_64 ``` Requires-Python: >=3.8.0 Description-Content-Type: text/markdown Requires-Dist: filelock Requires-Dist: typing-extensions Requires-Dist: sympy Requires-Dist: networkx Requires-Dist: jinja2 Requires-Dist: fsspec Requires-Dist: pytorch-triton-rocm (==2.1.0+34f8189eae) <-- This will be 2.1.0 on the release branch after https://github.com/pytorch/builder/pull/1515 Provides-Extra: dynamo Requires-Dist: jinja2 ; extra == 'dynamo' Provides-Extra: opt-einsum Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/108424 Approved by: https://github.com/atalman
This commit is contained in:
75
setup.py
75
setup.py
@ -1072,6 +1072,50 @@ def configure_extension_build():
|
||||
return extensions, cmdclass, packages, entry_points, extra_install_requires
|
||||
|
||||
|
||||
def add_triton(install_requires, extras_require) -> None:
|
||||
"""
|
||||
Add triton package as a dependency when it's needed
|
||||
"""
|
||||
# NB: If the installation requirments list already includes triton dependency,
|
||||
# there is no need to add it one more time as an extra dependency. In nightly
|
||||
# or when release PyTorch, that is done by setting PYTORCH_EXTRA_INSTALL_REQUIREMENTS
|
||||
# environment variable on pytorch/builder
|
||||
has_triton = any("triton" in pkg for pkg in install_requires)
|
||||
if has_triton:
|
||||
return
|
||||
|
||||
cmake_cache_vars = get_cmake_cache_vars()
|
||||
use_rocm = cmake_cache_vars["USE_ROCM"]
|
||||
use_cuda = cmake_cache_vars["USE_CUDA"]
|
||||
|
||||
# Triton is only needed for CUDA or ROCm
|
||||
if not use_rocm and not use_cuda:
|
||||
return
|
||||
|
||||
if use_rocm:
|
||||
triton_text_file = "triton-rocm.txt"
|
||||
triton_package_name = "pytorch-triton-rocm"
|
||||
else:
|
||||
triton_text_file = "triton.txt"
|
||||
triton_package_name = "pytorch-triton"
|
||||
triton_pin_file = os.path.join(
|
||||
cwd, ".ci", "docker", "ci_commit_pins", triton_text_file
|
||||
)
|
||||
triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt")
|
||||
|
||||
if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file):
|
||||
with open(triton_pin_file) as f:
|
||||
triton_pin = f.read().strip()
|
||||
with open(triton_version_file) as f:
|
||||
triton_version = f.read().strip()
|
||||
|
||||
if "dynamo" not in extras_require:
|
||||
extras_require["dynamo"] = []
|
||||
extras_require["dynamo"].append(
|
||||
triton_package_name + "==" + triton_version + "+" + triton_pin[:10]
|
||||
)
|
||||
|
||||
|
||||
# post run, warnings, printed at the end to make them more visible
|
||||
build_update_message = """
|
||||
It is no longer necessary to use the 'build' or 'rebuild' targets
|
||||
@ -1105,29 +1149,6 @@ def main():
|
||||
"fsspec",
|
||||
]
|
||||
|
||||
extras_require = {"opt-einsum": ["opt-einsum>=3.3"]}
|
||||
if platform.system() == "Linux":
|
||||
cmake_cache_vars = get_cmake_cache_vars()
|
||||
if cmake_cache_vars["USE_ROCM"]:
|
||||
triton_text_file = "triton-rocm.txt"
|
||||
triton_package_name = "pytorch-triton-rocm"
|
||||
else:
|
||||
triton_text_file = "triton.txt"
|
||||
triton_package_name = "pytorch-triton"
|
||||
triton_pin_file = os.path.join(
|
||||
cwd, ".ci", "docker", "ci_commit_pins", triton_text_file
|
||||
)
|
||||
triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt")
|
||||
if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file):
|
||||
with open(triton_pin_file) as f:
|
||||
triton_pin = f.read().strip()
|
||||
with open(triton_version_file) as f:
|
||||
triton_version = f.read().strip()
|
||||
extras_require["dynamo"] = [
|
||||
triton_package_name + "==" + triton_version + "+" + triton_pin[:10],
|
||||
"jinja2",
|
||||
]
|
||||
|
||||
# Parse the command line and check the arguments before we proceed with
|
||||
# building deps and setup. We need to set values so `--help` works.
|
||||
dist = Distribution()
|
||||
@ -1153,6 +1174,14 @@ def main():
|
||||
|
||||
install_requires += extra_install_requires
|
||||
|
||||
extras_require = {
|
||||
"opt-einsum": ["opt-einsum>=3.3"],
|
||||
}
|
||||
# Triton is only available on Linux atm
|
||||
if platform.system() == "Linux":
|
||||
extras_require["dynamo"] = ["jinja2"]
|
||||
add_triton(install_requires=install_requires, extras_require=extras_require)
|
||||
|
||||
# Read in README.md for our long_description
|
||||
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
Reference in New Issue
Block a user