[aarch64] fix TORCH_CUDA_ARCH_LIST for cuda arm build (#144436)

Fixes #144037

Root cause is CUDA ARM build did not call `.ci/manywheel/build_cuda.sh`, but calls `.ci/aarch64_linux/aarch64_ci_build.sh `instead. Therefore, https://github.com/pytorch/pytorch/blob/main/.ci/manywheel/build_cuda.sh#L56 was not called for CUDA ARM build.

Adding the equivalent of the code to `.ci/aarch64_linux/aarch64_ci_build.sh` as a WAR.

In the future, we should target to integrate the files in  .ci/aarch64_linux/aarch64_ci_build.sh back to .ci/manywheel/build_cuda.sh.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144436
Approved by: https://github.com/atalman
This commit is contained in:
Ting Lu
2025-01-11 09:00:46 +00:00
committed by PyTorch MergeBot
parent e1d0a2ff30
commit b7bef1ca84
2 changed files with 4 additions and 13 deletions

View File

@ -3,6 +3,9 @@ set -eux -o pipefail
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
# cuda arm build for Grace Hopper solely
export TORCH_CUDA_ARCH_LIST="9.0"
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
source $SCRIPTPATH/aarch64_ci_setup.sh

View File

@ -53,22 +53,10 @@ cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6"
case ${CUDA_VERSION} in
12.6)
if [[ "$GPU_ARCH_TYPE" = "cuda-aarch64" ]]; then
TORCH_CUDA_ARCH_LIST="9.0"
else
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0+PTX"
fi
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0+PTX"
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
;;
12.4)
if [[ "$GPU_ARCH_TYPE" = "cuda-aarch64" ]]; then
TORCH_CUDA_ARCH_LIST="9.0"
else
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
fi
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
;;
12.1)
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
;;