mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add CUDA_KERNEL_ASSERT_PRINTF
, a more flexible CUDA_KERNEL_ASSERT_MSG
(#160129)
This new assertion helper bundles a printf call with the assertion. The goal is to make changes to instrument asserts with device-side information more intuitive and less error-prone. (See the printf call in ATen/native/cuda/Repeat.cu.) Parametrized error messages are a substantial improvement in debuggability because they show the mismatched device-side values. This lets us avoid a whole cycle of rebuilding + re-running failing training workflows. We include file, line number, function, and failing condition in the printf (along with the message provided by the user). The format matches the format of the message output by `__assert_fail`. There's also an easy-to-grep-for keyword `CUDA_KERNEL_ASSERT` in the message. I'm following the existing patterns of arch-specific macros - e.g., on ROCm, this is just a call to abort(), just like the other `CUDA_KERNEL_ASSERT*` variations. I'd appreciate any thoughts on architecture-specific testing (most likely on the OSS side). # Alternatives * We could just update `CUDA_KERNEL_ASSERT_MSG`. That would mean introducing `printf` calls from the kernel where there weren't any before, though. This seems like a bad idea because of the performance sensitivity. * If we want to move more slowly here, I could instrument more `CUDA_KERNEL_ASSERT` callsites without a macro, similar to https://github.com/pytorch/pytorch/pull/157996. But the main downside here is the performance hit, so let's have an organized way of doing it first. # Risks/Problems * We're shoving a lot of stuff into this printf. If a filename (at compile-time) contains `%s`, we will end up dereferencing whatever value was pushed in. On a CPU this can cause a segfault. I don't know how it behaves on a GPU. * Adding printf calls can have a performance impact because of increased register and stack usage. I did not see this play out in practice (see "benchmarks" below). However, there are changes to the generated PTX that could result in performance problems later (see "changes in generated PTX" below). # Benchmarks * I ran the following benchmarks a several times on a host with an A100: https://gist.github.com/mjkatmeta/e5494d949204a2afe2d43c452b99424f * Results are here -- I couldn't find a significant difference before or after https://gist.github.com/mjkatmeta/0f99ec27bb91214fb2cc7f612938d431 # Change in generated PTX This is the easiest way I found to run nvcc over just Repeat.cu (this is a buck2 target that includes just a copy of Repeat.cu): ``` buck2 build --show-output scripts/mjk/ai_training/cuda_benchmarks:repeat_cuda # then use the printed .so file like this: ~/fbsource/third-party/cuda/cuda_12.8.0/x64-linux/bin/cuobjdump -ptx ../buck-out/v2/gen/fbcode/028bde1acfaba823/scripts/mjk/ai_training/cuda_benchmarks/__repeat_cuda__/libscripts_mjk_ai_training_cuda_benchmarks_repeat_cuda.so ``` ## with printf This is the version of the code that appears in this diff: https://gist.github.com/mjkatmeta/5d18d48282d46b2240d946b335052b9a ## without printf I recompiled, replacing `CUDA_KERNEL_ASSERT_PRINTF(...)` in Repeat.cu with: ``` CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]); ``` https://gist.github.com/mjkatmeta/480df4b3a122e7b326554dd15ebb7c9d (Both of these are annotated with `// CHAR ARRAY:` comments to make the string constants easier to read.) Test Plan: Running this minimal test case: ``` import torch def main(): x = torch.ones(10, dtype=torch.int64, device="cuda:0") torch.repeat_interleave(x, x, output_size=0) ``` Now we see the new message (from printf) alongside the assert failure: ``` $ buck2 run fbcode//scripts/darshanr/repeat_interleave_errors:repeat_interleave_errors [...] [CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/Repeat.cu:25: compute_cuda_kernel: block: [0,0,0], thread: [31,0,0]: Assertion failed: `result_size == cumsum_ptr[size - 1]`: Invalid input! In `repeat_interleave`, the `output_size` argument (0) must be the same as the sum of the elements in the `repeats` tensor (10). fbcode/caffe2/aten/src/ATen/native/cuda/Repeat.cu:25: compute_cuda_kernel: block: [0,0,0], thread: [384,0,0] Assertion `result_size == cumsum_ptr[size - 1]` failed. [...[ ``` Rollback Plan: Reviewed By: mradmila Differential Revision: D79310684 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160129 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d08cabe314
commit
e900a274e5
@ -17,12 +17,11 @@ __global__ static void compute_cuda_kernel(
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
int64_t result_size) {
|
||||
if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) {
|
||||
printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] "
|
||||
CUDA_KERNEL_ASSERT_PRINTF(
|
||||
result_size == cumsum_ptr[size - 1],
|
||||
"Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n",
|
||||
__FILE__, __LINE__, __func__,blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, result_size, cumsum_ptr[size - 1 ]);
|
||||
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1])
|
||||
}
|
||||
result_size,
|
||||
cumsum_ptr[size - 1]);
|
||||
|
||||
int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
|
||||
|
@ -359,6 +359,7 @@ static inline int C10_WARP_SIZE_INTERNAL() {
|
||||
// Those platforms do not support assert()
|
||||
#define CUDA_KERNEL_ASSERT(cond)
|
||||
#define CUDA_KERNEL_ASSERT_MSG(cond, msg)
|
||||
#define CUDA_KERNEL_ASSERT_PRINTF(cond, msg, ...)
|
||||
#define SYCL_KERNEL_ASSERT(cond)
|
||||
#elif defined(_MSC_VER)
|
||||
#if defined(NDEBUG)
|
||||
@ -396,6 +397,26 @@ __host__ __device__
|
||||
static_cast<unsigned>(__LINE__)), \
|
||||
0); \
|
||||
}
|
||||
#define CUDA_KERNEL_ASSERT_PRINTF(cond, msg, ...) \
|
||||
if (C10_UNLIKELY(!(cond))) { \
|
||||
(void)(printf( \
|
||||
"[CUDA_KERNEL_ASSERT] " __FILE__ ":" C10_STRINGIZE( \
|
||||
__LINE__) ": %s: block: [%d,%d,%d], thread: [%d,%d,%d]: " \
|
||||
"Assertion failed: `" #cond "`: " msg "\n", \
|
||||
__func__, \
|
||||
blockIdx.x, \
|
||||
blockIdx.y, \
|
||||
blockIdx.z, \
|
||||
threadIdx.x, \
|
||||
threadIdx.y, \
|
||||
threadIdx.z, \
|
||||
##__VA_ARGS__)); \
|
||||
(void)(_wassert( \
|
||||
_CRT_WIDE(#cond), \
|
||||
_CRT_WIDE(__FILE__), \
|
||||
static_cast<unsigned>(__LINE__)), \
|
||||
0); \
|
||||
}
|
||||
#define SYCL_KERNEL_ASSERT(cond) \
|
||||
if (C10_UNLIKELY(!(cond))) { \
|
||||
(void)(_wassert( \
|
||||
@ -455,6 +476,10 @@ __host__ __device__
|
||||
if C10_UNLIKELY (!(cond)) { \
|
||||
abort(); \
|
||||
}
|
||||
#define CUDA_KERNEL_ASSERT_PRINTF(cond, msg, ...) \
|
||||
if C10_UNLIKELY (!(cond)) { \
|
||||
abort(); \
|
||||
}
|
||||
#define SYCL_KERNEL_ASSERT(cond) \
|
||||
if C10_UNLIKELY (!(cond)) { \
|
||||
abort(); \
|
||||
@ -470,6 +495,23 @@ __host__ __device__
|
||||
__assert_fail( \
|
||||
msg, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
|
||||
}
|
||||
#define CUDA_KERNEL_ASSERT_PRINTF(cond, msg, ...) \
|
||||
if (C10_UNLIKELY(!(cond))) { \
|
||||
printf( \
|
||||
"[CUDA_KERNEL_ASSERT] " __FILE__ ":" C10_STRINGIZE( \
|
||||
__LINE__) ": %s: block: [%d,%d,%d], thread: [%d,%d,%d]: " \
|
||||
"Assertion failed: `" #cond "`: " msg "\n", \
|
||||
__func__, \
|
||||
blockIdx.x, \
|
||||
blockIdx.y, \
|
||||
blockIdx.z, \
|
||||
threadIdx.x, \
|
||||
threadIdx.y, \
|
||||
threadIdx.z, \
|
||||
##__VA_ARGS__); \
|
||||
__assert_fail( \
|
||||
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
|
||||
}
|
||||
#define SYCL_KERNEL_ASSERT(cond) \
|
||||
if (C10_UNLIKELY(!(cond))) { \
|
||||
__assert_fail( \
|
||||
|
Reference in New Issue
Block a user