Only add triton dependency to CUDA and ROCm binaries if it hasn't been set as an installation requirement yet (#108424)

The dependency was added twice before in CUDA and ROCm binaries, one as an installation dependency from builder and the later as an extra dependency for dynamo, for example:

```
Requires-Python: >=3.8.0
Description-Content-Type: text/markdown
License-File: LICENSE
License-File: NOTICE
Requires-Dist: filelock
Requires-Dist: typing-extensions
Requires-Dist: sympy
Requires-Dist: networkx
Requires-Dist: jinja2
Requires-Dist: fsspec
Requires-Dist: pytorch-triton (==2.1.0+e6216047b8)
Provides-Extra: dynamo
Requires-Dist: pytorch-triton (==2.1.0+e6216047b8) ; extra == 'dynamo'
Requires-Dist: jinja2 ; extra == 'dynamo'
Provides-Extra: opt-einsum
Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum'
```

In the previous release, we needed to remove this part from `setup.py` to build release binaries https://github.com/pytorch/pytorch/pull/96010.  With this, that step isn't needed anymore because the dependency will come from builder.

### Testing

Using the draft https://github.com/pytorch/pytorch/pull/108374 for testing and manually inspect the wheels artifact at https://github.com/pytorch/pytorch/actions/runs/6045878399 (don't want to go through all `ciflow/binaries` again)

* torch-2.1.0.dev20230901+cu121-cp39-cp39-linux_x86_64
```
Requires-Python: >=3.8.0
Description-Content-Type: text/markdown
Requires-Dist: filelock
Requires-Dist: typing-extensions
Requires-Dist: sympy
Requires-Dist: networkx
Requires-Dist: jinja2
Requires-Dist: fsspec
Requires-Dist: pytorch-triton (==2.1.0+e6216047b8) <-- This will be 2.1.0 on the release branch after https://github.com/pytorch/builder/pull/1515
Provides-Extra: dynamo
Requires-Dist: jinja2 ; extra == 'dynamo'
Provides-Extra: opt-einsum
Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum'
```

* torch-2.1.0.dev20230901+cu121.with.pypi.cudnn-cp39-cp39-linux_x86_64
```
Requires-Python: >=3.8.0
Description-Content-Type: text/markdown
Requires-Dist: filelock
Requires-Dist: typing-extensions
Requires-Dist: sympy
Requires-Dist: networkx
Requires-Dist: jinja2
Requires-Dist: fsspec
Requires-Dist: pytorch-triton (==2.1.0+e6216047b8)
Requires-Dist: nvidia-cuda-nvrtc-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cuda-runtime-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cuda-cupti-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cudnn-cu12 (==8.9.2.26) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cublas-cu12 (==12.1.3.1) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cufft-cu12 (==11.0.2.54) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-curand-cu12 (==10.3.2.106) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cusolver-cu12 (==11.4.5.107) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cusparse-cu12 (==12.1.0.106) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-nccl-cu12 (==2.18.1) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-nvtx-cu12 (==12.1.105) ; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: triton (==2.1.0) ; platform_system == "Linux" and platform_machine == "x86_64" <--This is 2.1.0 because it already has https://github.com/pytorch/pytorch/pull/108423, but the package doesn't exist yet atm
Provides-Extra: dynamo
Requires-Dist: jinja2 ; extra == 'dynamo'
Provides-Extra: opt-einsum
Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum'
```

* torch-2.1.0.dev20230901+rocm5.6-cp38-cp38-linux_x86_64
```
Requires-Python: >=3.8.0
Description-Content-Type: text/markdown
Requires-Dist: filelock
Requires-Dist: typing-extensions
Requires-Dist: sympy
Requires-Dist: networkx
Requires-Dist: jinja2
Requires-Dist: fsspec
Requires-Dist: pytorch-triton-rocm (==2.1.0+34f8189eae) <-- This will be 2.1.0 on the release branch after https://github.com/pytorch/builder/pull/1515
Provides-Extra: dynamo
Requires-Dist: jinja2 ; extra == 'dynamo'
Provides-Extra: opt-einsum
Requires-Dist: opt-einsum (>=3.3) ; extra == 'opt-einsum'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108424
Approved by: https://github.com/atalman
This commit is contained in:
Huy Do
2023-09-02 01:16:18 +00:00
committed by PyTorch MergeBot
parent 2e3fce5450
commit 4084d039b7

View File

@ -1072,6 +1072,50 @@ def configure_extension_build():
return extensions, cmdclass, packages, entry_points, extra_install_requires
def add_triton(install_requires, extras_require) -> None:
"""
Add triton package as a dependency when it's needed
"""
# NB: If the installation requirments list already includes triton dependency,
# there is no need to add it one more time as an extra dependency. In nightly
# or when release PyTorch, that is done by setting PYTORCH_EXTRA_INSTALL_REQUIREMENTS
# environment variable on pytorch/builder
has_triton = any("triton" in pkg for pkg in install_requires)
if has_triton:
return
cmake_cache_vars = get_cmake_cache_vars()
use_rocm = cmake_cache_vars["USE_ROCM"]
use_cuda = cmake_cache_vars["USE_CUDA"]
# Triton is only needed for CUDA or ROCm
if not use_rocm and not use_cuda:
return
if use_rocm:
triton_text_file = "triton-rocm.txt"
triton_package_name = "pytorch-triton-rocm"
else:
triton_text_file = "triton.txt"
triton_package_name = "pytorch-triton"
triton_pin_file = os.path.join(
cwd, ".ci", "docker", "ci_commit_pins", triton_text_file
)
triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt")
if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file):
with open(triton_pin_file) as f:
triton_pin = f.read().strip()
with open(triton_version_file) as f:
triton_version = f.read().strip()
if "dynamo" not in extras_require:
extras_require["dynamo"] = []
extras_require["dynamo"].append(
triton_package_name + "==" + triton_version + "+" + triton_pin[:10]
)
# post run, warnings, printed at the end to make them more visible
build_update_message = """
It is no longer necessary to use the 'build' or 'rebuild' targets
@ -1105,29 +1149,6 @@ def main():
"fsspec",
]
extras_require = {"opt-einsum": ["opt-einsum>=3.3"]}
if platform.system() == "Linux":
cmake_cache_vars = get_cmake_cache_vars()
if cmake_cache_vars["USE_ROCM"]:
triton_text_file = "triton-rocm.txt"
triton_package_name = "pytorch-triton-rocm"
else:
triton_text_file = "triton.txt"
triton_package_name = "pytorch-triton"
triton_pin_file = os.path.join(
cwd, ".ci", "docker", "ci_commit_pins", triton_text_file
)
triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt")
if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file):
with open(triton_pin_file) as f:
triton_pin = f.read().strip()
with open(triton_version_file) as f:
triton_version = f.read().strip()
extras_require["dynamo"] = [
triton_package_name + "==" + triton_version + "+" + triton_pin[:10],
"jinja2",
]
# Parse the command line and check the arguments before we proceed with
# building deps and setup. We need to set values so `--help` works.
dist = Distribution()
@ -1153,6 +1174,14 @@ def main():
install_requires += extra_install_requires
extras_require = {
"opt-einsum": ["opt-einsum>=3.3"],
}
# Triton is only available on Linux atm
if platform.system() == "Linux":
extras_require["dynamo"] = ["jinja2"]
add_triton(install_requires=install_requires, extras_require=extras_require)
# Read in README.md for our long_description
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
long_description = f.read()