[c10d] Add an option for NAN check on every collective (#125726)

Summary:
The NAN CHECK is done through device side assert without copying needed
from GPU to CPU
Test Plan:
Unit test for collectives that should experience run time error

(sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (38f5143e)]$  python
test/distributed/test_c10d_nccl.py ProcessGroupNCCLTest.test_nan_assert
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)`
failed.
[rank0]:[E507 17:31:56.885473996 Utils.cu:30] CUDA error during
checkForNan: device-side assert triggered

/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)`
failed.
/home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15:
checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)`
failed.
[rank1]:[E507 17:31:56.128961534 Utils.cu:30] CUDA error during
checkForNan: device-side assert triggered

.
----------------------------------------------------------------------
Ran 1 test in 7.723s

OK

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125726
Approved by: https://github.com/kwen2501
This commit is contained in:
Shuqiang Zhang
2024-05-13 10:50:08 -07:00
committed by PyTorch MergeBot
parent 1e47c7b11b
commit 6db3271007
7 changed files with 83 additions and 0 deletions

View File

@ -663,6 +663,7 @@ cu_library(
name = "torch_cuda",
srcs = [
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
copts = torch_cuda_half_options,
@ -830,6 +831,7 @@ cc_library(
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
)) + torch_sources,

View File

@ -679,6 +679,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
]

View File

@ -325,6 +325,26 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
del pg
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("type", [torch.float16, torch.float32, torch.float64])
@parametrize("size", [(10, 10), (1000, 1000)])
def test_nan_assert(self, type, size):
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
device = self.rank_to_GPU[self.rank][0]
nan_tensor = torch.full(size, self.rank, dtype=type, device=device)
# randomly pick an nan element
i = random.randint(0, nan_tensor.size(0) - 1)
j = random.randint(0, nan_tensor.size(1) - 1)
nan_tensor[i, j] = float("nan")
with self.assertRaises(RuntimeError):
pg.allreduce(nan_tensor)
dist.destroy_process_group()
# reset env
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_destruct_before_terminate_pg(self):

View File

@ -748,6 +748,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
// both timeout and other errors.
dumpOnException_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
(dist_debug_level_ >= DebugLevel::Detail);
enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
heartbeat_ = 1ULL;
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
heartbeatTimeoutInSec_ =
@ -836,6 +837,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
<< ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_
<< ", PG Name: " << options_->group_name;
if (options_->global_ranks_in_group.empty()) {
@ -2424,6 +2426,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
OpType opType,
const char* profilingTitle,
bool avoidRecordStreams) {
if (enableNanCheck_) {
checkForNan(input);
}
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
c10::cuda::CaptureStatus capture_status =
@ -2779,6 +2784,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
PreProcess pre,
PostProcess post,
const char* profilingTitle) {
if (enableNanCheck_) {
checkForNan(tensor);
}
// avoidRecordStreams_ note:
// send, recv, and irecv should be ok with avoidRecordStreams,
// However, for isend, I don't think the API requires the user

View File

@ -100,6 +100,8 @@ static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
"TORCH_NCCL_COORD_CHECK_MILSEC"};
static std::vector<std::string> TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"};
constexpr const char* NCCL_BACKEND_NAME = "nccl";
constexpr const char* EXCEPTION_DUMP = "exception_dump";
@ -1024,6 +1026,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// timeout and nccl errors.
bool dumpOnException_;
// Whether or not to enable nan check for input tensors to collectives.
bool enableNanCheck_;
// Whether or not to create start CUDAEvent and enable timing for start
// and end events. Note that enableTiming_ is always true if desyncDebug_
// is set to true.

View File

@ -0,0 +1,45 @@
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/torch.h>
#include <algorithm>
namespace c10d {
// CUDA kernel to check if data has NAN, device side assert
// is raised if NAN is found
template <typename T>
__global__ void checkForNaN(T* data, size_t size) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < size; i += stride) {
CUDA_KERNEL_ASSERT(!isnan(data[i]));
}
}
// CHECK if a Tensor contains NAN in any of its element
void checkForNan(const at::Tensor& tensor) {
// skip check for non float types
if (!torch::is_floating_point(tensor)) {
return;
}
const size_t maxNumThreadsPerBlock = 512;
const size_t maxNumBlocks = 24;
const size_t numThreadsPerBlock =
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());
const size_t numBlocks = std::min<size_t>(
maxNumBlocks,
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.scalar_type(), "checkForNaN", [&] {
checkForNaN<scalar_t><<<numBlocks, numThreadsPerBlock>>>(
tensor.data_ptr<scalar_t>(), tensor.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
} // namespace c10d

View File

@ -612,6 +612,8 @@ using SizeType = uint64_t;
// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
void checkForNan(const at::Tensor& tensor);
namespace tcputil {
// Send and receive