mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Compare commits
39 Commits
test_quant
...
whc_flight
Author | SHA1 | Date | |
---|---|---|---|
7c507b78c4 | |||
0019901601 | |||
18be18535b | |||
2729367313 | |||
33537aae24 | |||
dcdb1337dd | |||
9cf0f2bd59 | |||
1d2e877c05 | |||
f27b979b0c | |||
f30d6047ad | |||
75311510ef | |||
dbd6094d05 | |||
397b9d47e9 | |||
36a01a8ab9 | |||
ee336cf58a | |||
a9e2e745d7 | |||
ab4df89eea | |||
9d02ebe876 | |||
b61e01cce9 | |||
f7ce61ba53 | |||
e7bae15ab1 | |||
e71b422908 | |||
389940ce60 | |||
b2237a7c85 | |||
ef5dfe3f3e | |||
e303dc3c08 | |||
265efad2de | |||
60f0455905 | |||
4898313791 | |||
f4da9adf6b | |||
8f7f35273e | |||
44ec9612ed | |||
4d3bea2b29 | |||
0bcdddc3c1 | |||
28b6220312 | |||
210b7b65e2 | |||
4da10b5cd3 | |||
f09763814f | |||
80923ed5a6 |
@ -371,7 +371,8 @@ std::string readTraceFromFile(const std::string& filename, size_t size) {
|
||||
// Extend the nested class outside the parent class
|
||||
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
|
||||
public:
|
||||
TestDebugInfoWriter() : DebugInfoWriter(0) {}
|
||||
TestDebugInfoWriter(std::string namePrefix)
|
||||
: DebugInfoWriter(namePrefix, 0) {}
|
||||
|
||||
void write(const std::string& ncclTrace) override {
|
||||
traces_.assign(ncclTrace.begin(), ncclTrace.end());
|
||||
@ -415,10 +416,12 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
// The storer here is very similar to the fallback storer.
|
||||
// The only difference is that we are storing traces also in memory for
|
||||
// validation.
|
||||
std::string fileNamePrefix = c10d::getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
|
||||
std::make_unique<TestDebugInfoWriter>();
|
||||
std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
|
||||
std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
|
||||
pg.registerDebugInfoWriter(std::move(wrterForTestPtr));
|
||||
c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
|
||||
|
||||
// Normal collective case.
|
||||
auto work = pg.allreduce(tensors_);
|
||||
@ -449,6 +452,9 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
// TODO (kwen2501)
|
||||
GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
|
||||
<< "will rewrite them after refactoring Work queues.";
|
||||
ProcessGroupNCCLErrorsTest::SetUp();
|
||||
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
|
||||
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
|
||||
|
@ -3542,11 +3542,12 @@ class SparseCollective(MultiProcessTestCase):
|
||||
class NCCLTraceTestBase(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'
|
||||
os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0' # see 'timing_enabled' parametrized tests
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10'
|
||||
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
|
||||
self.tempdir = tempfile.TemporaryDirectory()
|
||||
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
|
||||
os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename()
|
||||
self._spawn_processes()
|
||||
|
||||
@classmethod
|
||||
@ -3617,17 +3618,27 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_short(self):
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_short(self, timing_enabled):
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
for i in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
|
||||
# gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
|
||||
time.sleep(1)
|
||||
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
ver = t['version']
|
||||
self.assertEqual(ver, "1.0")
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 2)
|
||||
last = t[-1]
|
||||
self.assertEqual(last['state'], 'completed')
|
||||
@ -3636,9 +3647,14 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
self.assertEqual(last['output_sizes'], ((3, 4),))
|
||||
self.assertEqual(last['seq_id'], 2)
|
||||
now = datetime.now()
|
||||
event_created_time = datetime.fromtimestamp(last['time_created_us'] / 1000000)
|
||||
event_created_time = datetime.fromtimestamp(last['time_created_ns'] / 1000000000)
|
||||
before_test = now - timedelta(minutes=1)
|
||||
self.assertTrue(before_test < event_created_time < now)
|
||||
if timing_enabled:
|
||||
# very loose bounds, measured 0.036 ms on devgpu
|
||||
self.assertTrue(0 < last['duration_ms'] < 100)
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in last)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
@ -3698,6 +3714,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 10)
|
||||
first = t[0]
|
||||
last = t[-1]
|
||||
@ -3732,6 +3749,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
pg.allreduce(a).wait()
|
||||
e.synchronize()
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]['seq_id'], 1)
|
||||
self.assertEqual(t[-1]['state'], 'completed')
|
||||
@ -3773,6 +3791,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
# give the other thread some time to fill the cuda buffer
|
||||
time.sleep(5)
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]['seq_id'], 1)
|
||||
self.assertEqual(t[-1]['state'], 'completed')
|
||||
@ -3832,6 +3851,9 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_timeout_dumps(self, timing_enabled):
|
||||
# We need to completely disable the coordinated timeout dump to avoid rank 0
|
||||
# also timeout so that we set the check frequency to be very large (25 min).
|
||||
os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1500000'
|
||||
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
# wait for rank0 to crash before looking for its output file
|
||||
@ -3839,6 +3861,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
|
||||
self.assertEqual(self._wait_process(0, timeout=90), -6)
|
||||
with open(self._trace_name(rank=0), 'rb') as f:
|
||||
t = pickle.load(f)
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 2)
|
||||
self.assertEqual(t[0]['seq_id'], 1)
|
||||
self.assertEqual(t[0]['state'], 'completed')
|
||||
@ -3868,7 +3891,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
|
||||
instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout)
|
||||
instantiate_parametrized_tests(NCCLTraceTest)
|
||||
|
||||
class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
def _check_return_codes(self, elapsed_time):
|
||||
# the base test infra assumes processes exit with matching return codes,
|
||||
# but we want rank0 to abort and rank1 to exit cleanly in this test
|
||||
@ -3877,7 +3900,7 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_timeout_dumps_on_idle_ranks(self):
|
||||
def test_timeout_dumps_on_stuck_ranks(self):
|
||||
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
# wait for both rank0 and 1 to crash before looking for both ranks' output
|
||||
@ -3888,20 +3911,17 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
self.assertTrue(os.path.exists(self._trace_name(rank=0)))
|
||||
with open(self._trace_name(rank=0), 'rb') as f:
|
||||
t = pickle.load(f)
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 2)
|
||||
with open(self._trace_name(rank=1), 'rb') as f:
|
||||
t = pickle.load(f)
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 1)
|
||||
self.assertEqual(t[0]['seq_id'], 1)
|
||||
self.assertEqual(t[0]['state'], 'completed')
|
||||
return
|
||||
|
||||
# Set heartbeat timeout to a shorter one (default timeout is 2 min).
|
||||
os.environ[
|
||||
"TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"
|
||||
] = f"{NCCLTraceTestDumpOnTimeoutBase.timeout_sec * 2}"
|
||||
pg = self._create_process_group_nccl()
|
||||
|
||||
device = self.local_device
|
||||
with torch.cuda.device(device):
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
@ -3910,12 +3930,14 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
if self.rank == 0:
|
||||
pg.allreduce(a).wait()
|
||||
|
||||
# rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly
|
||||
# rank 0 will get stuck, timeout and then signal a timeout to all ranks.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Force rank 1 to idle so that it also gets debug info dump triggered.
|
||||
if self.rank == 1:
|
||||
time.sleep(6)
|
||||
# Force rank 1 to idle so that it will eventually timeout as well after
|
||||
# getting the global signal to dump the debugging info.
|
||||
time.sleep(600)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
|
@ -71,7 +71,7 @@ class StoreTestBase:
|
||||
def _create_store(self, i):
|
||||
raise RuntimeError("not implemented")
|
||||
|
||||
def _test_set_get(self, fs):
|
||||
def _test_set_get_check(self, fs):
|
||||
fs.add("key", 1)
|
||||
fs.add("key", 2)
|
||||
fs.add("key", 3)
|
||||
@ -90,14 +90,16 @@ class StoreTestBase:
|
||||
self.assertEqual(b"value1", fs.get("key1"))
|
||||
self.assertEqual(b"value2", fs.get("key2"))
|
||||
self.assertEqual(b"21", fs.get("key3"))
|
||||
self.assertTrue(fs.check(["key3"]))
|
||||
self.assertFalse(fs.check(["Randomkey3"]))
|
||||
|
||||
fs.set("-key3", "7")
|
||||
self.assertEqual(b"7", fs.get("-key3"))
|
||||
fs.delete_key("-key3")
|
||||
self.assertEqual(fs.num_keys(), self.num_keys_total)
|
||||
|
||||
def test_set_get(self):
|
||||
self._test_set_get(self._create_store())
|
||||
def test_set_get_check(self):
|
||||
self._test_set_get_check(self._create_store())
|
||||
|
||||
def _test_compare_set(self, store):
|
||||
missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
|
||||
@ -441,6 +443,12 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
||||
def num_keys_total(self):
|
||||
return 6
|
||||
|
||||
def test_underlying_non_prefix_store(self):
|
||||
store = self._create_store()
|
||||
wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
|
||||
self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
|
||||
self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
|
||||
|
||||
class MyPythonStore(dist.Store):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -175,14 +175,29 @@ std::string getNcclErrorDetailStr(
|
||||
c10::optional<std::string> processGroupFailureReason = c10::nullopt);
|
||||
|
||||
// Write NCCL debug info to local disk or any storage users define.
|
||||
// There are some constrains we set for the debug info writer:
|
||||
// 1. The writer should only be registered once.
|
||||
// 2. Once registered, users cannot change it including un-register.
|
||||
// 3. It is recommended to register the customized writer in the trainer setup,
|
||||
// If users don't register before calling launchAsyncDebugDump, then users
|
||||
// lose the chance to register (and the default writer will be
|
||||
// auto-registered).
|
||||
class TORCH_API DebugInfoWriter {
|
||||
public:
|
||||
DebugInfoWriter(int rank);
|
||||
virtual ~DebugInfoWriter();
|
||||
virtual void write(const std::string& ncclTrace);
|
||||
static DebugInfoWriter& getWriter(int rank);
|
||||
static void registerWriter(std::unique_ptr<DebugInfoWriter> writer);
|
||||
|
||||
protected:
|
||||
DebugInfoWriter(std::string namePrefix, int rank) {
|
||||
filename_ = c10::str(namePrefix, rank);
|
||||
}
|
||||
std::string filename_;
|
||||
|
||||
private:
|
||||
static std::unique_ptr<DebugInfoWriter> writer_;
|
||||
static std::atomic<bool> hasWriterRegistered_;
|
||||
};
|
||||
|
||||
// RAII wrapper for NCCL communicator
|
||||
|
@ -108,4 +108,22 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
|
||||
return store_;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
|
||||
c10::intrusive_ptr<Store> store = store_;
|
||||
|
||||
while (store) {
|
||||
// Attempt to dynamically cast to PrefixStore
|
||||
PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
|
||||
if (asPrefixStore) {
|
||||
store = asPrefixStore->getUnderlyingStore();
|
||||
} else {
|
||||
break; // We've reached a non-PrefixStore
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
|
||||
return store;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -53,6 +53,9 @@ class TORCH_API PrefixStore : public Store {
|
||||
|
||||
c10::intrusive_ptr<Store> getUnderlyingStore();
|
||||
|
||||
// Recursively to fetch the store before layers of wrapping with PrefixStore.
|
||||
c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
|
||||
|
||||
protected:
|
||||
std::string prefix_;
|
||||
c10::intrusive_ptr<Store> store_;
|
||||
|
@ -341,6 +341,35 @@ std::string dump_nccl_trace() {
|
||||
return NCCLTraceBuffer::get()->dump();
|
||||
}
|
||||
|
||||
c10::optional<std::function<std::string()>>& get_cpp_trace_dumper() {
|
||||
static c10::optional<std::function<std::string()>> dumper(c10::nullopt);
|
||||
return dumper;
|
||||
}
|
||||
|
||||
gil_checker_t& get_gil_checker() {
|
||||
static gil_checker_t gil_checker = nullptr;
|
||||
return gil_checker;
|
||||
}
|
||||
|
||||
std::future<bool> launchAsyncGilCheck() {
|
||||
std::promise<bool> resultPromise;
|
||||
std::future<bool> resultFuture = resultPromise.get_future();
|
||||
TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker");
|
||||
std::thread workerThread([promise = std::move(resultPromise)]() mutable {
|
||||
try {
|
||||
auto& gil_checker = get_gil_checker();
|
||||
promise.set_value((*gil_checker)());
|
||||
} catch (...) {
|
||||
promise.set_exception(std::current_exception());
|
||||
}
|
||||
});
|
||||
|
||||
// Detach the thread to allow it to run independently
|
||||
workerThread.detach();
|
||||
|
||||
return resultFuture;
|
||||
}
|
||||
|
||||
// Return CUDA device with ordinal given by input rank. If we aren't
|
||||
// bound to a specific device, there is no strict guarantee that this
|
||||
// heuristic is the correct assignment of ranks to GPUs that Python
|
||||
@ -358,7 +387,7 @@ at::Device ProcessGroupNCCL::guessDeviceForRank() const {
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000;
|
||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100;
|
||||
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
|
||||
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
|
||||
|
||||
@ -466,12 +495,17 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
exception_ = exception_ptr;
|
||||
if (exception_) {
|
||||
LOG(INFO) << "[Rank " << rank_ << "]"
|
||||
<< " found async exception when checking for NCCL errors: "
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "found async exception when checking for NCCL errors: "
|
||||
<< getExceptionMsgFromExceptionPtr(exception_);
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
|
||||
static std::string prefix = c10::str("[Rank ", rank_, "] ");
|
||||
return prefix;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::setException(
|
||||
std::exception_ptr exception_ptr) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@ -527,9 +561,7 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
|
||||
return true;
|
||||
|
||||
std::string exceptionMsg = c10::str(
|
||||
"[Rank ",
|
||||
rank_,
|
||||
"] ",
|
||||
logPrefix(),
|
||||
"Watchdog caught collective operation timeout: ",
|
||||
*this,
|
||||
" ran for ",
|
||||
@ -550,13 +582,13 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
|
||||
"Some NCCL operations have failed or timed out. Due to the ",
|
||||
"asynchronous nature of CUDA kernels, subsequent GPU operations ",
|
||||
"might run on corrupted/incomplete data.");
|
||||
LOG(ERROR) << exceptionMsg;
|
||||
LOG(ERROR) << logPrefix() << exceptionMsg;
|
||||
C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException");
|
||||
|
||||
if (SHOULD_TEAR_DOWN(errorHandling)) {
|
||||
auto tearDownMsg = c10::str(
|
||||
"To avoid data inconsistency, we are taking the entire process down.");
|
||||
LOG(ERROR) << tearDownMsg;
|
||||
LOG(ERROR) << logPrefix() << tearDownMsg;
|
||||
std::rethrow_exception(exception_);
|
||||
}
|
||||
}
|
||||
@ -597,9 +629,8 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
|
||||
// can not run new events successfully.
|
||||
if (timedOut) {
|
||||
std::string exceptionMsg = c10::str(
|
||||
"[Rank ",
|
||||
rank_,
|
||||
"] Work ",
|
||||
logPrefix(),
|
||||
"Work ",
|
||||
(*this),
|
||||
" timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1).");
|
||||
LOG(ERROR) << exceptionMsg;
|
||||
@ -713,6 +744,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
ValueError,
|
||||
at::cuda::getNumGPUs() != 0,
|
||||
"ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
|
||||
logPrefix_ = createLogPrefix();
|
||||
blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
|
||||
asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
|
||||
getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/));
|
||||
@ -723,8 +755,18 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
heartbeat_ = 1ULL;
|
||||
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
|
||||
heartbeatTimeoutInSec_ =
|
||||
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 2 /*2 Mins*/);
|
||||
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10 /*10 Mins*/);
|
||||
waitTimeoutDumpInMilSec_ =
|
||||
getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 2000);
|
||||
coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
|
||||
ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
|
||||
// store_ usually is wrapped with PrefixStore and the prefix is different
|
||||
// across different ProcessGroupNCCL(PG) instances. We need to get the
|
||||
// underlying non-PrefixStore for sharing global information shared across
|
||||
// different PGs.
|
||||
PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
|
||||
globalStore_ =
|
||||
prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
enableTiming_.store(
|
||||
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
|
||||
@ -737,15 +779,15 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
expandable_segments()) {
|
||||
useTensorRegisterAllocatorHook_ = false;
|
||||
LOG(INFO)
|
||||
<< "[Rank " << rank_
|
||||
<< "] disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode.";
|
||||
<< logPrefix()
|
||||
<< "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode.";
|
||||
}
|
||||
#endif
|
||||
|
||||
if (blockingWait_) {
|
||||
if (asyncErrorHandling_ != NoHandling || desyncDebug_) {
|
||||
LOG(INFO)
|
||||
<< "[Rank " << rank_ << "] TORCH_NCCL_BLOCKING_WAIT and "
|
||||
<< logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and "
|
||||
<< "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG"
|
||||
<< "should not both be enabled. "
|
||||
<< "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process.";
|
||||
@ -755,8 +797,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
} else {
|
||||
if (desyncDebug_ && asyncErrorHandling_ == NoHandling) {
|
||||
LOG(INFO)
|
||||
<< "[Rank " << rank_
|
||||
<< "] TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING "
|
||||
<< logPrefix()
|
||||
<< "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING "
|
||||
<< "must both be enabled. "
|
||||
<< "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING.";
|
||||
asyncErrorHandling_ = SkipCleanUp;
|
||||
@ -781,12 +823,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
const std::string OFF = "OFF";
|
||||
std::string torch_distributed_debug =
|
||||
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
|
||||
std::string nccl_debug = getCvarString({"NCCL_DEBUG"}, OFF.c_str());
|
||||
LOG(INFO) << "[Rank " << rank_
|
||||
<< "] ProcessGroupNCCL initialization options: "
|
||||
<< "NCCL version: " << getNcclVersion()
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: "
|
||||
<< "NCCL version: " << getNcclVersion() << ", size: " << size
|
||||
<< ", global rank: " << globalRank()
|
||||
<< ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
|
||||
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeout_
|
||||
<< ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
|
||||
<< waitTimeoutDumpInMilSec_
|
||||
<< ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
|
||||
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
|
||||
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
|
||||
@ -804,7 +847,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
<< monitorThreadEnabled_.load()
|
||||
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
|
||||
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
|
||||
<< ", NCCL_DEBUG: " << nccl_debug << ", ID=" << this->getID();
|
||||
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
|
||||
<< ", ID=" << this->getID();
|
||||
|
||||
RECORD_PARAM_COMMS(
|
||||
0, // seq
|
||||
@ -835,7 +879,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
|
||||
std::vector<at::Device> rankDevices = {device};
|
||||
const auto key = getKeyFromDevices(rankDevices);
|
||||
LOG(INFO) << "Eagerly connecting nccl backend with device " << device;
|
||||
LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device "
|
||||
<< device;
|
||||
getNCCLComm(key, rankDevices, OpType::ALLREDUCE);
|
||||
}
|
||||
|
||||
@ -846,8 +891,8 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
std::vector<at::Device> rankDevices = {device};
|
||||
const auto key = getKeyFromDevices(rankDevices);
|
||||
LOG(INFO) << "Performing nocolor split on backend device " << device
|
||||
<< ", key " << key << ", i am " << this;
|
||||
LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
|
||||
<< device << ", key " << key << ", i am " << this;
|
||||
auto comm = getNCCLComm(key, rankDevices, OpType::ALLREDUCE);
|
||||
TORCH_CHECK_WITH(
|
||||
DistBackendError,
|
||||
@ -897,8 +942,7 @@ void ProcessGroupNCCL::runHealthCheck() {
|
||||
// We don't need to join the thread, just need to verify health check via the
|
||||
// CV. Hence we detach the thread here.
|
||||
t.detach(); // NOLINT
|
||||
LOG(INFO) << "[Rank " << rank_ << "]"
|
||||
<< " will wait up to " << options_->timeout.count()
|
||||
LOG(INFO) << logPrefix() << "will wait up to " << options_->timeout.count()
|
||||
<< " msec for NCCL health check to complete.";
|
||||
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
|
||||
healthCheckData.healthCheckCv.wait_for(
|
||||
@ -1002,10 +1046,49 @@ std::future<bool> ProcessGroupNCCL::launchAsyncDebugDump() {
|
||||
return resultFuture;
|
||||
}
|
||||
|
||||
void abortCommsFromMap(
|
||||
std::chrono::time_point<std::chrono::steady_clock> getWakeupTime(
|
||||
int intervalInMilSec) {
|
||||
return std::chrono::steady_clock::now() +
|
||||
std::chrono::milliseconds(intervalInMilSec);
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::waitForDumpOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
|
||||
size_t timeout_sec) {
|
||||
TORCH_CHECK(fut.valid(), "Expected a valid future");
|
||||
|
||||
auto futStatus = fut.wait_for(std::chrono::seconds(timeout_sec));
|
||||
TORCH_CHECK(
|
||||
futStatus != std::future_status::deferred, "Expected eager launch.");
|
||||
if (futStatus == std::future_status::ready) {
|
||||
// Calling .get() will re-raise any exception from the future, and we don't
|
||||
// care about the retval
|
||||
try {
|
||||
fut.get();
|
||||
std::this_thread::sleep_until(wakeUpTime);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << logPrefix()
|
||||
<< "Caught exception during async debug dump: \"" << e.what()
|
||||
<< "\"\n";
|
||||
} catch (...) {
|
||||
LOG(ERROR) << logPrefix()
|
||||
<< "Caught unknown exception during async debug dump.";
|
||||
}
|
||||
} else {
|
||||
LOG(INFO)
|
||||
<< logPrefix() << "Debug dump timed out and is being abandoned."
|
||||
<< " This may be due to slow ADDR2LINE performance processing stacktraces."
|
||||
<< " Try TORCH_DISABLE_ADDR2LINE=1 and TORCH_NCCL_TRACE_CPP_STACK=0 to work around.";
|
||||
}
|
||||
// Ensure we sleep at least until wakeUpTime regardless of future execution
|
||||
// time
|
||||
std::this_thread::sleep_until(wakeUpTime);
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::abortCommsFromMap(
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
|
||||
ncclCommsMap,
|
||||
const int rank,
|
||||
c10::optional<std::string> abortReason) {
|
||||
// The process may control multiple devices, loop through the communicators on
|
||||
// each device
|
||||
@ -1026,8 +1109,17 @@ void abortCommsFromMap(
|
||||
// their responsibility to destroy the process group and recreate
|
||||
// it to recover from errors.
|
||||
|
||||
LOG(INFO) << "[Rank " << rank << "] Destroyed " << ncclComms.size()
|
||||
<< "communicators on CUDA device " << devName;
|
||||
c10::StreamId streamId = -1;
|
||||
if (ncclStreams_.find(devName) != ncclStreams_.end()) {
|
||||
auto streams = ncclStreams_.at(devName);
|
||||
if (streams.size() > 0) {
|
||||
streamId = streams[0].id();
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << logPrefix() << "] Destroyed " << ncclComms.size()
|
||||
<< "communicators on CUDA device: " << devName
|
||||
<< " with stream: " << streamId;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1048,8 +1140,8 @@ void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
abortCommsFromMap(devNCCLCommMap_, rank_, abortReason);
|
||||
abortCommsFromMap(inInitializationCommMap_, rank_, abortReason);
|
||||
abortCommsFromMap(devNCCLCommMap_, abortReason);
|
||||
abortCommsFromMap(inInitializationCommMap_, abortReason);
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::shutdown() {
|
||||
@ -1067,6 +1159,7 @@ void ProcessGroupNCCL::shutdown() {
|
||||
}
|
||||
|
||||
ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
|
||||
terminateProcessGroup_.store(true);
|
||||
workMetaListCV_.notify_one();
|
||||
|
||||
@ -1074,6 +1167,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
if (ncclCommWatchdogThread_.joinable()) {
|
||||
ncclCommWatchdogThread_.join();
|
||||
}
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
|
||||
#endif
|
||||
|
||||
if (onCompletionHookThread_.joinable())
|
||||
@ -1083,6 +1177,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
// threads dying due to aborted communicator and raising a SIGABRT
|
||||
std::string abortReason = c10::str("Process Group destroyed on rank ", rank_);
|
||||
abort(abortReason);
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL abort finished.";
|
||||
|
||||
// We need to wait for abort to finish before we can safely shut down
|
||||
// heartbeat monitoring thread.
|
||||
@ -1095,33 +1190,20 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
#endif
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::registerDebugInfoWriter(
|
||||
std::unique_ptr<DebugInfoWriter> writer) {
|
||||
TORCH_CHECK_WITH(
|
||||
DistBackendError,
|
||||
debugInfoWriter_ == nullptr,
|
||||
"ProcessGroupNCCL debugInfoWriter already registered");
|
||||
debugInfoWriter_ = std::move(writer);
|
||||
}
|
||||
|
||||
bool ProcessGroupNCCL::dumpDebuggingInfo() {
|
||||
// Serialize all calls to this function to avoid corrupting data, but allow
|
||||
// multiple calls in one runtime. User is responsible for preserving the
|
||||
// output file from an earlier call before a later call overwrites it.
|
||||
std::lock_guard<std::mutex> lock(writeDebugInfoMutex_);
|
||||
LOG(ERROR) << "ProcessGroupNCCL preparing to dump debug info.";
|
||||
static std::mutex writeDebugInfoMutex;
|
||||
std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
|
||||
LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
|
||||
if (ncclTraceBufferSize_ > 0) {
|
||||
// We dump nccl trace into local disk by default and users can register
|
||||
// their customized writer by inheriting `DebugInfoWriter` via
|
||||
// `registerDebugInfoWriter`.
|
||||
auto ncclTrace = dump_nccl_trace();
|
||||
if (debugInfoWriter_ == nullptr) {
|
||||
// Dump the trace blob into local disk as a fallback.
|
||||
std::unique_ptr<DebugInfoWriter> debugInfoWriterPtr =
|
||||
std::make_unique<DebugInfoWriter>(rank_);
|
||||
registerDebugInfoWriter(std::move(debugInfoWriterPtr));
|
||||
}
|
||||
debugInfoWriter_->write(ncclTrace);
|
||||
DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank());
|
||||
writer.write(ncclTrace);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -1130,48 +1212,133 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() {
|
||||
void ProcessGroupNCCL::terminateProcess(std::string errMsg) {
|
||||
// Logging with `FATAL`, after errMsg printed, it calls `std::abort()`
|
||||
// to terminate the program execution.
|
||||
LOG(FATAL) << errMsg;
|
||||
LOG(FATAL) << logPrefix() << errMsg;
|
||||
}
|
||||
|
||||
int computeDeltaMS(
|
||||
std::chrono::time_point<std::chrono::steady_clock> start,
|
||||
std::chrono::time_point<std::chrono::steady_clock> end) {
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
|
||||
.count();
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
uint64_t heartBeatCounter = 0ULL;
|
||||
std::string errorMsg;
|
||||
std::string exitMsg;
|
||||
bool checkTimeoutSignal = (dumpOnTimeout_ && uid_ == 0);
|
||||
int monitorPollInterval = checkTimeoutSignal ? coordCheckIntervalMilSec_
|
||||
: heartbeatTimeoutInSec_ * 1000;
|
||||
auto lastTimePollStore = std::chrono::steady_clock::now();
|
||||
auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
|
||||
std::future<bool> asyncDebugDump;
|
||||
while (true) {
|
||||
// This won't have any lock since this lock is only used here.
|
||||
// Please be aware that mutex `monitorMutex_` should not be used
|
||||
// somewhere else to avoid the deadlock.
|
||||
std::unique_lock<std::mutex> lock(monitorMutex_);
|
||||
if (monitorWakeUpCV_.wait_for(
|
||||
lock, std::chrono::seconds(heartbeatTimeoutInSec_), [&] {
|
||||
lock, std::chrono::milliseconds(monitorPollInterval), [&] {
|
||||
return terminateHeartbeatMonitorThread_.load();
|
||||
})) {
|
||||
// For the normal complete or user interception, monitorWakeUpCV_
|
||||
// will get notified, we early return and exit heartbeatMonitor.
|
||||
return;
|
||||
}
|
||||
auto currentTime = std::chrono::steady_clock::now();
|
||||
|
||||
// Check the heart beat of watchdog thread.
|
||||
auto heartbeat = heartbeat_;
|
||||
if (heartbeat != heartBeatCounter) {
|
||||
heartBeatCounter = heartbeat;
|
||||
} else {
|
||||
// No heartbeat increase detected and timeout.
|
||||
break;
|
||||
// We put extra functionality in the thread for the default PG (aka, uid_=0)
|
||||
// because the signal is same across different PGs. We only need to run
|
||||
// once per process to avoid duplicate things performed in too many separate
|
||||
// threads. For example, we check a global flag on the TCPStore periodically
|
||||
// to see if any PG on any rank observed a timeout and signaled peers to
|
||||
// dump debugging info, and we avoid hammering the TCPStore from all PGs on
|
||||
// the same rank.
|
||||
if (checkTimeoutSignal) {
|
||||
// We poll store to see if some ranks have flagged a timeout when
|
||||
// we haven't polled for `heartbeat_timeout` seconds and there haven't
|
||||
// any work added or removed for `watchdog_timeout` seconds.
|
||||
if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >=
|
||||
kWatchdogThreadSleepMillis &&
|
||||
computeDeltaMS(lastTimePollStore, currentTime) >=
|
||||
coordCheckIntervalMilSec_) {
|
||||
lastTimePollStore = currentTime;
|
||||
if (globalStore_->check({std::string(TIMEOUT_DUMP)})) {
|
||||
errorMsg = c10::str(
|
||||
logPrefix(),
|
||||
"Received a global timeout from another rank and will ",
|
||||
"start to dump the debug info.");
|
||||
exitMsg = c10::str(
|
||||
"ProcessGroupNCCL's watchdog detected a collective timeout and notified current rank. ",
|
||||
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
|
||||
"sizes used across ranks, the order of collectives is not same for all ranks ",
|
||||
"or the scheduled collective, for some reason, didn't run. Additionally, ",
|
||||
"this can be caused by GIL deadlock or other reasons such as network errors or ",
|
||||
"bugs in the communications library (e.g. NCCL), etc. We tried our best to ",
|
||||
"dump the debug info into the storage to help you debug the issue.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >=
|
||||
heartbeatTimeoutInSec_ * 1000) {
|
||||
// Check the heart beat of watchdog thread.
|
||||
lastTimeHeartBeatCheck = currentTime;
|
||||
auto heartbeat = heartbeat_.load();
|
||||
if (heartbeat != heartBeatCounter) {
|
||||
heartBeatCounter = heartbeat;
|
||||
} else {
|
||||
// No heartbeat increase detected and timeout.
|
||||
errorMsg = c10::str(
|
||||
logPrefix(),
|
||||
"Heartbeat monitor timed out! Process will be terminated after dumping debug info.",
|
||||
" workMetaList_.size()=",
|
||||
workMetaList_.size());
|
||||
exitMsg = c10::str(
|
||||
"ProcessGroupNCCL's watchdog got stuck for ",
|
||||
heartbeatTimeoutInSec_,
|
||||
" seconds without making progress in monitoring enqueued collectives. ",
|
||||
"This typically indicates a NCCL/CUDA API hang blocking the watchdog, ",
|
||||
"and could be triggered by another thread holding the GIL inside a ",
|
||||
"CUDA api, or other deadlock-prone behaviors.",
|
||||
"If you suspect the watchdog is not actually stuck and a longer timeout would help, ",
|
||||
"you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value "
|
||||
"or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)."
|
||||
"If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout "
|
||||
"or false positive abort; otherwise, please attempt to debug the hang.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG(ERROR) << errorMsg;
|
||||
|
||||
auto& cpp_dumper = get_cpp_trace_dumper();
|
||||
if (cpp_dumper.has_value()) {
|
||||
LOG(INFO) << "Dumping c++ stacktraces: " << cpp_dumper.value()();
|
||||
}
|
||||
|
||||
auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
|
||||
// Store debug info to storage if no other thread does it. (By default to
|
||||
// local disk)
|
||||
std::future<bool> asyncDebugDump = launchAsyncDebugDump();
|
||||
asyncDebugDump = launchAsyncDebugDump();
|
||||
|
||||
// Create a error message reported from MonitorThread, so
|
||||
// we throw exception and make the whole process to be killed.
|
||||
const auto exitMsg = c10::str(
|
||||
"[Rank ",
|
||||
rank_,
|
||||
"] NCCL monitor thread timeout. Basically, this could ",
|
||||
"be due to CUDA or NCCL calls being unexpectedly blocking, ",
|
||||
"especially when your program enters a deadlock state in watchdog "
|
||||
"or destructors. If you see this error, please file a bug to PyTorch.");
|
||||
if (get_gil_checker() != nullptr) {
|
||||
auto fut = launchAsyncGilCheck();
|
||||
auto kGilCheckTimeout = std::chrono::milliseconds(300);
|
||||
auto futStatus = fut.wait_for(kGilCheckTimeout);
|
||||
if (futStatus != std::future_status::ready) {
|
||||
TORCH_CHECK(
|
||||
futStatus != std::future_status::deferred,
|
||||
"Expected the future to have been launched eagerly.");
|
||||
LOG(ERROR)
|
||||
<< "Could not acquire GIL within 300 ms on exit, possible GIL induced hang";
|
||||
}
|
||||
LOG(INFO) << "Could acquire GIL on exit";
|
||||
} else {
|
||||
LOG(INFO)
|
||||
<< "GIL checker was not registered, perhaps this is a no-python build?";
|
||||
}
|
||||
|
||||
// There are two possible cases for the watchdog thread exit:
|
||||
// Case one: desync report runs quickly, and it follows the step:
|
||||
@ -1197,46 +1364,40 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// We already log completion inside the thread, so it may not be necessary to
|
||||
// check the return value here. We mainly use a future so we can exit early
|
||||
// if done.
|
||||
asyncDebugDump.wait_for(std::chrono::seconds(heartbeatTimeoutInSec_));
|
||||
waitForDumpOrTimeout(asyncDebugDump, wakeUpTime);
|
||||
|
||||
if (!terminateHeartbeatMonitorThread_.load()) {
|
||||
const auto logMsg = c10::str(
|
||||
"[Rank ",
|
||||
rank_,
|
||||
"] monitoring thread detects no heartbeat and will finally kill the process!",
|
||||
" terminateProcessGroup_",
|
||||
terminateProcessGroup_,
|
||||
" collectiveDebugInfoMode_",
|
||||
collectiveDebugInfoMode_);
|
||||
LOG(ERROR) << logMsg;
|
||||
terminateProcess(exitMsg);
|
||||
// Create a error message reported from MonitorThread, so
|
||||
// we throw exception and make the whole process to be killed.
|
||||
// TODO(fduwjj): After having a hang debug wiki, we need to update the wiki
|
||||
// url here.
|
||||
const auto finalExitMsg = c10::str(logPrefix(), exitMsg);
|
||||
terminateProcess(finalExitMsg);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||
try {
|
||||
VLOG(2) << "[Rank " << rank_ << "] NCCL watchdog thread started!";
|
||||
VLOG(2) << logPrefix() << "NCCL watchdog thread started!";
|
||||
if (monitorThreadEnabled_.load()) {
|
||||
ncclHeartbeatMonitorThread_ =
|
||||
std::thread(&ProcessGroupNCCL::heartbeatMonitor, this);
|
||||
}
|
||||
watchdogHandler();
|
||||
VLOG(2) << "[Rank " << rank_
|
||||
<< "] NCCL watchdog thread terminated normally";
|
||||
VLOG(2) << logPrefix() << "NCCL watchdog thread terminated normally";
|
||||
} catch (std::exception& e) {
|
||||
if (std::string(e.what()).find("driver shutting down") !=
|
||||
std::string::npos) {
|
||||
LOG(INFO)
|
||||
<< "[Rank " << rank_
|
||||
<< "] main process destroyed cuda before watchdog loop exited, terminating watchdog."
|
||||
<< logPrefix()
|
||||
<< "main process destroyed cuda before watchdog loop exited, terminating watchdog."
|
||||
<< " (Watchdog caught exception: " << e.what();
|
||||
|
||||
} else {
|
||||
// Append error message reported from watchdogHandler
|
||||
const auto exitMsg = c10::str(
|
||||
"[Rank ",
|
||||
rank_,
|
||||
"] NCCL watchdog thread terminated with exception: ",
|
||||
logPrefix(),
|
||||
"NCCL watchdog thread terminated with exception: ",
|
||||
e.what());
|
||||
LOG(ERROR) << exitMsg;
|
||||
// TODO(whc) clean up the rethrow - why is it stored in a class var and
|
||||
@ -1247,9 +1408,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||
}
|
||||
} catch (...) {
|
||||
const auto exitMsg = c10::str(
|
||||
"[Rank ",
|
||||
rank_,
|
||||
"] NCCL watchdog thread terminated with exception: unknown");
|
||||
logPrefix(), "NCCL watchdog thread terminated with exception: unknown");
|
||||
LOG(ERROR) << exitMsg;
|
||||
watchDogException_ =
|
||||
std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg));
|
||||
@ -1288,12 +1447,13 @@ std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() {
|
||||
|
||||
#if defined(__linux__)
|
||||
struct DumpPipe {
|
||||
DumpPipe(bool enabled, int rank) {
|
||||
if (!enabled) {
|
||||
DumpPipe(int rank) {
|
||||
std::string fileStem =
|
||||
getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
|
||||
if (fileStem.empty() ||
|
||||
getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
|
||||
return;
|
||||
}
|
||||
std::string fileStem = getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
|
||||
std::string filename = c10::str(fileStem, rank, ".pipe");
|
||||
TORCH_CHECK(
|
||||
@ -1305,6 +1465,8 @@ struct DumpPipe {
|
||||
"Error creating named pipe ",
|
||||
filename);
|
||||
fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
|
||||
LOG(INFO) << "Pipe file " << filename
|
||||
<< " has been opened, write to it to trigger NCCL Debug Dump.";
|
||||
TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
|
||||
}
|
||||
bool shouldDump() {
|
||||
@ -1329,22 +1491,40 @@ struct DumpPipe {
|
||||
};
|
||||
#else
|
||||
struct DumpPipe {
|
||||
DumpPipe(bool enabled, int rank) {}
|
||||
DumpPipe(int rank) {}
|
||||
bool shouldDump() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
std::string ProcessGroupNCCL::createLogPrefix() const {
|
||||
return c10::str("[PG ", uid_, " Rank ", rank_, "] ");
|
||||
}
|
||||
|
||||
const std::string& ProcessGroupNCCL::logPrefix() const {
|
||||
return logPrefix_;
|
||||
}
|
||||
|
||||
const int& ProcessGroupNCCL::globalRank() const {
|
||||
static int globalRank = rank_;
|
||||
return globalRank;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::watchdogHandler() {
|
||||
bool done = false;
|
||||
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
|
||||
auto lastTimePollStore = std::chrono::steady_clock::now();
|
||||
c10::optional<std::future<bool>> optAsyncDebugDump;
|
||||
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
|
||||
|
||||
DumpPipe dumpPipe(dumpOnTimeout_, rank_);
|
||||
c10::optional<DumpPipe> dumpPipe = c10::nullopt;
|
||||
if (uid_ == 0) {
|
||||
// DumpPipe is one per-trainer process, and its convenient to name them
|
||||
// after 'global' ranks in the system, So we assume processgroup (uid)==0 is
|
||||
// the global PG and has globally unique rank ids across trainers.
|
||||
dumpPipe.emplace(rank_);
|
||||
}
|
||||
while (!done || !terminateProcessGroup_.load()) {
|
||||
std::unique_lock<std::mutex> lock(workMetaListMutex_);
|
||||
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
||||
@ -1356,37 +1536,6 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
// Bump up heart beat by one.
|
||||
heartbeat_++;
|
||||
|
||||
// poll store to see if some ranks have flagged a timeout when
|
||||
// we haven't polled for `heartbeat_timeout` seconds and there haven't
|
||||
// any work added or removed for `watchdog_timeout` seconds.
|
||||
if (dumpOnTimeout_) {
|
||||
auto currentTime = std::chrono::steady_clock::now();
|
||||
auto timeSinceLastWorkListUpdate =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
(currentTime - lastWorkListUpdateTime_))
|
||||
.count();
|
||||
auto timeSinceLastPollStore =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
(currentTime - lastTimePollStore))
|
||||
.count();
|
||||
if (timeSinceLastWorkListUpdate >= kWatchdogThreadSleepMillis &&
|
||||
timeSinceLastPollStore >= heartbeatTimeoutInSec_ * 1000) {
|
||||
lastTimePollStore = currentTime;
|
||||
if (store_->check({std::string(TIMEOUT_DUMP)}) && !optAsyncDebugDump) {
|
||||
optAsyncDebugDump = launchAsyncDebugDump();
|
||||
optAsyncDebugDump->wait_for(
|
||||
std::chrono::milliseconds(kWatchdogThreadSleepMillis * 30));
|
||||
const auto exitMsg = c10::str(
|
||||
"Some other rank's watchdog thread detected a timeout and notified ",
|
||||
"all other ranks, so we're dumping debug info and aborting [Rank ",
|
||||
rank_,
|
||||
"] as well.");
|
||||
LOG(ERROR) << exitMsg;
|
||||
C10_THROW_ERROR(DistBackendError, exitMsg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it = workMetaList_.begin(); it != workMetaList_.end();
|
||||
/* no increment */) {
|
||||
auto& work = *it;
|
||||
@ -1411,9 +1560,10 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
// abort process immediately.
|
||||
collectiveDebugInfoMode_.store(true);
|
||||
std::vector<uint8_t> vec(1);
|
||||
store_->set(std::string(TIMEOUT_DUMP), vec);
|
||||
globalStore_->set(std::string(TIMEOUT_DUMP), vec);
|
||||
}
|
||||
|
||||
auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
|
||||
if (dumpOnTimeout_ && !optAsyncDebugDump) {
|
||||
// Store debug info to storage. (By default to local disk)
|
||||
optAsyncDebugDump = launchAsyncDebugDump();
|
||||
@ -1421,20 +1571,21 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
|
||||
if (desyncDebug_) {
|
||||
auto desyncMsg = getNCCLWatchdogDebugInfo();
|
||||
LOG(ERROR) << desyncMsg;
|
||||
LOG(ERROR) << logPrefix() << desyncMsg;
|
||||
}
|
||||
|
||||
if (dumpOnTimeout_) {
|
||||
// Store debug info to storage. (By default to local disk)
|
||||
optAsyncDebugDump->wait_for(
|
||||
std::chrono::milliseconds(kWatchdogThreadSleepMillis * 30));
|
||||
waitForDumpOrTimeout(*optAsyncDebugDump, wakeUpTime);
|
||||
}
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
|
||||
LOG(ERROR) << logPrefix()
|
||||
<< "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
|
||||
<< " Please file an issue. Error: " << e.what();
|
||||
} catch (...) {
|
||||
LOG(ERROR)
|
||||
<< logPrefix()
|
||||
<< "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
|
||||
<< " Please file an issue.";
|
||||
}
|
||||
@ -1455,7 +1606,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
|
||||
// Clean up completed work
|
||||
if (work.isCompleted()) {
|
||||
NCCLTraceBuffer::get()->retire_id(work.trace_id_);
|
||||
NCCLTraceBuffer::get()->retire_id(work.trace_id_, true);
|
||||
if (onCompletionHook_) {
|
||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||
// thread
|
||||
@ -1475,9 +1626,14 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
// completed.
|
||||
++it;
|
||||
}
|
||||
// Increment heartbeat after each work processed,
|
||||
// in case processing is slowed down (but not hung) by cuda api contention
|
||||
heartbeat_++;
|
||||
}
|
||||
// process a request to dump the trace
|
||||
if (dumpPipe.shouldDump()) {
|
||||
// process a request to dump the trace. only PG uid 0 will respond to dump
|
||||
// requests, but this is fine since all PG's feed into the same flight
|
||||
// recorder and dump.
|
||||
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
|
||||
launchAsyncDebugDump();
|
||||
}
|
||||
done = workMetaList_.empty();
|
||||
@ -1523,8 +1679,8 @@ void ProcessGroupNCCL::runHookLoop() {
|
||||
if (std::string(e.what()).find("driver shutting down") !=
|
||||
std::string::npos) {
|
||||
LOG(INFO)
|
||||
<< "[Rank " << rank_
|
||||
<< "] main process destroyed cuda before runHookLoop exited, terminating runHookLoop."
|
||||
<< logPrefix()
|
||||
<< "main process destroyed cuda before runHookLoop exited, terminating runHookLoop."
|
||||
<< " (runHookLoop caught exception: " << e.what();
|
||||
|
||||
} else {
|
||||
@ -1698,7 +1854,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
if (bound_device_id_) {
|
||||
for (const auto& device : devices) {
|
||||
if (*bound_device_id_ != device) {
|
||||
LOG(ERROR) << "Tensor found on device " << device
|
||||
LOG(ERROR) << logPrefix() << "Tensor found on device " << device
|
||||
<< " but backend constrained to " << *bound_device_id_;
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
@ -1850,9 +2006,8 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
|
||||
// At this point NCCL should have been initialized, hence we can accurately
|
||||
// get the env value even if NCCL sets it by reading from nccl.conf file
|
||||
if (getRank() == 0) {
|
||||
LOG(INFO) << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A");
|
||||
}
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A");
|
||||
|
||||
// See [Group Start/End Note]
|
||||
for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
|
||||
@ -2148,18 +2303,16 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
|
||||
}
|
||||
|
||||
float ProcessGroupNCCL::WorkNCCL::getDuration() const {
|
||||
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
|
||||
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled");
|
||||
TORCH_CHECK(
|
||||
ncclStartEvents_->size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup.");
|
||||
ncclStartEvents_,
|
||||
"getDuration only works if ncclStartEvents_ is populated, true if timing enabled");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents_->size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup.");
|
||||
TORCH_CHECK(
|
||||
(*ncclEndEvents_)[0].query(),
|
||||
"getDuration can only be called after work is succeeded.")
|
||||
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
|
||||
ncclEndEvents_,
|
||||
"getDuration only works if ncclEndEvents_ is populated, which should always be true");
|
||||
return getDurationFromFirstEvent(*ncclStartEvents_, *ncclEndEvents_);
|
||||
}
|
||||
|
||||
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
|
||||
return seq_;
|
||||
}
|
||||
@ -3420,14 +3573,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
|
||||
// ensure that each process is on a different GPU
|
||||
auto numGPUs = at::cuda::getNumGPUs();
|
||||
int16_t deviceIdx = static_cast<int16_t>(rank_ % numGPUs);
|
||||
LOG(INFO) << c10::str(
|
||||
"Rank ",
|
||||
this->getRank(),
|
||||
" using GPU ",
|
||||
deviceIdx,
|
||||
" to perform barrier as devices used by this process are currently unknown. ",
|
||||
"This can potentially cause a hang if this rank to GPU mapping is incorrect.",
|
||||
"Specify device_ids in barrier() to force use of a particular device.");
|
||||
LOG(INFO)
|
||||
<< logPrefix()
|
||||
<< c10::str(
|
||||
" using GPU ",
|
||||
deviceIdx,
|
||||
" to perform barrier as devices used by this process are currently unknown. ",
|
||||
"This can potentially cause a hang if this rank to GPU mapping is incorrect.",
|
||||
"Specify device_ids in barrier() to force use of a particular device.");
|
||||
devices.emplace_back(guessDeviceForRank());
|
||||
} else {
|
||||
for (auto usedDeviceIdx : usedDeviceIdxs_) {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
@ -12,6 +13,7 @@
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
@ -26,49 +28,71 @@
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace c10d {
|
||||
// Environment variable which controls whether we perform a NCCL healt check
|
||||
// Control whether we perform a NCCL health check or not
|
||||
// which ensures communicators are healthy at the beginning of init.
|
||||
static std::vector<std::string> TORCH_ENABLE_NCCL_HEALTH_CHECK = {
|
||||
"TORCH_ENABLE_NCCL_HEALTH_CHECK",
|
||||
"ENABLE_NCCL_HEALTH_CHECK"};
|
||||
|
||||
// Environment variable which controls whether or not wait() is blocking or
|
||||
// non-blocking.
|
||||
// Control whether or not wait() is blocking or non-blocking.
|
||||
static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
|
||||
"TORCH_NCCL_BLOCKING_WAIT",
|
||||
"NCCL_BLOCKING_WAIT"};
|
||||
|
||||
// Environment variable which controls whether or not we perform Async Error
|
||||
// Handling with NCCL.
|
||||
// Control whether or not we perform Async Error Handling with NCCL.
|
||||
static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING",
|
||||
"NCCL_ASYNC_ERROR_HANDLING"};
|
||||
|
||||
// Environment Variable to control whether dumping debug info on watchdog
|
||||
// Control whether dumping debug info on watchdog
|
||||
// timeout is enabled. This variable must be set together with
|
||||
// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0.
|
||||
static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = {
|
||||
"TORCH_NCCL_DUMP_ON_TIMEOUT"};
|
||||
|
||||
// Environment Variable to control whether Desync Debug is enabled.
|
||||
// This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
|
||||
// Control whether Desync Debug is enabled. This variable must be set
|
||||
// together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
|
||||
static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = {
|
||||
"TORCH_NCCL_DESYNC_DEBUG",
|
||||
"NCCL_DESYNC_DEBUG"};
|
||||
|
||||
// Enable recording start-events for all ProcessGroupNCCL collectives, and
|
||||
// compute accurate collective timing per-collective. (Note: end-events are
|
||||
// recorded by default. Turn on this flag can increase chances of a watchdog
|
||||
// hang due to performing a CUDA event query which eventually calls
|
||||
// cudaEventElapsedTime() API.
|
||||
static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = {
|
||||
"TORCH_NCCL_ENABLE_TIMING",
|
||||
"NCCL_ENABLE_TIMING"};
|
||||
|
||||
// Enable monitoring thread which aborts the process when the ProcessGroupNCCL
|
||||
// Watchdog thread gets stuck and no heartbeat is detected after
|
||||
// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL
|
||||
// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged
|
||||
// time than necessary tying up cluster resources.
|
||||
static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = {
|
||||
"TORCH_NCCL_ENABLE_MONITORING"};
|
||||
|
||||
// Control the watchdog heartbeat timeout period after which the monitoring
|
||||
// thread will abort the process.
|
||||
static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = {
|
||||
"TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"};
|
||||
|
||||
// The maximum number of events we store in the flight recorder's ring buffer.
|
||||
// (One event could be the start or end of a collective, for example).
|
||||
static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = {
|
||||
"TORCH_NCCL_TRACE_BUFFER_SIZE"};
|
||||
|
||||
// Control how much extra time we will wait for dumping the debugging info
|
||||
// before we exit and throws timeout exception.
|
||||
static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
|
||||
"TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"};
|
||||
|
||||
// Control the interval inside the watchdog thread to check the coordinated
|
||||
// signal from other ranks, e.g. to dump the debugging information.
|
||||
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
|
||||
"TORCH_NCCL_COORD_CHECK_MILSEC"};
|
||||
|
||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
|
||||
|
||||
constexpr const char* TIMEOUT_DUMP = "timeout_dump";
|
||||
@ -205,6 +229,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
uint64_t getSequencenumber() const override;
|
||||
|
||||
const std::string& logPrefix() const;
|
||||
|
||||
// Helper function that sets an exception_ptr on the WorkNCCL object.
|
||||
void setException(std::exception_ptr exception_ptr);
|
||||
|
||||
@ -540,8 +566,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
void enableCollectivesTiming() override;
|
||||
|
||||
// Provide an API for users to define their own ways to store NCCL debug info.
|
||||
void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
|
||||
// Helper function for iteratively aborting communicators in the provided map
|
||||
void abortCommsFromMap(
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
|
||||
ncclCommsMap,
|
||||
c10::optional<std::string> abortReason);
|
||||
|
||||
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
|
||||
// instead of relying on ProcessGroupNCCL destructor.
|
||||
@ -694,6 +723,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Desync debug helper
|
||||
void logWorkEnd(WorkNCCL& work);
|
||||
|
||||
// Generates a prefix that is unique to this process group and rank, for
|
||||
// disambiguating logs
|
||||
std::string createLogPrefix() const;
|
||||
|
||||
// Returns the unique prefix created in createLogPrefix
|
||||
const std::string& logPrefix() const;
|
||||
|
||||
// Returns the global rank of the device. This function assumes that users
|
||||
// always create a default global process group(PG) which includes all
|
||||
// devices. It is called in the constructor of ProcessGroupNCCL, so it always
|
||||
// return the rank_ of the the very first PG created, aka, default global PG.
|
||||
const int& globalRank() const;
|
||||
|
||||
protected:
|
||||
// Function that runs as part of a separate thread aside from watchdog
|
||||
// thread because we need to check the heartbeat from watchdog thread
|
||||
@ -712,6 +754,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// for dump completion.
|
||||
std::future<bool> launchAsyncDebugDump();
|
||||
|
||||
// Helper to wait up to the specified timeout and then abandon the dump.
|
||||
// Logs on timeout, and asserts the future's status is as expected.
|
||||
void waitForDumpOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
|
||||
size_t timeout_sec = 30);
|
||||
|
||||
// When watchdog timeout, this function will be called and return debug info
|
||||
// for users. For now we only get information from retrieveDesyncReport.
|
||||
// We are working on enabling more useful debug information for watchdog
|
||||
@ -720,9 +769,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
static const int64_t kWatchdogThreadSleepMillis;
|
||||
|
||||
// The store is used to broadcast the NCCL unique ID of rank 0.
|
||||
// The store is used to broadcast the NCCL unique ID of rank 0. This store
|
||||
// comes with prefix and it is different across ProcessGroup NCCL instances
|
||||
// (aka, different ProcessGroups).
|
||||
c10::intrusive_ptr<Store> store_;
|
||||
|
||||
// Reference to the store without prefix so that keys are same across all
|
||||
// ProcessGroup NCCL instances and (key, value) pairs written to the store are
|
||||
// global.
|
||||
c10::intrusive_ptr<Store> globalStore_;
|
||||
|
||||
bool storeError_{false};
|
||||
|
||||
const c10::intrusive_ptr<Options> options_;
|
||||
@ -781,11 +837,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::mutex mutex_;
|
||||
|
||||
// Heartbeat of watchdog thread.
|
||||
uint64_t heartbeat_;
|
||||
std::atomic_uint64_t heartbeat_;
|
||||
|
||||
// The time interval used for deciding whether there is no watchdog heartbeat.
|
||||
int heartbeatTimeoutInSec_;
|
||||
|
||||
// Extra time of sleep when waiting for timeout dump to finish.
|
||||
int waitTimeoutDumpInMilSec_;
|
||||
|
||||
// Interval of check coordinated signals in ProcessGroupNCCL from other ranks
|
||||
// e.g., trigger the dump of the debugging info for timeout when notified.
|
||||
int coordCheckIntervalMilSec_;
|
||||
|
||||
// Size of ring buffer where we store NCCL Traces for debugging.
|
||||
int ncclTraceBufferSize_;
|
||||
|
||||
@ -823,9 +886,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
bool writeDebugInfo_ = false;
|
||||
|
||||
// Mutex to Guard the check of writeDebugInfo_
|
||||
std::mutex writeDebugInfoMutex_;
|
||||
|
||||
// Condition Variable for watchdog thread sleep
|
||||
std::condition_variable workMetaListCV_;
|
||||
|
||||
@ -929,14 +989,25 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
std::exception_ptr watchDogException_ = nullptr;
|
||||
|
||||
// The callback function to store NCCL debug info.
|
||||
std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
|
||||
|
||||
size_t uid_;
|
||||
|
||||
std::string logPrefix_;
|
||||
};
|
||||
|
||||
TORCH_API std::string dump_nccl_trace();
|
||||
|
||||
// Gets a mutable reference to a global optional function. Heartbeat Monitor
|
||||
// will query this function and if available, call it to dump traces. Inside
|
||||
// fbcode, we store a function here that uses an internal tool for process
|
||||
// tracing
|
||||
TORCH_API c10::optional<std::function<std::string()>>& get_cpp_trace_dumper();
|
||||
|
||||
// Similar to get_cpp_trace_dumper, this stores a function defined in
|
||||
// torch-python layer that lets us check whether the GIL can be acquired,
|
||||
// helpful for instrumenting in cases where a hang was observed.
|
||||
typedef bool (*gil_checker_t)();
|
||||
|
||||
TORCH_API gil_checker_t& get_gil_checker();
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -269,10 +269,20 @@ inline std::string retrieveDesyncReport(
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
DebugInfoWriter::DebugInfoWriter(int rank) {
|
||||
std::string fileName = getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
filename_ = c10::str(fileName, rank);
|
||||
/* Helper used by work::getDuration() and nccl flight recorder */
|
||||
float getDurationFromFirstEvent(
|
||||
const std::vector<at::cuda::CUDAEvent>& ncclStartEvents,
|
||||
const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) {
|
||||
TORCH_CHECK(
|
||||
ncclStartEvents.size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup, but found multiple start events.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents.size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup, but found multiple end events.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents[0].query(),
|
||||
"getDuration can only be called after work is succeeded.")
|
||||
return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]);
|
||||
}
|
||||
|
||||
DebugInfoWriter::~DebugInfoWriter() = default;
|
||||
@ -293,6 +303,31 @@ void DebugInfoWriter::write(const std::string& ncclTrace) {
|
||||
LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
|
||||
}
|
||||
|
||||
DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
|
||||
if (writer_ == nullptr) {
|
||||
std::string fileNamePrefix = getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
// Using std::unique_ptr here to auto-delete the writer object
|
||||
// when the pointer itself is destroyed.
|
||||
std::unique_ptr<DebugInfoWriter> writerPtr(
|
||||
new DebugInfoWriter(fileNamePrefix, rank));
|
||||
DebugInfoWriter::registerWriter(std::move(writerPtr));
|
||||
}
|
||||
return *writer_;
|
||||
}
|
||||
|
||||
void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
|
||||
TORCH_CHECK_WITH(
|
||||
DistBackendError,
|
||||
hasWriterRegistered_.load() == false,
|
||||
"debugInfoWriter already registered");
|
||||
hasWriterRegistered_.store(true);
|
||||
writer_ = std::move(writer);
|
||||
}
|
||||
|
||||
std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
|
||||
std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
|
||||
|
||||
inline std::string pickle_str(const c10::IValue& v) {
|
||||
std::vector<char> result;
|
||||
{
|
||||
@ -340,7 +375,7 @@ struct NCCLTraceBuffer {
|
||||
const char* profiling_name_;
|
||||
|
||||
std::shared_ptr<torch::CapturedTraceback> traceback_;
|
||||
// we borrow pointser to start_ and end_ so we can query the state
|
||||
// we borrow pointers to start_ and end_ so we can query the state
|
||||
// on reporting. However, once the event is completed, the call
|
||||
// to `complete` will clear these.
|
||||
EventList *start_, *end_;
|
||||
@ -348,6 +383,7 @@ struct NCCLTraceBuffer {
|
||||
// timestamp when the entry was created, likely close to the time the work
|
||||
// was 'enqueued'- not necessarily started
|
||||
c10::time_t time_created_;
|
||||
c10::optional<float> duration_;
|
||||
|
||||
const char* state_ = "scheduled";
|
||||
|
||||
@ -386,7 +422,7 @@ struct NCCLTraceBuffer {
|
||||
id_,
|
||||
pg_id,
|
||||
seq_id,
|
||||
profiling_name,
|
||||
profiling_name == nullptr ? "" : profiling_name,
|
||||
std::move(traceback),
|
||||
std::move(start),
|
||||
std::move(end),
|
||||
@ -456,35 +492,90 @@ struct NCCLTraceBuffer {
|
||||
return result;
|
||||
}
|
||||
|
||||
void retire_id(c10::optional<size_t> id) {
|
||||
/*
|
||||
Mark an Event as completed and free its events.
|
||||
|
||||
This is called by the watchdog thread, and is asynchronous from the
|
||||
perspective of the main thread.
|
||||
|
||||
compute_duration defaults to true since retire_id is only called in the
|
||||
watchdog thread, which is currently a place we call cuda APIs which may hang,
|
||||
but care should be taken to avoid computing duration in any function that must
|
||||
never hang. (timing must also be enabled for compute_duration - see
|
||||
TORCH_NCCL_ENABLE_TIMING).
|
||||
*/
|
||||
void retire_id(c10::optional<size_t> id, bool compute_duration = true) {
|
||||
if (!enabled_ || !id) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
||||
bool can_compute_duration = false;
|
||||
EventList* startEvents = nullptr;
|
||||
EventList* endEvents = nullptr;
|
||||
c10::optional<float> duration = c10::nullopt;
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
auto& entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
update_state(entry);
|
||||
entry.retired_ = true;
|
||||
entry.start_ = entry.end_ = nullptr;
|
||||
|
||||
if (compute_duration) {
|
||||
can_compute_duration = strcmp(entry.state_, "completed") == 0 &&
|
||||
entry.start_ && entry.end_;
|
||||
startEvents = entry.start_;
|
||||
endEvents = entry.end_;
|
||||
}
|
||||
}
|
||||
|
||||
if (can_compute_duration) {
|
||||
// Compute duration without without holding the lock, because
|
||||
// cudaEventDuration() can hang, and we need to acquire the lock before we
|
||||
// can dump(), which we never want to block.
|
||||
guard.unlock();
|
||||
duration = getDurationFromFirstEvent(*startEvents, *endEvents);
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry ref, see if it has been overwritten
|
||||
entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ != *id) {
|
||||
LOG(INFO)
|
||||
<< "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
}
|
||||
if (duration.has_value()) {
|
||||
entry.duration_ = duration.value();
|
||||
}
|
||||
}
|
||||
|
||||
entry.retired_ = true;
|
||||
entry.start_ = entry.end_ = nullptr;
|
||||
}
|
||||
|
||||
std::string dump() {
|
||||
auto result = dump_entries();
|
||||
auto entries = new_list();
|
||||
c10::IValue pg_id_s = "pg_id";
|
||||
c10::IValue seq_id_s = "seq_id";
|
||||
c10::IValue profiling_name_s = "profiling_name";
|
||||
c10::IValue input_sizes_s = "input_sizes";
|
||||
c10::IValue output_sizes_s = "output_sizes";
|
||||
c10::IValue time_created_s = "time_created_us";
|
||||
c10::IValue entries_key = "entries";
|
||||
c10::IValue version_key = "version";
|
||||
// Update whenever changing contents or formatting of the dump
|
||||
// (minor when adding fields, major when changing existing fields)
|
||||
c10::IValue version_val = "1.0";
|
||||
|
||||
c10::IValue frames_s = "frames";
|
||||
c10::IValue state_s = "state";
|
||||
c10::IValue line_s = "line";
|
||||
c10::IValue name_s = "name";
|
||||
c10::IValue filename_s = "filename";
|
||||
c10::IValue retired_s = "retired";
|
||||
c10::IValue pg_id_key = "pg_id";
|
||||
c10::IValue seq_id_key = "seq_id";
|
||||
c10::IValue profiling_name_key = "profiling_name";
|
||||
c10::IValue input_sizes_key = "input_sizes";
|
||||
c10::IValue output_sizes_key = "output_sizes";
|
||||
c10::IValue time_created_key = "time_created_ns";
|
||||
c10::IValue duration_key = "duration_ms";
|
||||
|
||||
c10::IValue frames_key = "frames";
|
||||
c10::IValue state_key = "state";
|
||||
c10::IValue line_key = "line";
|
||||
c10::IValue name_key = "name";
|
||||
c10::IValue filename_key = "filename";
|
||||
c10::IValue retired_key = "retired";
|
||||
|
||||
std::vector<torch::CapturedTraceback*> tracebacks;
|
||||
for (auto& e : result) {
|
||||
@ -494,9 +585,9 @@ struct NCCLTraceBuffer {
|
||||
std::vector<c10::IValue> all_frames;
|
||||
for (const auto& f : stracebacks.all_frames) {
|
||||
auto d = new_dict();
|
||||
d.insert(name_s, f.funcname);
|
||||
d.insert(filename_s, f.filename);
|
||||
d.insert(line_s, int64_t(f.lineno));
|
||||
d.insert(name_key, f.funcname);
|
||||
d.insert(filename_key, f.filename);
|
||||
d.insert(line_key, int64_t(f.lineno));
|
||||
all_frames.emplace_back(std::move(d));
|
||||
}
|
||||
|
||||
@ -504,10 +595,13 @@ struct NCCLTraceBuffer {
|
||||
auto& e = result.at(i);
|
||||
auto& tb = stracebacks.tracebacks.at(i);
|
||||
auto dict = new_dict();
|
||||
dict.insert(pg_id_s, int64_t(e.pg_id_));
|
||||
dict.insert(seq_id_s, int64_t(e.seq_id_));
|
||||
dict.insert(profiling_name_s, e.profiling_name_);
|
||||
dict.insert(time_created_s, int64_t(e.time_created_ / 1000));
|
||||
dict.insert(pg_id_key, int64_t(e.pg_id_));
|
||||
dict.insert(seq_id_key, int64_t(e.seq_id_));
|
||||
dict.insert(profiling_name_key, e.profiling_name_);
|
||||
dict.insert(time_created_key, int64_t(e.time_created_));
|
||||
if (e.duration_) {
|
||||
dict.insert(duration_key, *e.duration_);
|
||||
}
|
||||
|
||||
auto it = e.sizes_.begin();
|
||||
auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
|
||||
@ -523,19 +617,24 @@ struct NCCLTraceBuffer {
|
||||
return sizes;
|
||||
};
|
||||
|
||||
dict.insert(input_sizes_s, read_sizes(e.input_dims_));
|
||||
dict.insert(output_sizes_s, read_sizes(e.output_dims_));
|
||||
dict.insert(state_s, e.state_);
|
||||
dict.insert(retired_s, e.retired_);
|
||||
dict.insert(input_sizes_key, read_sizes(e.input_dims_));
|
||||
dict.insert(output_sizes_key, read_sizes(e.output_dims_));
|
||||
dict.insert(state_key, e.state_);
|
||||
dict.insert(retired_key, e.retired_);
|
||||
|
||||
auto frames = new_list();
|
||||
for (int64_t frame : tb) {
|
||||
frames.push_back(all_frames.at(frame));
|
||||
}
|
||||
dict.insert(frames_s, frames);
|
||||
dict.insert(frames_key, frames);
|
||||
entries.push_back(dict);
|
||||
}
|
||||
return pickle_str(entries);
|
||||
|
||||
auto dict = new_dict();
|
||||
dict.insert(entries_key, entries);
|
||||
dict.insert(version_key, version_val);
|
||||
|
||||
return pickle_str(dict);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -50,6 +50,35 @@
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
bool acquire_gil() {
|
||||
// basically if this function can acquire the gil, it will return quickly.
|
||||
// if not, it will hang forever. The idea is to call this from a thread
|
||||
// wrapped in a future, and then check the future after a timeout, to
|
||||
// determine whether we're facing gil contention.
|
||||
if (Py_IsInitialized()) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
return true;
|
||||
}
|
||||
|
||||
// If we end up here, its probably still a "pass" from the perspective of
|
||||
// checking whether python is stuck. but currently we don't check the return
|
||||
// value of this function anyway, just check whether it returned quickly vs
|
||||
// timing out. Taking a long time is the main sign of trouble. Fast return
|
||||
// with true or with false is both OK from the perspective of debugging python
|
||||
// hangs.
|
||||
return false;
|
||||
}
|
||||
|
||||
bool registerGilChecker() {
|
||||
c10d::get_gil_checker() = &acquire_gil;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool registered = registerGilChecker();
|
||||
#endif // USE_C10D_NCCL
|
||||
|
||||
// Wrapper to ensure GIL is released before destructing ProcessGroupGloo
|
||||
// TODO: move this somewhere more generally useful
|
||||
template <typename T>
|
||||
@ -1033,6 +1062,29 @@ Example::
|
||||
>>> store.add("first_key", 6)
|
||||
>>> # Should return 7
|
||||
>>> store.get("first_key")
|
||||
)")
|
||||
.def(
|
||||
"check",
|
||||
&::c10d::Store::check,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
The call to check whether a given list of ``keys`` have value stored in
|
||||
the store. This call immediately returns in normal cases but still suffers
|
||||
from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.
|
||||
Calling :meth:`~torch.distributed.store.check` with a list of keys that
|
||||
one wants to check whether stored in the store or not.
|
||||
|
||||
Arguments:
|
||||
keys (lisr[str]): The keys to query whether stored in the store.
|
||||
|
||||
Example::
|
||||
>>> import torch.distributed as dist
|
||||
>>> from datetime import timedelta
|
||||
>>> # Using TCPStore as an example, other store types can also be used
|
||||
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
|
||||
>>> store.add("first_key", 1)
|
||||
>>> # Should return 7
|
||||
>>> store.check(["first_key"])
|
||||
)")
|
||||
.def(
|
||||
"delete_key",
|
||||
@ -1404,7 +1456,11 @@ Arguments:
|
||||
.def_property_readonly(
|
||||
"underlying_store",
|
||||
&::c10d::PrefixStore::getUnderlyingStore,
|
||||
R"(Gets the underlying store object that PrefixStore wraps around.)");
|
||||
R"(Gets the underlying store object that PrefixStore wraps around.)")
|
||||
.def_property_readonly(
|
||||
"_underlying_non_prefix_store",
|
||||
&::c10d::PrefixStore::getUnderlyingNonPrefixStore,
|
||||
R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
|
||||
|
||||
auto processGroup =
|
||||
py::class_<
|
||||
|
Reference in New Issue
Block a user