mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CI] Limit Numba CUDA-13 patch to CUDA environments only (#164607)
The patch introduced in https://github.com/pytorch/pytorch/pull/163111 caused issues in ROCm environments. This change guards the patching logic to CUDA environments only, thus ameliorating test failures in ROCm environments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164607 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@ -34,12 +34,14 @@ fi
|
||||
|
||||
|
||||
# Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878
|
||||
NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true)
|
||||
if [ -n "$NUMBA_CUDA_DIR" ]; then
|
||||
NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch"
|
||||
pushd "$NUMBA_CUDA_DIR"
|
||||
patch -p4 <"$NUMBA_PATCH"
|
||||
popd
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
||||
NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true)
|
||||
if [ -n "$NUMBA_CUDA_DIR" ]; then
|
||||
NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch"
|
||||
pushd "$NUMBA_CUDA_DIR"
|
||||
patch -p4 <"$NUMBA_PATCH"
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Environment variables:"
|
||||
|
Reference in New Issue
Block a user