mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] Disabling Kernel Asserts for ROCm by default - fix and clean up and refactoring (#114660)
Related to #103973 #110532 #108404 #94891
**Context:**
As commented in 6ae0554d11/cmake/Dependencies.cmake (L1198)
Kernel asserts are enabled by default for CUDA and disabled for ROCm.
However it is somewhat broken, and Kernel assert was still enabled for ROCm.
Disabling kernel assert is also needed for users who do not have PCIe atomics support. These community users have verified that disabling the kernel assert in PyTorch/ROCm platform fixed their pytorch workflow, like torch.sum script, stable-diffusion. (see the related issues)
**Changes:**
This pull request serves the following purposes:
* Refactor and clean up the logic, make it simpler for ROCm to enable and disable Kernel Asserts
* Fix the bug that Kernel Asserts for ROCm was not disabled by default.
Specifically,
- Renamed `TORCH_DISABLE_GPU_ASSERTS` to `C10_USE_ROCM_KERNEL_ASSERT` for the following reasons:
(1) This variable only applies to ROCm.
(2) The new name is more align with #define CUDA_KERNEL_ASSERT function.
(3) With USE_ in front of the name, we can easily control it with environment variable to turn on and off this feature during build (e.g. `USE_ROCM_KERNEL_ASSERT=1 python setup.py develop` will enable kernel assert for ROCm build).
- Get rid of the `ROCM_FORCE_ENABLE_GPU_ASSERTS' to simplify the logic and make it easier to understand and maintain
- Added `#cmakedefine` to carry over the CMake variable to C++
**Tests:**
(1) build with default mode and verify that USE_ROCM_KERNEL_ASSERT is OFF(0), and kernel assert is disabled:
```
python setup.py develop
```
Verify CMakeCache.txt has correct value.
```
/xxxx/pytorch/build$ grep USE_ROCM_KERNEL_ASSERT CMakeCache.txt
USE_ROCM_KERNEL_ASSERT:BOOL=0
```
Tested the following code in ROCm build and CUDA build, and expected the return code differently.
```
subprocess.call([sys.executable, '-c', "import torch;torch._assert_async(torch.tensor(0,device='cuda'));torch.cuda.synchronize()"])
```
This piece of code is adapted from below unit test to get around the limitation that this unit test now was skipped for ROCm. (We will check to enable this unit test in the future)
```
python test/test_cuda_expandable_segments.py -k test_fixed_cuda_assert_async
```
Ran the following script, expecting r ==0 since the CUDA_KERNEL_ASSERT is defined as nothing:
```
>> import sys
>>> import subprocess
>>> r=subprocess.call([sys.executable, '-c', "import torch;torch._assert_async(torch.tensor(0,device='cuda'));torch.cuda.synchronize()"])
>>> r
0
```
(2) Enable the kernel assert by building with USE_ROCM_KERNEL_ASSERT=1, or USE_ROCM_KERNEL_ASSERT=ON
```
USE_ROCM_KERNEL_ASSERT=1 python setup.py develop
```
Verify `USE_ROCM_KERNEL_ASSERT` is `1`
```
/xxxx/pytorch/build$ grep USE_ROCM_KERNEL_ASSERT CMakeCache.txt
USE_ROCM_KERNEL_ASSERT:BOOL=1
```
Run the assert test, and expected return code not equal to 0.
```
>> import sys
>>> import subprocess
>>> r=subprocess.call([sys.executable, '-c', "import torch;torch._assert_async(torch.tensor(0,device='cuda'));torch.cuda.synchronize()"])
>>>/xxxx/pytorch/aten/src/ATen/native/hip/TensorCompare.hip:108: _assert_async_cuda_kernel: Device-side assertion `input[0] != 0' failed.
:0:rocdevice.cpp :2690: 2435301199202 us: [pid:206019 tid:0x7f6cf0a77700] Callback: Queue 0x7f64e8400000 aborting with error : HSA_STATUS_ERROR_EXCEPTION: An HSAIL operation resulted in a hardware exception. code: 0x1016
>>> r
-6
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114660
Approved by: https://github.com/jeffdaily, https://github.com/malfet, https://github.com/jithunnair-amd
This commit is contained in:
committed by
PyTorch MergeBot
parent
fb80f05ee2
commit
66a76516bf
@ -288,7 +288,7 @@ option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the
|
||||
option(USE_XNNPACK "Use XNNPACK" ON)
|
||||
option(USE_ZMQ "Use ZMQ" OFF)
|
||||
option(USE_ZSTD "Use ZSTD" OFF)
|
||||
option(TORCH_DISABLE_GPU_ASSERTS "Disable GPU asserts by default" OFF)
|
||||
option(USE_ROCM_KERNEL_ASSERT "Use Kernel Assert for ROCm" OFF)
|
||||
# Ensure that an ITT build is the default for x86 CPUs
|
||||
cmake_dependent_option(
|
||||
USE_ITT "Use Intel(R) VTune Profiler ITT functionality" ON
|
||||
|
@ -24,7 +24,7 @@ def define_targets(rules):
|
||||
"CAFFE2_USE_CUDNN",
|
||||
"USE_MKLDNN",
|
||||
"CAFFE2_USE_ITT",
|
||||
"TORCH_DISABLE_GPU_ASSERTS",
|
||||
"USE_ROCM_KERNEL_ASSERT",
|
||||
"EIGEN_MPL2_ONLY",
|
||||
],
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ set(C10_USE_GLOG ${USE_GLOG}) # used in cmake_macros.h.in
|
||||
set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in
|
||||
set(C10_USE_NUMA ${USE_NUMA})
|
||||
set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME})
|
||||
set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT})
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/macros/cmake_macros.h.in
|
||||
${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h)
|
||||
|
@ -374,9 +374,7 @@ extern SYCL_EXTERNAL void __assert_fail(
|
||||
unsigned int line,
|
||||
const char* func);
|
||||
#else // __SYCL_DEVICE_ONLY__
|
||||
#if ( \
|
||||
defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__)) && \
|
||||
!defined(TORCH_DISABLE_GPU_ASSERTS))
|
||||
#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__)))
|
||||
// CUDA supports __assert_fail function which are common for both device
|
||||
// and host side code.
|
||||
__host__ __device__
|
||||
@ -393,18 +391,14 @@ __host__ __device__
|
||||
unsigned int line,
|
||||
const char* function) noexcept __attribute__((__noreturn__));
|
||||
|
||||
#if (defined(__HIP_ARCH__) || defined(__HIP__)) && \
|
||||
!defined(TORCH_DISABLE_GPU_ASSERTS)
|
||||
// ROCm supports __assert_fail only as a device side function.
|
||||
__device__ __attribute__((noinline)) __attribute__((weak)) void __assert_fail(
|
||||
const char* assertion,
|
||||
const char* file,
|
||||
unsigned int line,
|
||||
const char* function);
|
||||
#endif // defined(__HIP_ARCH__) || defined(__HIP__)
|
||||
#endif // __SYCL_DEVICE_ONLY__
|
||||
}
|
||||
#endif // NDEBUG
|
||||
// ROCm disable kernel assert by default
|
||||
#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM)
|
||||
#define CUDA_KERNEL_ASSERT(cond)
|
||||
#define SYCL_KERNEL_ASSERT(cond)
|
||||
#else
|
||||
#define CUDA_KERNEL_ASSERT(cond) \
|
||||
if (C10_UNLIKELY(!(cond))) { \
|
||||
__assert_fail( \
|
||||
@ -415,6 +409,7 @@ __device__ __attribute__((noinline)) __attribute__((weak)) void __assert_fail(
|
||||
__assert_fail( \
|
||||
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
|
||||
}
|
||||
#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM
|
||||
#endif // __APPLE__
|
||||
|
||||
#ifdef __APPLE__
|
||||
|
@ -9,5 +9,6 @@
|
||||
#cmakedefine C10_USE_GFLAGS
|
||||
#cmakedefine C10_USE_NUMA
|
||||
#cmakedefine C10_USE_MSVC_STATIC_RUNTIME
|
||||
#cmakedefine C10_USE_ROCM_KERNEL_ASSERT
|
||||
|
||||
#endif // C10_MACROS_CMAKE_MACROS_H_
|
||||
|
@ -104,6 +104,7 @@ def define_ovrsource_targets():
|
||||
("#cmakedefine C10_BUILD_SHARED_LIBS", ""),
|
||||
("#cmakedefine C10_USE_NUMA", ""),
|
||||
("#cmakedefine C10_USE_MSVC_STATIC_RUNTIME", ""),
|
||||
("#cmakedefine C10_USE_ROCM_KERNEL_ASSERT", ""),
|
||||
]
|
||||
|
||||
mobile_c10_cmake_defines = [
|
||||
|
@ -26,13 +26,13 @@
|
||||
#cmakedefine CAFFE2_USE_NVTX
|
||||
#cmakedefine CAFFE2_USE_ITT
|
||||
#cmakedefine CAFFE2_USE_TRT
|
||||
#cmakedefine TORCH_DISABLE_GPU_ASSERTS
|
||||
|
||||
#ifndef EIGEN_MPL2_ONLY
|
||||
#cmakedefine EIGEN_MPL2_ONLY
|
||||
#endif
|
||||
|
||||
// Useful build settings that are recorded in the compiled binary
|
||||
// torch.__build__.show()
|
||||
#define CAFFE2_BUILD_STRINGS { \
|
||||
{"TORCH_VERSION", "${TORCH_VERSION}"}, \
|
||||
{"CXX_COMPILER", "${CMAKE_CXX_COMPILER}"}, \
|
||||
@ -68,5 +68,5 @@
|
||||
{"USE_NVTX", "${CAFFE2_USE_NVTX}"}, \
|
||||
{"USE_ITT", "${CAFFE2_USE_ITT}"}, \
|
||||
{"USE_TRT", "${CAFFE2_USE_TRT}"}, \
|
||||
{"TORCH_DISABLE_GPU_ASSERTS", "${TORCH_DISABLE_GPU_ASSERTS}"}, \
|
||||
{"USE_ROCM_KERNEL_ASSERT", "${USE_ROCM_KERNEL_ASSERT}"}, \
|
||||
}
|
||||
|
@ -1192,16 +1192,6 @@ if(ANDROID)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS log)
|
||||
endif()
|
||||
|
||||
# ---[ Kernel asserts
|
||||
# Kernel asserts are enabled by default for CUDA and disabled for ROCm.
|
||||
# For ROCm, it can be enabled by setting ROCM_FORCE_ENABLE_GPU_ASSERTS
|
||||
if(USE_ROCM AND ROCM_FORCE_ENABLE_GPU_ASSERTS)
|
||||
message(STATUS "Forcefully enabling kernel asserts on ROCM")
|
||||
elseif(USE_ROCM AND NOT ROCM_FORCE_ENABLE_GPU_ASSERTS)
|
||||
message(STATUS "Disabling kernel asserts for ROCm")
|
||||
caffe2_update_option(TORCH_DISABLE_GPU_ASSERTS ON)
|
||||
endif()
|
||||
|
||||
# ---[ LLVM
|
||||
if(USE_LLVM)
|
||||
message(STATUS "Looking for LLVM in ${USE_LLVM}")
|
||||
@ -1249,6 +1239,7 @@ if(USE_ROCM)
|
||||
caffe2_update_option(USE_SYSTEM_NCCL ON)
|
||||
endif()
|
||||
|
||||
|
||||
list(APPEND HIP_CXX_FLAGS -fPIC)
|
||||
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
|
||||
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
|
||||
@ -1291,6 +1282,15 @@ if(USE_ROCM)
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
|
||||
|
||||
# ---[ Kernel asserts
|
||||
# Kernel asserts is disabled for ROCm by default.
|
||||
# It can be turned on by turning on the env USE_ROCM_KERNEL_ASSERT to the build system.
|
||||
if(USE_ROCM_KERNEL_ASSERT)
|
||||
message(STATUS "Enabling Kernel Assert for ROCm")
|
||||
else()
|
||||
message(STATUS "Disabling Kernel Assert for ROCm")
|
||||
endif()
|
||||
|
||||
else()
|
||||
caffe2_update_option(USE_ROCM OFF)
|
||||
endif()
|
||||
|
@ -198,5 +198,5 @@ function(caffe2_print_configuration_summary)
|
||||
# coreml
|
||||
message(STATUS " USE_COREML_DELEGATE : ${USE_COREML_DELEGATE}")
|
||||
message(STATUS " BUILD_LAZY_TS_BACKEND : ${BUILD_LAZY_TS_BACKEND}")
|
||||
message(STATUS " TORCH_DISABLE_GPU_ASSERTS : ${TORCH_DISABLE_GPU_ASSERTS}")
|
||||
message(STATUS " USE_ROCM_KERNEL_ASSERT : ${USE_ROCM_KERNEL_ASSERT}")
|
||||
endfunction()
|
||||
|
3
setup.py
3
setup.py
@ -160,6 +160,9 @@
|
||||
# USE_ZSTD
|
||||
# Enables use of ZSTD, if the libraries are found
|
||||
#
|
||||
# USE_ROCM_KERNEL_ASSERT=1
|
||||
# Enable kernel assert in ROCm platform
|
||||
#
|
||||
# Environment variables we respect (these environment variables are
|
||||
# conventional and are often understood/set by other software.)
|
||||
#
|
||||
|
Reference in New Issue
Block a user