diff --git a/BUILD.bazel b/BUILD.bazel index d3084d9ebd44..3f7e6327452c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -663,6 +663,7 @@ cu_library( name = "torch_cuda", srcs = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], copts = torch_cuda_half_options, @@ -830,6 +831,7 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], )) + torch_sources, diff --git a/build_variables.bzl b/build_variables.bzl index f28131023cf6..b4cf73a6b17e 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -679,6 +679,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index ebc56588ed57..11bb1868f905 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -325,6 +325,26 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase): del pg + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("type", [torch.float16, torch.float32, torch.float64]) + @parametrize("size", [(10, 10), (1000, 1000)]) + def test_nan_assert(self, type, size): + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + pg = self._create_process_group_nccl(store, self.opts()) + device = self.rank_to_GPU[self.rank][0] + nan_tensor = torch.full(size, self.rank, dtype=type, device=device) + # randomly pick an nan element + i = random.randint(0, nan_tensor.size(0) - 1) + j = random.randint(0, nan_tensor.size(1) - 1) + nan_tensor[i, j] = float("nan") + with self.assertRaises(RuntimeError): + pg.allreduce(nan_tensor) + dist.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 6cca50daff6c..26c31691a400 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -748,6 +748,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( // both timeout and other errors. dumpOnException_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) || (dist_debug_level_ >= DebugLevel::Detail); + enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false); heartbeat_ = 1ULL; monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true)); heartbeatTimeoutInSec_ = @@ -836,6 +837,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_ << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_ << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_ + << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_ << ", PG Name: " << options_->group_name; if (options_->global_ranks_in_group.empty()) { @@ -2424,6 +2426,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( OpType opType, const char* profilingTitle, bool avoidRecordStreams) { + if (enableNanCheck_) { + checkForNan(input); + } // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; c10::cuda::CaptureStatus capture_status = @@ -2779,6 +2784,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( PreProcess pre, PostProcess post, const char* profilingTitle) { + if (enableNanCheck_) { + checkForNan(tensor); + } // avoidRecordStreams_ note: // send, recv, and irecv should be ok with avoidRecordStreams, // However, for isend, I don't think the API requires the user diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index fac9b6f38204..e688ee0f0c67 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -100,6 +100,8 @@ static std::vector TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = { static std::vector TORCH_NCCL_COORD_CHECK_MILSEC = { "TORCH_NCCL_COORD_CHECK_MILSEC"}; +static std::vector TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"}; + constexpr const char* NCCL_BACKEND_NAME = "nccl"; constexpr const char* EXCEPTION_DUMP = "exception_dump"; @@ -1024,6 +1026,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // timeout and nccl errors. bool dumpOnException_; + // Whether or not to enable nan check for input tensors to collectives. + bool enableNanCheck_; + // Whether or not to create start CUDAEvent and enable timing for start // and end events. Note that enableTiming_ is always true if desyncDebug_ // is set to true. diff --git a/torch/csrc/distributed/c10d/Utils.cu b/torch/csrc/distributed/c10d/Utils.cu new file mode 100644 index 000000000000..6bd36efec86d --- /dev/null +++ b/torch/csrc/distributed/c10d/Utils.cu @@ -0,0 +1,45 @@ +#include +#include +#include +#include +#include +#include + +namespace c10d { + +// CUDA kernel to check if data has NAN, device side assert +// is raised if NAN is found +template +__global__ void checkForNaN(T* data, size_t size) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < size; i += stride) { + CUDA_KERNEL_ASSERT(!isnan(data[i])); + } +} + +// CHECK if a Tensor contains NAN in any of its element +void checkForNan(const at::Tensor& tensor) { + // skip check for non float types + if (!torch::is_floating_point(tensor)) { + return; + } + const size_t maxNumThreadsPerBlock = 512; + const size_t maxNumBlocks = 24; + const size_t numThreadsPerBlock = + std::min(maxNumThreadsPerBlock, tensor.numel()); + + const size_t numBlocks = std::min( + maxNumBlocks, + (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.scalar_type(), "checkForNaN", [&] { + checkForNaN<<>>( + tensor.data_ptr(), tensor.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 8427b63e38e8..b193c8971b57 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -612,6 +612,8 @@ using SizeType = uint64_t; // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) +void checkForNan(const at::Tensor& tensor); + namespace tcputil { // Send and receive