mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Add an option for NAN check on every collective (#125726)
Summary: The NAN CHECK is done through device side assert without copying needed from GPU to CPU Test Plan: Unit test for collectives that should experience run time error (sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (38f5143e)]$ python test/distributed/test_c10d_nccl.py ProcessGroupNCCLTest.test_nan_assert /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)` failed. [rank0]:[E507 17:31:56.885473996 Utils.cu:30] CUDA error during checkForNan: device-side assert triggered /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [0,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [1,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [2,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [3,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [4,0,0] Assertion `!isnan(val)` failed. /home/sqzhang/pytorch/torch/csrc/distributed/c10d/Utils.cu:15: checkForNaN: block: [0,0,0], thread: [5,0,0] Assertion `!isnan(val)` failed. [rank1]:[E507 17:31:56.128961534 Utils.cu:30] CUDA error during checkForNan: device-side assert triggered . ---------------------------------------------------------------------- Ran 1 test in 7.723s OK Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/125726 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e47c7b11b
commit
6db3271007
@ -663,6 +663,7 @@ cu_library(
|
||||
name = "torch_cuda",
|
||||
srcs = [
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
copts = torch_cuda_half_options,
|
||||
@ -830,6 +831,7 @@ cc_library(
|
||||
"torch/csrc/cuda/python_nccl.cpp",
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
)) + torch_sources,
|
||||
|
@ -679,6 +679,7 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
]
|
||||
|
@ -325,6 +325,26 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
del pg
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("type", [torch.float16, torch.float32, torch.float64])
|
||||
@parametrize("size", [(10, 10), (1000, 1000)])
|
||||
def test_nan_assert(self, type, size):
|
||||
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
device = self.rank_to_GPU[self.rank][0]
|
||||
nan_tensor = torch.full(size, self.rank, dtype=type, device=device)
|
||||
# randomly pick an nan element
|
||||
i = random.randint(0, nan_tensor.size(0) - 1)
|
||||
j = random.randint(0, nan_tensor.size(1) - 1)
|
||||
nan_tensor[i, j] = float("nan")
|
||||
with self.assertRaises(RuntimeError):
|
||||
pg.allreduce(nan_tensor)
|
||||
dist.destroy_process_group()
|
||||
# reset env
|
||||
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_destruct_before_terminate_pg(self):
|
||||
|
@ -748,6 +748,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
// both timeout and other errors.
|
||||
dumpOnException_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
|
||||
(dist_debug_level_ >= DebugLevel::Detail);
|
||||
enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
|
||||
heartbeat_ = 1ULL;
|
||||
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
|
||||
heartbeatTimeoutInSec_ =
|
||||
@ -836,6 +837,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
|
||||
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
|
||||
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
|
||||
<< ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_
|
||||
<< ", PG Name: " << options_->group_name;
|
||||
|
||||
if (options_->global_ranks_in_group.empty()) {
|
||||
@ -2424,6 +2426,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
OpType opType,
|
||||
const char* profilingTitle,
|
||||
bool avoidRecordStreams) {
|
||||
if (enableNanCheck_) {
|
||||
checkForNan(input);
|
||||
}
|
||||
// Environment setting by the user may add onto collective call's option
|
||||
avoidRecordStreams |= avoidRecordStreams_;
|
||||
c10::cuda::CaptureStatus capture_status =
|
||||
@ -2779,6 +2784,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
PreProcess pre,
|
||||
PostProcess post,
|
||||
const char* profilingTitle) {
|
||||
if (enableNanCheck_) {
|
||||
checkForNan(tensor);
|
||||
}
|
||||
// avoidRecordStreams_ note:
|
||||
// send, recv, and irecv should be ok with avoidRecordStreams,
|
||||
// However, for isend, I don't think the API requires the user
|
||||
|
@ -100,6 +100,8 @@ static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
|
||||
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
|
||||
"TORCH_NCCL_COORD_CHECK_MILSEC"};
|
||||
|
||||
static std::vector<std::string> TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"};
|
||||
|
||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
|
||||
|
||||
constexpr const char* EXCEPTION_DUMP = "exception_dump";
|
||||
@ -1024,6 +1026,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// timeout and nccl errors.
|
||||
bool dumpOnException_;
|
||||
|
||||
// Whether or not to enable nan check for input tensors to collectives.
|
||||
bool enableNanCheck_;
|
||||
|
||||
// Whether or not to create start CUDAEvent and enable timing for start
|
||||
// and end events. Note that enableTiming_ is always true if desyncDebug_
|
||||
// is set to true.
|
||||
|
45
torch/csrc/distributed/c10d/Utils.cu
Normal file
45
torch/csrc/distributed/c10d/Utils.cu
Normal file
@ -0,0 +1,45 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
#include <torch/torch.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
// CUDA kernel to check if data has NAN, device side assert
|
||||
// is raised if NAN is found
|
||||
template <typename T>
|
||||
__global__ void checkForNaN(T* data, size_t size) {
|
||||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
for (size_t i = tid; i < size; i += stride) {
|
||||
CUDA_KERNEL_ASSERT(!isnan(data[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK if a Tensor contains NAN in any of its element
|
||||
void checkForNan(const at::Tensor& tensor) {
|
||||
// skip check for non float types
|
||||
if (!torch::is_floating_point(tensor)) {
|
||||
return;
|
||||
}
|
||||
const size_t maxNumThreadsPerBlock = 512;
|
||||
const size_t maxNumBlocks = 24;
|
||||
const size_t numThreadsPerBlock =
|
||||
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());
|
||||
|
||||
const size_t numBlocks = std::min<size_t>(
|
||||
maxNumBlocks,
|
||||
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.scalar_type(), "checkForNaN", [&] {
|
||||
checkForNaN<scalar_t><<<numBlocks, numThreadsPerBlock>>>(
|
||||
tensor.data_ptr<scalar_t>(), tensor.numel());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
} // namespace c10d
|
@ -612,6 +612,8 @@ using SizeType = uint64_t;
|
||||
// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
|
||||
#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
|
||||
|
||||
void checkForNan(const at::Tensor& tensor);
|
||||
|
||||
namespace tcputil {
|
||||
|
||||
// Send and receive
|
||||
|
Reference in New Issue
Block a user