Files
pytorch/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
Ke Wen f89af60183 Rewrite NCCL watchdog to more reliably throw timeout (#97066)
Fixes #97191

This PR aims to propagate collective exceptions (async error or timeout) up to the program, so as to avoid silent stuck job.

### Previous output in #97191
```
Rank 0 is the problematic rank
Rank 4 completed
Rank 5 completed
Rank 3 completed
Rank 6 completed
Rank 2 completed
Rank 7 completed
Rank 1 completed
[E ProcessGroupNCCL.cpp:464] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10917 milliseconds before timing out.
Rank 0 completed
[E ProcessGroupNCCL.cpp:478] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:483] To avoid data inconsistency, we are taking the entire process down.
```
Although it says that it is taking the process down, it sometimes fails to do so.

### New output after this PR:
```
...
[E ProcessGroupNCCL.cpp:459] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:473] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:479] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:818] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLREDUCE, Timeout(ms)=10000) ran for 10599 milliseconds before timing out.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 194470) of binary: /data/home/kw2501/repos/pytorch-dev-env/bin/python
Traceback (most recent call last):
  File "/pytorch-dev-env/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
  File "/pytorch-dev/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/pytorch-dev/torch/distributed/run.py", line 794, in main
    run(args)
  File "/pytorch-dev/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/pytorch-dev/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/pytorch-dev/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
hang.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-03-20_22:00:42
  host      : node0
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 194470)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 194470
============================================================
```

The log suggests that TorchX monitor is triggered, and job is torn down.

### Major changes in this PR:
1. Merge ncclWatchDog thread and workCleanupLoop thread into one so that the watch action and the throw action are streamlined.
Previously, ncclWatchDog is responsible for watching comm error and timeout, and workCleanupLoop is responsible for watching Work item error and throwing exception. This two-thread design is not streamlined, raising the chance of missing the throw. Also, it is duplicated to watch at multiple level.
2. Rethrow exception at watchdog thread.
3. Clean up a bunch of duplicated functions, e.g. `checkAndThrowException` and `handleNcclException`.
4. Turn on ASYNC_ERROR_HANDLING by default
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97066
Approved by: https://github.com/rohan-varma
2023-03-25 04:30:20 +00:00

260 lines
7.1 KiB
C++

#include <chrono>
#include <c10/util/irange.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/FileStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include "CUDATest.hpp"
#include "TestUtils.hpp"
#include <gtest/gtest.h>
using namespace c10d::test;
constexpr int kNcclErrorHandlingVersion = 2400;
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLSimulateErrors(
const std::vector<at::Device>& devices,
bool simulate_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(devices, rank, opType, seq), simulate_error_(simulate_error) {}
std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
const override {
if (simulate_error_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
}
private:
bool simulate_error_;
};
class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
public:
ProcessGroupNCCLSimulateErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {}
std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
if (simulate_error_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
}
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
return std::chrono::milliseconds(
ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
c10d::OpType opType,
const char* profilingTitle,
const c10::optional<std::vector<at::Tensor>>& inputs =
c10::nullopt) override {
return c10::make_intrusive<WorkNCCLSimulateErrors>(
devices, simulate_error_, rank, opType, seq_);
}
size_t getNCCLCommCacheSize() {
return devNCCLCommMap_.size();
}
void simulate_error() {
simulate_error_ = true;
}
void reset_error() {
simulate_error_ = false;
}
private:
bool simulate_error_;
};
class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLTimedoutErrors(
const std::vector<at::Device>& devices,
bool set_timedout_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(devices, rank, opType, seq),
set_timedout_error_(set_timedout_error) {}
private:
bool isCompleted() override {
if (set_timedout_error_) {
return false;
}
return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
}
private:
bool set_timedout_error_;
};
class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
public:
ProcessGroupNCCLTimedOutErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
set_timedout_error_(false) {}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
c10d::OpType opType,
const char* profilingTitle,
const c10::optional<std::vector<at::Tensor>>& inputs =
c10::nullopt) override {
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
devices, set_timedout_error_, rank, opType, seq_);
}
void set_timedout_error() {
set_timedout_error_ = true;
}
void reset_timedout_error() {
set_timedout_error_ = false;
}
private:
bool set_timedout_error_;
};
class ProcessGroupNCCLErrorsTest : public ::testing::Test {
protected:
bool skipTest() {
if (cudaNumDevices() == 0) {
LOG(INFO) << "Skipping test since CUDA is not available";
return true;
}
#ifdef USE_C10D_NCCL
if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
LOG(INFO) << "Skipping test since NCCL version is too old";
return true;
}
#endif
return false;
}
void SetUp() override {
size_t numDevices = cudaNumDevices();
TemporaryFile file;
store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
at::cuda::OptionalCUDAGuard deviceGuard;
tensors_.resize(numDevices);
for (const auto i : c10::irange(numDevices)) {
deviceGuard.set_index(i);
tensors_[i] = at::ones({3, 3}, at::kCUDA);
}
}
void TearDown() override {
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0", 1) == 0);
}
std::vector<at::Tensor> tensors_;
c10::intrusive_ptr<::c10d::FileStore> store_;
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
if (skipTest()) {
return;
}
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(1000);
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulate_error();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), std::runtime_error);
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_FALSE(work->isSuccess());
EXPECT_THROW(work->wait(), std::runtime_error);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
if (skipTest()) {
return;
}
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.set_timedout_error();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), std::runtime_error);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
if (skipTest()) {
return;
}
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
pg.barrier()->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulate_error();
work = pg.allreduce(tensors_);
// Should not throw exceptions.
work->wait();
pg.barrier()->wait();
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_FALSE(work->isSuccess());
// Communicators might be aborted here, further operations would fail.
}