[PGNCCL] Add default value for nccl_nonblocking_timeout (#138374)

- Added default value for `nccl_nonblocking_timeout` (30 mins, previous: -1).
- Reuse C10D_CHECK_TIMEOUT in other CHECK macros

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138374
Approved by: https://github.com/eqy
ghstack dependencies: #137855, #138488
This commit is contained in:
Ke Wen
2024-10-21 11:15:31 -07:00
committed by PyTorch MergeBot
parent 03c72976a5
commit 6b29d40e9b
5 changed files with 73 additions and 100 deletions

View File

@ -11,6 +11,7 @@
#include <nccl.h>
#include <sched.h>
#include <limits>
#include <sstream>
#include <type_traits>
@ -146,6 +147,7 @@ static inline void NCCL_CHECK(ncclResult_t result) {
}
// TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
// Default value: on
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
@ -155,40 +157,34 @@ bool nccl_use_nonblocking() {
return nccl_use_nonblocking_;
}
static int _parse_nccl_nonblocking_timeout() {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
int timeout = -1;
if (val) {
const std::string config(val);
timeout = std::stoi(config);
if (!nccl_use_nonblocking() && timeout > 0) {
TORCH_WARN(
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
timeout = -1;
// Default value: 30 minutes
static int nccl_nonblocking_timeout() {
static int timeout = -2; // -2 means not initialized
if (timeout == -2) {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
if (val && strlen(val) > 0) {
timeout = strtol(val, nullptr, 0);
} else {
// Default value consistent with kBackendDefaultTimeout
timeout = 30 * 60;
}
}
return timeout;
}
static int nccl_nonblocking_timeout() {
static int timeout = _parse_nccl_nonblocking_timeout();
return timeout;
}
static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclResult_t result = to_nccl_result(status);
auto startTimepoint = std::chrono::steady_clock::now();
while (result == ncclInProgress) {
if (nccl_nonblocking_timeout() > 0) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
sched_yield(); // yield to other threads
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
}
if (result != ncclSuccess) {
@ -213,15 +209,14 @@ static inline void NCCL_CHECK_TIMEOUT(
if (result == ncclInProgress) {
for (const auto i : c10::irange(comms.size())) {
do {
if (nccl_nonblocking_timeout() > 0) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
sched_yield(); // yield to other threads
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
} while (result == ncclInProgress);
if (result != ncclSuccess) {

View File

@ -95,8 +95,7 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
// comm ptr is valid. Therefore we add a manual wait here for safety.
// TODO: remove this wait after NCCL fix the semantics.
auto startTime = std::chrono::steady_clock::now();
auto timeout =
nccl_nonblocking_timeout() > 0 ? nccl_nonblocking_timeout() : 30 * 60;
auto timeout = nccl_nonblocking_timeout();
while (!comm->ncclComm_) {
C10D_CHECK_TIMEOUT(startTime, timeout);
C10D_SCHED_SLEEP();
@ -179,26 +178,21 @@ bool nccl_use_nonblocking() {
return nccl_use_nonblocking_;
}
int _parse_nccl_nonblocking_timeout() {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
int timeout = -1;
if (val) {
const std::string config(val);
timeout = std::stoi(config);
if (!nccl_use_nonblocking() && timeout > 0) {
TORCH_WARN(
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
timeout = -1;
// Default value: 30 minutes
int nccl_nonblocking_timeout() {
static int timeout = -2; // -2 means not initialized
if (timeout == -2) {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
if (val && strlen(val) > 0) {
timeout = strtol(val, nullptr, 0);
} else {
// Default value consistent with kBackendDefaultTimeout
timeout = 30 * 60;
}
}
return timeout;
}
int nccl_nonblocking_timeout() {
static int timeout = _parse_nccl_nonblocking_timeout();
return timeout;
}
std::string ncclGetErrorWithVersion(ncclResult_t error) {
return std::string(ncclGetErrorString(error)) + ", NCCL version " +
getNcclVersion();

View File

@ -120,32 +120,24 @@ constexpr int64_t kCommInitBusyWaitMillis = 2;
// Macro to throw on a non-successful NCCL return value, non-blocking.
#define C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, yield_fn) \
ncclResult_t result = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
while (result == ncclInProgress) { \
if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
currentTimepoint - startTimepoint) \
.count(); \
if (timeElapsed > nccl_nonblocking_timeout()) { \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + \
ncclGetErrorWithVersion(result) + "\n" + \
getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
do { \
ncclResult_t result = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
auto timeout = nccl_nonblocking_timeout(); \
while (result == ncclInProgress) { \
C10D_CHECK_TIMEOUT(startTimepoint, timeout); \
yield_fn; \
ncclCommGetAsyncError(comm, &result); \
} \
yield_fn; \
ncclCommGetAsyncError(comm, &result); \
} \
if (result != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
}
if (result != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} while (0)
// Sleep for kCommInitBusyWaitMillis milliseconds.
#define C10D_SCHED_SLEEP() \
std::this_thread::sleep_for( \
std::chrono::milliseconds(kCommInitBusyWaitMillis))
@ -165,33 +157,24 @@ constexpr int64_t kCommInitBusyWaitMillis = 2;
C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, C10D_SCHED_SLEEP())
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \
ncclResult_t state = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
if (state == ncclInProgress) { \
do { \
if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
currentTimepoint - startTimepoint) \
.count(); \
if (timeElapsed > nccl_nonblocking_timeout()) { \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
":" + std::to_string(__LINE__) + ", " + \
ncclGetErrorWithVersion(state) + "\n" + \
getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} \
sched_yield(); \
ncclCommGetAsyncError(comm->getNcclComm(), &state); \
} while (state == ncclInProgress); \
} \
if (state != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
"\n" + getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
}
do { \
ncclResult_t state = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
auto timeout = nccl_nonblocking_timeout(); \
if (state == ncclInProgress) { \
do { \
C10D_CHECK_TIMEOUT(startTimepoint, timeout); \
sched_yield(); \
ncclCommGetAsyncError(comm->getNcclComm(), &state); \
} while (state == ncclInProgress); \
} \
if (state != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
"\n" + getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} while (0)
// Macro to print and abort on a non-successful NCCL return value.
#define C10D_NCCL_ASSERT(cmd) \

View File

@ -990,6 +990,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
<< ", TORCH_NCCL_USE_COMM_NONBLOCKING: " << nccl_use_nonblocking()
#ifdef NCCL_HAS_COMM_REGISTER
<< ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
<< useTensorRegisterAllocatorHook_

View File

@ -499,7 +499,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// * NCCL_SPLIT_NOCOLOR (-1): not in group;
// * NCCL_SPLIT_NOCOLOR - 1: uninitialized.
// [Note 1]: the type must be `int` instead of `int64_t` because NCCL API
// accepts int. Otherwise, an imlicit conversion may happen at the API call
// accepts int. Otherwise, an implicit conversion may happen at the API call
// and the value may become negative.
// [Note 2]: this member is pybinded to Python, the value passed from Python
// must be within the numerical range of C++ int. Otherwise, Python will