From 9d1ab4f4bb508a72c7f549f0b5219c4601944ba1 Mon Sep 17 00:00:00 2001 From: Ken Date: Sat, 4 Oct 2025 02:39:04 +0000 Subject: [PATCH] [CI] Limit Numba CUDA-13 patch to CUDA environments only (#164607) The patch introduced in https://github.com/pytorch/pytorch/pull/163111 caused issues in ROCm environments. This change guards the patching logic to CUDA environments only, thus ameliorating test failures in ROCm environments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164607 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .ci/pytorch/test.sh | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 46cae66cf512..283a048e77bc 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -34,12 +34,14 @@ fi # Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878 -NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true) -if [ -n "$NUMBA_CUDA_DIR" ]; then - NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch" - pushd "$NUMBA_CUDA_DIR" - patch -p4 <"$NUMBA_PATCH" - popd +if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then + NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true) + if [ -n "$NUMBA_CUDA_DIR" ]; then + NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch" + pushd "$NUMBA_CUDA_DIR" + patch -p4 <"$NUMBA_PATCH" + popd + fi fi echo "Environment variables:"