mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
build: set -DNDEBUG in Release (#32719)
Summary: This might lead to silent undefined behaviour (e.g. with out-of-bound indices). This affects `test_multinomial_invalid_probs_cuda` which is now removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32719 Test Plan: * Build with VERBOSE=1 and manually inspect `less ndebug.build.log | grep 'c++' | grep -v -- -DNDEBUG` (only with nina on Linux) * CI Fixes https://github.com/pytorch/pytorch/issues/22745 Differential Revision: D20104340 Pulled By: yf225 fbshipit-source-id: 2ebfd7ddae632258a36316999eeb5c968fb7642c
This commit is contained in:
committed by
Facebook Github Bot
parent
93e30c16cb
commit
8aa09de19e
@ -35,13 +35,13 @@ __global__ void renormRowsL1(scalar_t* dist, long rows, long cols) {
|
||||
scalar_t sum = static_cast<scalar_t>(0);
|
||||
for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
|
||||
val = dist[row * cols + col];
|
||||
assert(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
|
||||
CUDA_ALWAYS_ASSERT(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
|
||||
sum = sum + val;
|
||||
}
|
||||
|
||||
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<scalar_t>(), zero);
|
||||
if (threadIdx.x == 0) {
|
||||
assert(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
|
||||
CUDA_ALWAYS_ASSERT(!THCNumerics<scalar_t>::lt(val, zero)); // ! < 0 for NaN handling
|
||||
smem[0] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
@ -61,7 +61,7 @@ void renormRows(Tensor& t) {
|
||||
int64_t cols = t.size(1);
|
||||
|
||||
auto props = at::cuda::getCurrentDeviceProperties();
|
||||
assert(props != NULL);
|
||||
CUDA_ALWAYS_ASSERT(props != NULL);
|
||||
int numSM = props->multiProcessorCount;
|
||||
int maxThreads = props->maxThreadsPerBlock;
|
||||
|
||||
@ -84,7 +84,7 @@ __device__ int binarySearchForMultinomial(scalar_t* cumdist,
|
||||
int start = 0;
|
||||
int end = size;
|
||||
// cumdist[size - 1] = 0 => all zero prob dist
|
||||
assert(cumdist[size - 1] > static_cast<scalar_t>(0));
|
||||
CUDA_ALWAYS_ASSERT(cumdist[size - 1] > static_cast<scalar_t>(0));
|
||||
|
||||
while (end - start > 0) {
|
||||
int mid = start + (end - start) / 2;
|
||||
@ -240,9 +240,9 @@ sampleMultinomialOnce(int64_t* dest,
|
||||
scalar_t val;
|
||||
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
|
||||
val = dist[curDist * stride_dist + cat * stride_categories];
|
||||
assert(val >= zero);
|
||||
assert(!THCNumerics<scalar_t>::isinf(val));
|
||||
assert(!THCNumerics<scalar_t>::isnan(val));
|
||||
CUDA_ALWAYS_ASSERT(val >= zero);
|
||||
CUDA_ALWAYS_ASSERT(!THCNumerics<scalar_t>::isinf(val));
|
||||
CUDA_ALWAYS_ASSERT(!THCNumerics<scalar_t>::isnan(val));
|
||||
sum = sum + static_cast<accscalar_t>(val);
|
||||
}
|
||||
|
||||
@ -252,8 +252,8 @@ sampleMultinomialOnce(int64_t* dest,
|
||||
// Broadcast sum and sample value
|
||||
if (threadIdx.x == 0) {
|
||||
// Make sure the sum of our distribution didn't overflow
|
||||
assert(!THCNumerics<accscalar_t>::isinf(sum));
|
||||
assert(sum > accZero);
|
||||
CUDA_ALWAYS_ASSERT(!THCNumerics<accscalar_t>::isinf(sum));
|
||||
CUDA_ALWAYS_ASSERT(sum > accZero);
|
||||
|
||||
asmem[0] = sum;
|
||||
smem[0] = sampled[curDist];
|
||||
@ -363,7 +363,7 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self_v.scalar_type(), "multinomial_kernel_cuda", [&] {
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
auto props = at::cuda::getCurrentDeviceProperties();
|
||||
assert(props != NULL);
|
||||
CUDA_ALWAYS_ASSERT(props != NULL);
|
||||
int numSM = props->multiProcessorCount;
|
||||
int maxThreads = props->maxThreadsPerBlock;
|
||||
int maxShared = props->sharedMemPerBlock;
|
||||
|
@ -192,6 +192,10 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
|
||||
#define C10_WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER <= 1900
|
||||
#define __func__ __FUNCTION__
|
||||
#endif
|
||||
|
||||
// CUDA_KERNEL_ASSERT is a macro that wraps an assert() call inside cuda
|
||||
// kernels. This is not supported by Apple platforms so we special case it.
|
||||
// See http://docs.nvidia.com/cuda/cuda-c-programming-guide/#assertion
|
||||
@ -201,6 +205,35 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
|
||||
#define CUDA_KERNEL_ASSERT(...) assert(__VA_ARGS__)
|
||||
#endif // __APPLE__
|
||||
|
||||
// CUDA_ALWAYS_ASSERT is similar to CUDA_KERNEL_ASSERT but checks the assertion
|
||||
// even when NDEBUG is defined. This is useful for important assertions in CUDA
|
||||
// code that when building Release.
|
||||
#if defined(__APPLE__) || defined(__HIP_PLATFORM_HCC__)
|
||||
// Those platforms do not support assert()
|
||||
#define CUDA_ALWAYS_ASSERT(cond)
|
||||
#elif defined(_MSC_VER)
|
||||
// TODO: This should be defined but I don't have the environment to properly
|
||||
// test it. See e.g., https://github.com/pytorch/pytorch/pull/32719#discussion_r379918384
|
||||
#define CUDA_ALWAYS_ASSERT(cond)
|
||||
#else // __APPLE__, _MSC_VER
|
||||
#if defined(NDEBUG)
|
||||
extern "C" {
|
||||
[[noreturn]]
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) || defined(__HIP__)
|
||||
__host__ __device__
|
||||
#endif // __CUDA_ARCH__
|
||||
void __assert_fail(const char *assertion, const char *file,
|
||||
unsigned int line, const char *function)
|
||||
throw();
|
||||
}
|
||||
#endif // NDEBUG
|
||||
#define CUDA_ALWAYS_ASSERT(cond) \
|
||||
if (C10_UNLIKELY(!(cond))) { \
|
||||
__assert_fail(#cond, __FILE__, static_cast<unsigned int>(__LINE__), \
|
||||
__func__); \
|
||||
}
|
||||
#endif // __APPLE__
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <TargetConditionals.h>
|
||||
#endif
|
||||
|
@ -1257,16 +1257,21 @@ if (NOT INTERN_BUILD_MOBILE)
|
||||
MESSAGE(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor")
|
||||
ENDIF()
|
||||
|
||||
OPTION(NDEBUG "disable asserts (WARNING: this may result in silent UB e.g. with out-of-bound indices)")
|
||||
IF (NOT NDEBUG)
|
||||
STRING(APPEND CMAKE_C_FLAGS_RELEASE " -DNDEBUG")
|
||||
STRING(APPEND CMAKE_CXX_FLAGS_RELEASE " -DNDEBUG")
|
||||
IF (NOT GENERATOR_IS_MULTI_CONFIG)
|
||||
IF (${CMAKE_BUILD_TYPE} STREQUAL "Release")
|
||||
MESSAGE(STATUS "Adding -DNDEBUG to compile flags")
|
||||
STRING(APPEND CMAKE_C_FLAGS " -DNDEBUG")
|
||||
STRING(APPEND CMAKE_CXX_FLAGS " -DNDEBUG")
|
||||
ELSE()
|
||||
MESSAGE(STATUS "Removing -DNDEBUG from compile flags")
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_C_FLAGS "" ${CMAKE_C_FLAGS})
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_C_FLAGS_DEBUG "" ${CMAKE_C_FLAGS_DEBUG})
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_C_FLAGS_RELEASE "" ${CMAKE_C_FLAGS_RELEASE})
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_CXX_FLAGS "" ${CMAKE_CXX_FLAGS})
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_CXX_FLAGS_DEBUG "" ${CMAKE_CXX_FLAGS_DEBUG})
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_CXX_FLAGS_RELEASE "" ${CMAKE_CXX_FLAGS_RELEASE})
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_C_FLAGS_DEBUG "" ${CMAKE_C_FLAGS_DEBUG})
|
||||
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_CXX_FLAGS_DEBUG "" ${CMAKE_CXX_FLAGS_DEBUG})
|
||||
|
||||
SET(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE OFF)
|
||||
|
||||
|
Reference in New Issue
Block a user