diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 46cae66cf512..283a048e77bc 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -34,12 +34,14 @@ fi # Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878 -NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true) -if [ -n "$NUMBA_CUDA_DIR" ]; then - NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch" - pushd "$NUMBA_CUDA_DIR" - patch -p4 <"$NUMBA_PATCH" - popd +if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then + NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true) + if [ -n "$NUMBA_CUDA_DIR" ]; then + NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch" + pushd "$NUMBA_CUDA_DIR" + patch -p4 <"$NUMBA_PATCH" + popd + fi fi echo "Environment variables:"