mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Summary: The NAN CHECK is done through device side assert without copying needed from GPU to CPU Test Plan: Unit test for collectives that should experience run time error (sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (38f5143e)]$ python test/distributed/test_c10d_nccl.py ProcessGroupNCCLTest.test_nan_assert /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)` failed. [rank0]:[E507 17:31:56.885473996 Utils.cu:30] CUDA error during checkForNan: device-side assert triggered /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)` failed. [rank1]:[E507 17:31:56.128961534 Utils.cu:30] CUDA error during checkForNan: device-side assert triggered . ---------------------------------------------------------------------- Ran 1 test in 7.723s OK Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/125726 Approved by: https://github.com/kwen2501
46 lines
1.3 KiB
Plaintext
46 lines
1.3 KiB
Plaintext
#include <ATen/Dispatch.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
|
#include <torch/torch.h>
|
|
#include <algorithm>
|
|
|
|
namespace c10d {
|
|
|
|
// CUDA kernel to check if data has NAN, device side assert
|
|
// is raised if NAN is found
|
|
template <typename T>
|
|
__global__ void checkForNaN(T* data, size_t size) {
|
|
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
size_t stride = blockDim.x * gridDim.x;
|
|
|
|
for (size_t i = tid; i < size; i += stride) {
|
|
CUDA_KERNEL_ASSERT(!isnan(data[i]));
|
|
}
|
|
}
|
|
|
|
// CHECK if a Tensor contains NAN in any of its element
|
|
void checkForNan(const at::Tensor& tensor) {
|
|
// skip check for non float types
|
|
if (!torch::is_floating_point(tensor)) {
|
|
return;
|
|
}
|
|
const size_t maxNumThreadsPerBlock = 512;
|
|
const size_t maxNumBlocks = 24;
|
|
const size_t numThreadsPerBlock =
|
|
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());
|
|
|
|
const size_t numBlocks = std::min<size_t>(
|
|
maxNumBlocks,
|
|
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.scalar_type(), "checkForNaN", [&] {
|
|
checkForNaN<scalar_t><<<numBlocks, numThreadsPerBlock>>>(
|
|
tensor.data_ptr<scalar_t>(), tensor.numel());
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
});
|
|
|
|
}
|
|
|
|
} // namespace c10d
|