mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 4c28a0eb0ba437c1b7db559f63f8bec17bd48f69. Reverted https://github.com/pytorch/pytorch/pull/134131 on behalf of https://github.com/ZainRizvi due to Sorry but this causes formatting errors internally which make it fail to build. See D61759282 ([comment](https://github.com/pytorch/pytorch/pull/134131#issuecomment-2310455878))
509 lines
17 KiB
C++
509 lines
17 KiB
C++
#include <chrono>
|
|
#include <filesystem>
|
|
#include <fstream>
|
|
#include <thread>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/cuda/nccl.h>
|
|
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
|
#include "CUDATest.hpp"
|
|
#include "TestUtils.hpp"
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
using namespace c10d::test;
|
|
|
|
constexpr int kNcclErrorHandlingVersion = 2400;
|
|
|
|
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
|
|
public:
|
|
WorkNCCLSimulateErrors(
|
|
at::Device& device,
|
|
bool simulate_error,
|
|
int rank,
|
|
c10d::OpType opType,
|
|
uint64_t seq)
|
|
: WorkNCCL("0", "default_pg", device, rank, opType, seq),
|
|
simulateError_(simulate_error) {}
|
|
|
|
std::exception_ptr checkForNCCLErrors() override {
|
|
if (simulateError_) {
|
|
return std::make_exception_ptr(std::runtime_error("Error"));
|
|
}
|
|
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
|
|
}
|
|
|
|
private:
|
|
bool simulateError_;
|
|
};
|
|
|
|
class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
|
|
public:
|
|
ProcessGroupNCCLSimulateErrors(
|
|
const c10::intrusive_ptr<c10d::Store>& store,
|
|
int rank,
|
|
int size,
|
|
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
|
|
: ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
|
|
|
|
std::exception_ptr checkForNCCLErrors(
|
|
std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
|
|
if (simulateError_) {
|
|
return std::make_exception_ptr(std::runtime_error("Error"));
|
|
}
|
|
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
|
|
}
|
|
|
|
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
|
|
return std::chrono::milliseconds(
|
|
ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
|
|
at::Device& device,
|
|
int rank,
|
|
c10d::OpType opType,
|
|
const char* profilingTitle,
|
|
const std::vector<at::Tensor>& inputs = {},
|
|
const std::vector<at::Tensor>& outputs = {},
|
|
bool record = false) override {
|
|
return c10::make_intrusive<WorkNCCLSimulateErrors>(
|
|
device, simulateError_, rank, opType, seqCollective_);
|
|
}
|
|
|
|
size_t getNCCLCommCacheSize() {
|
|
return devNCCLCommMap_.size();
|
|
}
|
|
|
|
void simulateError() {
|
|
simulateError_ = true;
|
|
}
|
|
|
|
void resetError() {
|
|
simulateError_ = false;
|
|
}
|
|
|
|
private:
|
|
bool simulateError_;
|
|
};
|
|
|
|
class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
|
|
public:
|
|
WorkNCCLTimedoutErrors(
|
|
at::Device& device,
|
|
bool set_timedout_error,
|
|
int rank,
|
|
c10d::OpType opType,
|
|
uint64_t seq)
|
|
: WorkNCCL("0", "default_pg", device, rank, opType, seq),
|
|
setTimedoutError_(set_timedout_error) {}
|
|
|
|
private:
|
|
bool isCompleted() override {
|
|
if (setTimedoutError_) {
|
|
return false;
|
|
}
|
|
return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
|
|
}
|
|
|
|
private:
|
|
bool setTimedoutError_;
|
|
};
|
|
|
|
class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
|
|
public:
|
|
ProcessGroupNCCLTimedOutErrors(
|
|
const c10::intrusive_ptr<c10d::Store>& store,
|
|
int rank,
|
|
int size,
|
|
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
|
|
: ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
|
|
watchDogDebugInfoFinished_(false),
|
|
setTimedoutError_(false) {}
|
|
|
|
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
|
|
at::Device& device,
|
|
int rank,
|
|
c10d::OpType opType,
|
|
const char* profilingTitle,
|
|
const std::vector<at::Tensor>& inputs = {},
|
|
const std::vector<at::Tensor>& outputs = {},
|
|
bool record = false) override {
|
|
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
|
|
device, setTimedoutError_, rank, opType, seqCollective_);
|
|
}
|
|
|
|
void setTimedoutError() {
|
|
setTimedoutError_ = true;
|
|
}
|
|
|
|
void resetTimedoutError() {
|
|
setTimedoutError_ = false;
|
|
}
|
|
|
|
bool getWatchDogDebugInfoFinishedFlag() {
|
|
return watchDogDebugInfoFinished_;
|
|
}
|
|
|
|
// In the constructor of ProcessGroupNCCL. We don't allow the watchdog thread
|
|
// to run any handling or desync report when the main thread is block wait.
|
|
// Even if users set handling and turn on desyncDebug flag, they will get
|
|
// reset. For the ease of unit test, we want the main thread to be block wait,
|
|
// so we have this hack to manually set the desync debug flag after PG
|
|
// creation.
|
|
void forceSetDesyncDebugFlag() {
|
|
desyncDebug_ = true;
|
|
}
|
|
|
|
protected:
|
|
std::string getNCCLWatchdogDebugInfo() override {
|
|
LOG(INFO) << "overridden getNCCLWatchdogDebugInfo called";
|
|
watchDogDebugInfoFinished_ = true;
|
|
return "";
|
|
}
|
|
bool watchDogDebugInfoFinished_;
|
|
|
|
private:
|
|
bool setTimedoutError_;
|
|
};
|
|
|
|
class ProcessGroupNCCLNoHeartbeatCaught
|
|
: public ProcessGroupNCCLTimedOutErrors {
|
|
public:
|
|
ProcessGroupNCCLNoHeartbeatCaught(
|
|
const c10::intrusive_ptr<c10d::Store>& store,
|
|
int rank,
|
|
int size,
|
|
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
|
|
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
|
|
hasMonitorThreadCaughtError_(false) {}
|
|
|
|
std::mutex& getWatchdogMutex() {
|
|
return workMetaListMutex_;
|
|
}
|
|
|
|
bool getErrorCaughtFlag() {
|
|
return hasMonitorThreadCaughtError_;
|
|
}
|
|
|
|
void forceTryWriteDebugInfo() {
|
|
std::future<bool> asyncDebugDump = std::async(
|
|
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
|
|
asyncDebugDump.wait();
|
|
}
|
|
|
|
protected:
|
|
// Override the heartbeat monitor function to make sure that we capture
|
|
// the exception in the monitor thread because we cannot try-catch it in
|
|
// the main thread and we set a flag for the main thread to check.
|
|
void heartbeatMonitor() override {
|
|
try {
|
|
c10d::ProcessGroupNCCL::heartbeatMonitor();
|
|
} catch (std::runtime_error& e) {
|
|
hasMonitorThreadCaughtError_ = true;
|
|
}
|
|
}
|
|
|
|
// It's really hard to unit test std::abort. So we override it instead.
|
|
// Commented this override, we do see process aborted with core dump without
|
|
// this override.
|
|
void terminateProcess(std::string errMsg) override {
|
|
throw std::runtime_error(errMsg);
|
|
}
|
|
|
|
bool hasMonitorThreadCaughtError_;
|
|
};
|
|
|
|
class ProcessGroupNCCLDebugInfoStuck
|
|
: public ProcessGroupNCCLNoHeartbeatCaught {
|
|
public:
|
|
ProcessGroupNCCLDebugInfoStuck(
|
|
const c10::intrusive_ptr<c10d::Store>& store,
|
|
int rank,
|
|
int size,
|
|
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
|
|
: ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, opts) {}
|
|
|
|
protected:
|
|
// Override the heartbeat monitor function to set a long timeout to mimic the
|
|
// stuck in getting debug info.
|
|
std::string getNCCLWatchdogDebugInfo() override {
|
|
std::this_thread::sleep_for(
|
|
std::chrono::seconds(heartbeatTimeoutInSec_ * 20));
|
|
watchDogDebugInfoFinished_ = true;
|
|
return "";
|
|
}
|
|
};
|
|
|
|
class ProcessGroupNCCLErrorsTest : public ::testing::Test {
|
|
protected:
|
|
bool skipTest() {
|
|
if (cudaNumDevices() == 0) {
|
|
LOG(INFO) << "Skipping test since CUDA is not available";
|
|
return true;
|
|
}
|
|
#ifdef USE_C10D_NCCL
|
|
if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
|
|
LOG(INFO) << "Skipping test since NCCL version is too old";
|
|
return true;
|
|
}
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
void SetUp() override {
|
|
// Enable LOG(INFO) messages.
|
|
c10::initLogging();
|
|
// Need to have this check for at SetUp to make sure we only run the test --
|
|
// including the init -- when there are GPUs available.
|
|
if (skipTest()) {
|
|
GTEST_SKIP() << "Skipping ProcessGroupNCCLErrorsTest because system "
|
|
<< "requirement is not met (no CUDA or GPU).";
|
|
}
|
|
|
|
size_t numDevices = 1; // One device per rank (thread)
|
|
TemporaryFile file;
|
|
store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
|
|
|
|
tensors_.resize(numDevices);
|
|
tensors_[0] = at::empty({3, 3}, at::kCUDA);
|
|
}
|
|
|
|
void TearDown() override {
|
|
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);
|
|
}
|
|
|
|
std::vector<at::Tensor> tensors_;
|
|
c10::intrusive_ptr<::c10d::FileStore> store_;
|
|
};
|
|
|
|
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
|
|
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
|
|
auto options = c10d::ProcessGroupNCCL::Options::create();
|
|
options->timeout = std::chrono::milliseconds(1000);
|
|
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
|
|
|
|
auto work = pg.allreduce(tensors_);
|
|
work->wait();
|
|
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
|
|
|
|
// Now run all reduce with errors.
|
|
pg.simulateError();
|
|
work = pg.allreduce(tensors_);
|
|
EXPECT_THROW(work->wait(), std::runtime_error);
|
|
|
|
// Verify the work item failed.
|
|
EXPECT_TRUE(work->isCompleted());
|
|
EXPECT_THROW(work->wait(), std::runtime_error);
|
|
|
|
// Communicators might be aborted here, further operations would fail.
|
|
}
|
|
|
|
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
|
|
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
|
|
auto options = c10d::ProcessGroupNCCL::Options::create();
|
|
options->timeout = std::chrono::milliseconds(3000);
|
|
ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options);
|
|
|
|
auto work = pg.allreduce(tensors_);
|
|
work->wait();
|
|
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
|
|
|
|
// Now run all reduce with errors.
|
|
pg.setTimedoutError();
|
|
work = pg.allreduce(tensors_);
|
|
EXPECT_THROW(work->wait(), c10::DistBackendError);
|
|
|
|
// Communicators might be aborted here, further operations would fail.
|
|
}
|
|
|
|
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
|
|
auto options = c10d::ProcessGroupNCCL::Options::create();
|
|
options->timeout = std::chrono::milliseconds(3000);
|
|
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
|
|
|
|
auto work = pg.allreduce(tensors_);
|
|
pg.barrier()->wait();
|
|
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
|
|
|
|
// Now run all reduce with errors.
|
|
pg.simulateError();
|
|
work = pg.allreduce(tensors_);
|
|
|
|
// Should not throw exceptions.
|
|
work->wait();
|
|
pg.barrier()->wait();
|
|
|
|
EXPECT_TRUE(work->isCompleted());
|
|
// Communicators might be aborted here, further operations would fail.
|
|
}
|
|
|
|
// Function to read what we wrote to the local disk for validation.
|
|
std::string readTraceFromFile(const std::string& filename, size_t size) {
|
|
std::ifstream file(filename, std::ios::binary);
|
|
// Read the strings from the file
|
|
if (file) { // While the file stream is in good state
|
|
std::string str(size, '\0');
|
|
file.read(&str[0], size);
|
|
if (file) {
|
|
return str;
|
|
}
|
|
}
|
|
return "";
|
|
}
|
|
|
|
// Extend the nested class outside the parent class
|
|
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
|
|
public:
|
|
TestDebugInfoWriter(std::string namePrefix)
|
|
: DebugInfoWriter(namePrefix, 0) {}
|
|
|
|
void write(const std::string& ncclTrace) override {
|
|
traces_.assign(ncclTrace.begin(), ncclTrace.end());
|
|
c10d::DebugInfoWriter::write(ncclTrace);
|
|
}
|
|
|
|
std::vector<uint8_t>& getTraces() {
|
|
return traces_;
|
|
}
|
|
|
|
private:
|
|
std::vector<uint8_t> traces_;
|
|
};
|
|
|
|
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
|
int heartBeatIntervalInSec = 2;
|
|
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
|
|
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
|
|
ASSERT_TRUE(
|
|
setenv(
|
|
c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
|
|
timeInterval.c_str(),
|
|
1) == 0);
|
|
ASSERT_TRUE(
|
|
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
|
|
auto tempFilename = c10::str(
|
|
std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_");
|
|
ASSERT_TRUE(
|
|
setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0);
|
|
// Enable nccl flight recorder.
|
|
ASSERT_TRUE(setenv("TORCH_NCCL_TRACE_BUFFER_SIZE", "10", 1) == 0);
|
|
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DUMP_ON_TIMEOUT[0].c_str(), "1", 1) == 0);
|
|
auto options = c10d::ProcessGroupNCCL::Options::create();
|
|
// Set a long watchdog timeout, so that we have enough time to lock the
|
|
// watchdog and let the heartbeat monitor thread to kick in.
|
|
options->timeout = std::chrono::milliseconds(30000);
|
|
ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options);
|
|
// The storer here is very similar to the fallback storer.
|
|
// The only difference is that we are storing traces also in memory for
|
|
// validation.
|
|
std::string fileNamePrefix = c10d::getCvarString(
|
|
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
|
std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
|
|
std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
|
|
std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
|
|
c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
|
|
|
|
// Normal collective case.
|
|
auto work = pg.allreduce(tensors_);
|
|
work->wait();
|
|
|
|
work = pg.allreduce(tensors_);
|
|
{
|
|
// Now run all reduce with errors.
|
|
std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
|
|
LOG(INFO) << "Lock watchdog thread.";
|
|
// Wait long enough before monitor thread throws exceptions.
|
|
std::this_thread::sleep_for(
|
|
std::chrono::seconds(heartBeatIntervalInSec * 3));
|
|
// Check the monitoring thread launched and exception thrown.
|
|
EXPECT_TRUE(pg.getErrorCaughtFlag());
|
|
}
|
|
work->wait();
|
|
EXPECT_TRUE(traces.size() > 0);
|
|
auto filename = c10::str(tempFilename, 0);
|
|
auto traceFromStorage = readTraceFromFile(filename, traces.size());
|
|
// Check the traces read from storage match with the original nccl trace.
|
|
EXPECT_TRUE(traceFromStorage == std::string(traces.begin(), traces.end()));
|
|
std::filesystem::remove(filename);
|
|
}
|
|
|
|
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
|
|
protected:
|
|
void SetUp() override {
|
|
// TODO (kwen2501)
|
|
GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
|
|
<< "will rewrite them after refactoring Work queues.";
|
|
ProcessGroupNCCLErrorsTest::SetUp();
|
|
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
|
|
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
|
|
ASSERT_TRUE(
|
|
setenv(
|
|
c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
|
|
timeInterval.c_str(),
|
|
1) == 0);
|
|
ASSERT_TRUE(
|
|
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
|
|
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0);
|
|
// We cannot capture the exception thrown in watchdog thread without making
|
|
// lots of changes to the code. So we don't let the watchdog throw
|
|
// exception.
|
|
ASSERT_TRUE(
|
|
setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0);
|
|
options_ = c10d::ProcessGroupNCCL::Options::create();
|
|
// Set a super short watchdog timeout.
|
|
options_->timeout = std::chrono::milliseconds(100);
|
|
}
|
|
|
|
void watchdogTimeoutTestCommon(
|
|
ProcessGroupNCCLNoHeartbeatCaught& pg,
|
|
int multiplier) {
|
|
pg.forceSetDesyncDebugFlag();
|
|
pg.setTimedoutError();
|
|
auto work = pg.allreduce(tensors_);
|
|
std::this_thread::sleep_for(
|
|
std::chrono::seconds(heartBeatIntervalInSec * multiplier));
|
|
EXPECT_THROW(work->wait(), c10::DistBackendError);
|
|
}
|
|
|
|
const int heartBeatIntervalInSec = 2;
|
|
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> options_;
|
|
};
|
|
|
|
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoFinished) {
|
|
ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options_);
|
|
// Write debug info will lead to watchdog thread to wait for 30 seconds.
|
|
// And this is hard to override, so we just call it before hand. Otherwise,
|
|
// we need to set a long heartbeat timeout which will make the test way
|
|
// slower.
|
|
pg.forceTryWriteDebugInfo();
|
|
watchdogTimeoutTestCommon(pg, 2);
|
|
|
|
// The flag is true shows that the heartbeat monitor thread does not kill
|
|
// the watchdog thread when it is getting debug info such as desync debug
|
|
// info.
|
|
EXPECT_TRUE(pg.getWatchDogDebugInfoFinishedFlag());
|
|
// The flag is false shows that the heartbeat monitor thread does not
|
|
// trigger process abort if getting debug info and destroy PG is fast.
|
|
EXPECT_FALSE(pg.getErrorCaughtFlag());
|
|
|
|
// Communicators might be aborted here, further operations would fail.
|
|
}
|
|
|
|
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoStuck) {
|
|
ProcessGroupNCCLDebugInfoStuck pg(store_, 0, 1, options_);
|
|
// Need to keep main thread sleep longer so that we can let heartbeat monitor
|
|
// thread to finish the extra wait and flip the flag.
|
|
watchdogTimeoutTestCommon(pg, 4);
|
|
// The flag is false shows that we get stuck in getting debug info such as
|
|
// desync debug info in the watchdog thread.
|
|
EXPECT_FALSE(pg.getWatchDogDebugInfoFinishedFlag());
|
|
// The flag is true shows that the heartbeat monitor thread does trigger
|
|
// process abort if getting debug info gets stuck.
|
|
EXPECT_TRUE(pg.getErrorCaughtFlag());
|
|
|
|
// Communicators might be aborted here, further operations would fail.
|
|
}
|