mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add triton as dependency to CUDA aarch64 build (#149584)
Aarch64 Triton build was added by: https://github.com/pytorch/pytorch/pull/148705 Hence add proper contrain to CUDA 12.8 Aarch64 build Please note we want to still use: ```platform_system == 'Linux' and platform_machine == 'x86_64'``` For all other builds. Since these are prototype binaries only used by cuda 12.8 linux aarch64 build. Which we would like to serve from download.pytorch.org Pull Request resolved: https://github.com/pytorch/pytorch/pull/149584 Approved by: https://github.com/nWEIdia, https://github.com/tinglvv, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
80dfce2cc3
commit
9b1127437e
@ -74,6 +74,12 @@ TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
|
||||
|
||||
# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT
|
||||
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'"
|
||||
|
||||
# CUDA 12.8 builds have triton for Linux and Linux aarch64 binaries.
|
||||
if [[ "$DESIRED_CUDA" == cu128 ]]; then
|
||||
TRITON_CONSTRAINT="platform_system == 'Linux'"
|
||||
fi
|
||||
|
||||
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" && ! "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then
|
||||
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
|
||||
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
|
||||
|
Reference in New Issue
Block a user