Files
pytorch/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
Will Constable f85d3a022c [C10D] Fix pointToPoint op Flight Recording (#120270)
Fix and test issues with both coalesced and individual send/recv ops

Considered an alternate approach and then ditched it
 - alternate approach: #119757
 - reason ditched: prefer recording individual collective events inside
   coalescing region instead of just the event at the end of the region,
   which also would not have tensor sizes or opnames without additional
   state variables added

Another approach also ditched
- record events on workEnqueue instead of initWork
- reason ditched: too messy to get input/output shapes tagged on
  recording when recording in workEnqueue.  Adding the info onto the
  Work obj would be possible, but adds to overhead of copying Works
  which we do on every collective. We can get info off the input/output
  tensors directly in initWork, but we don't want to keep refs to those
  tensors alive while the work is Enqueued, so we'd have to specifically
  copy size lists or something.

This PR instead avoids creating a work inside pointToPoint when
coalescing is active. Instead, only at endCoalescing() is a work finally
intialized and enqueued.  But it adds a record() call inside
pointToPoint() instead of creating a work, during coalescing. This
record() call picks up tensor shapes and op names.

It ALSO changes initWork to accept a 'record' argument. This defaults to
false, and should only be set to true if the caller ensures the work
will be enqueued by workEnqueue, ensuring its cuda events are live when
used by flight recorder's update_state().

The testing uncovers some odd pre-existing behavior and leaves them
alone for now. We could change some of these
- seq starts off at 1, not 0 for first op (but this is inconistent)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120270
Approved by: https://github.com/shuqiangzhang
ghstack dependencies: #120724
2024-02-29 01:03:31 +00:00

507 lines
17 KiB
C++

#include <chrono>
#include <filesystem>
#include <fstream>
#include <thread>
#include <c10/util/irange.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/FileStore.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include "CUDATest.hpp"
#include "TestUtils.hpp"
#include <gtest/gtest.h>
using namespace c10d::test;
constexpr int kNcclErrorHandlingVersion = 2400;
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLSimulateErrors(
at::Device& device,
bool simulate_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(device, rank, opType, seq), simulateError_(simulate_error) {}
std::exception_ptr checkForNCCLErrors() override {
if (simulateError_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
}
private:
bool simulateError_;
};
class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
public:
ProcessGroupNCCLSimulateErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
std::exception_ptr checkForNCCLErrors(
std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
if (simulateError_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
}
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
return std::chrono::milliseconds(
ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
at::Device& device,
int rank,
c10d::OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs = {},
const std::vector<at::Tensor>& outputs = {},
bool record = false) override {
return c10::make_intrusive<WorkNCCLSimulateErrors>(
device, simulateError_, rank, opType, seq_);
}
size_t getNCCLCommCacheSize() {
return devNCCLCommMap_.size();
}
void simulateError() {
simulateError_ = true;
}
void resetError() {
simulateError_ = false;
}
private:
bool simulateError_;
};
class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLTimedoutErrors(
at::Device& device,
bool set_timedout_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(device, rank, opType, seq),
setTimedoutError_(set_timedout_error) {}
private:
bool isCompleted() override {
if (setTimedoutError_) {
return false;
}
return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
}
private:
bool setTimedoutError_;
};
class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
public:
ProcessGroupNCCLTimedOutErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
watchDogDebugInfoFinished_(false),
setTimedoutError_(false) {}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
at::Device& device,
int rank,
c10d::OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs = {},
const std::vector<at::Tensor>& outputs = {},
bool record = false) override {
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
device, setTimedoutError_, rank, opType, seq_);
}
void setTimedoutError() {
setTimedoutError_ = true;
}
void resetTimedoutError() {
setTimedoutError_ = false;
}
bool getWatchDogDebugInfoFinishedFlag() {
return watchDogDebugInfoFinished_;
}
// In the constructor of ProcessGroupNCCL. We don't allow the watchdog thread
// to run any handling or desync report when the main thread is block wait.
// Even if users set handling and turn on desyncDebug flag, they will get
// reset. For the ease of unit test, we want the main thread to be block wait,
// so we have this hack to manually set the desync debug flag after PG
// creation.
void forceSetDesyncDebugFlag() {
desyncDebug_ = true;
}
protected:
std::string getNCCLWatchdogDebugInfo() override {
LOG(INFO) << "overridden getNCCLWatchdogDebugInfo called";
watchDogDebugInfoFinished_ = true;
return "";
}
bool watchDogDebugInfoFinished_;
private:
bool setTimedoutError_;
};
class ProcessGroupNCCLNoHeartbeatCaught
: public ProcessGroupNCCLTimedOutErrors {
public:
ProcessGroupNCCLNoHeartbeatCaught(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
hasMonitorThreadCaughtError_(false) {}
std::mutex& getWatchdogMutex() {
return workMetaListMutex_;
}
bool getErrorCaughtFlag() {
return hasMonitorThreadCaughtError_;
}
void forceTryWriteDebugInfo() {
std::future<bool> asyncDebugDump = std::async(
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
asyncDebugDump.wait();
}
protected:
// Override the heartbeat monitor function to make sure that we capture
// the exception in the monitor thread because we cannot try-catch it in
// the main thread and we set a flag for the main thread to check.
void heartbeatMonitor() override {
try {
c10d::ProcessGroupNCCL::heartbeatMonitor();
} catch (std::runtime_error& e) {
hasMonitorThreadCaughtError_ = true;
}
}
// It's really hard to unit test std::abort. So we override it instead.
// Commented this override, we do see process aborted with core dump without
// this override.
void terminateProcess(std::string errMsg) override {
throw std::runtime_error(errMsg);
}
bool hasMonitorThreadCaughtError_;
};
class ProcessGroupNCCLDebugInfoStuck
: public ProcessGroupNCCLNoHeartbeatCaught {
public:
ProcessGroupNCCLDebugInfoStuck(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, opts) {}
protected:
// Override the heartbeat monitor function to set a long timeout to mimic the
// stuck in getting debug info.
std::string getNCCLWatchdogDebugInfo() override {
std::this_thread::sleep_for(
std::chrono::seconds(heartbeatTimeoutInSec_ * 20));
watchDogDebugInfoFinished_ = true;
return "";
}
};
class ProcessGroupNCCLErrorsTest : public ::testing::Test {
protected:
bool skipTest() {
if (cudaNumDevices() == 0) {
LOG(INFO) << "Skipping test since CUDA is not available";
return true;
}
#ifdef USE_C10D_NCCL
if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
LOG(INFO) << "Skipping test since NCCL version is too old";
return true;
}
#endif
return false;
}
void SetUp() override {
// Enable LOG(INFO) messages.
c10::initLogging();
// Need to have this check for at SetUp to make sure we only run the test --
// including the init -- when there are GPUs available.
if (skipTest()) {
GTEST_SKIP() << "Skipping ProcessGroupNCCLErrorsTest because system "
<< "requirement is not met (no CUDA or GPU).";
}
size_t numDevices = 1; // One device per rank (thread)
TemporaryFile file;
store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
tensors_.resize(numDevices);
tensors_[0] = at::empty({3, 3}, at::kCUDA);
}
void TearDown() override {
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);
}
std::vector<at::Tensor> tensors_;
c10::intrusive_ptr<::c10d::FileStore> store_;
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(1000);
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulateError();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), std::runtime_error);
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_THROW(work->wait(), std::runtime_error);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.setTimedoutError();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), c10::DistBackendError);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
pg.barrier()->wait();
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulateError();
work = pg.allreduce(tensors_);
// Should not throw exceptions.
work->wait();
pg.barrier()->wait();
EXPECT_TRUE(work->isCompleted());
// Communicators might be aborted here, further operations would fail.
}
// Function to read what we wrote to the local disk for validation.
std::string readTraceFromFile(const std::string& filename, size_t size) {
std::ifstream file(filename, std::ios::binary);
// Read the strings from the file
if (file) { // While the file stream is in good state
std::string str(size, '\0');
file.read(&str[0], size);
if (file) {
return str;
}
}
return "";
}
// Extend the nested class outside the parent class
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
public:
TestDebugInfoWriter(std::string namePrefix)
: DebugInfoWriter(namePrefix, 0) {}
void write(const std::string& ncclTrace) override {
traces_.assign(ncclTrace.begin(), ncclTrace.end());
c10d::DebugInfoWriter::write(ncclTrace);
}
std::vector<uint8_t>& getTraces() {
return traces_;
}
private:
std::vector<uint8_t> traces_;
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
int heartBeatIntervalInSec = 2;
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
ASSERT_TRUE(
setenv(
c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
timeInterval.c_str(),
1) == 0);
ASSERT_TRUE(
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
auto tempFilename = c10::str(
std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_");
ASSERT_TRUE(
setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0);
// Enable nccl flight recorder.
ASSERT_TRUE(setenv("TORCH_NCCL_TRACE_BUFFER_SIZE", "10", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
// Set a long watchdog timeout, so that we have enough time to lock the
// watchdog and let the heartbeat monitor thread to kick in.
options->timeout = std::chrono::milliseconds(30000);
ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options);
// The storer here is very similar to the fallback storer.
// The only difference is that we are storing traces also in memory for
// validation.
std::string fileNamePrefix = c10d::getCvarString(
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
// Normal collective case.
auto work = pg.allreduce(tensors_);
work->wait();
work = pg.allreduce(tensors_);
{
// Now run all reduce with errors.
std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
LOG(INFO) << "Lock watchdog thread.";
// Wait long enough before monitor thread throws exceptions.
std::this_thread::sleep_for(
std::chrono::seconds(heartBeatIntervalInSec * 3));
// Check the monitoring thread launched and exception thrown.
EXPECT_TRUE(pg.getErrorCaughtFlag());
}
work->wait();
EXPECT_TRUE(traces.size() > 0);
auto filename = c10::str(tempFilename, 0);
auto traceFromStorage = readTraceFromFile(filename, traces.size());
// Check the traces read from storage match with the original nccl trace.
EXPECT_TRUE(traceFromStorage == std::string(traces.begin(), traces.end()));
std::filesystem::remove(filename);
}
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
protected:
void SetUp() override {
// TODO (kwen2501)
GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
<< "will rewrite them after refactoring Work queues.";
ProcessGroupNCCLErrorsTest::SetUp();
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
ASSERT_TRUE(
setenv(
c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
timeInterval.c_str(),
1) == 0);
ASSERT_TRUE(
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0);
// We cannot capture the exception thrown in watchdog thread without making
// lots of changes to the code. So we don't let the watchdog throw
// exception.
ASSERT_TRUE(
setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0);
options_ = c10d::ProcessGroupNCCL::Options::create();
// Set a super short watchdog timeout.
options_->timeout = std::chrono::milliseconds(100);
}
void watchdogTimeoutTestCommon(
ProcessGroupNCCLNoHeartbeatCaught& pg,
int multiplier) {
pg.forceSetDesyncDebugFlag();
pg.setTimedoutError();
auto work = pg.allreduce(tensors_);
std::this_thread::sleep_for(
std::chrono::seconds(heartBeatIntervalInSec * multiplier));
EXPECT_THROW(work->wait(), c10::DistBackendError);
}
const int heartBeatIntervalInSec = 2;
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> options_;
};
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoFinished) {
ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options_);
// Write debug info will lead to watchdog thread to wait for 30 seconds.
// And this is hard to override, so we just call it before hand. Otherwise,
// we need to set a long heartbeat timeout which will make the test way
// slower.
pg.forceTryWriteDebugInfo();
watchdogTimeoutTestCommon(pg, 2);
// The flag is true shows that the heartbeat monitor thread does not kill
// the watchdog thread when it is getting debug info such as desync debug
// info.
EXPECT_TRUE(pg.getWatchDogDebugInfoFinishedFlag());
// The flag is false shows that the heartbeat monitor thread does not
// trigger process abort if getting debug info and destroy PG is fast.
EXPECT_FALSE(pg.getErrorCaughtFlag());
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoStuck) {
ProcessGroupNCCLDebugInfoStuck pg(store_, 0, 1, options_);
// Need to keep main thread sleep longer so that we can let heartbeat monitor
// thread to finish the extra wait and flip the flag.
watchdogTimeoutTestCommon(pg, 4);
// The flag is false shows that we get stuck in getting debug info such as
// desync debug info in the watchdog thread.
EXPECT_FALSE(pg.getWatchDogDebugInfoFinishedFlag());
// The flag is true shows that the heartbeat monitor thread does trigger
// process abort if getting debug info gets stuck.
EXPECT_TRUE(pg.getErrorCaughtFlag());
// Communicators might be aborted here, further operations would fail.
}