[c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into single-device style (#119421)

Part 2 and last part of #118674:
Introduce actual "single-device" code change to ProcessGroupNCCL.

assert size == 1 and test refactor have been done in #119099.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119421
Approved by: https://github.com/shuqiangzhang
This commit is contained in:
Ke Wen
2024-02-09 20:23:16 +00:00
committed by PyTorch MergeBot
parent 0597dab523
commit f3e7d80993
8 changed files with 831 additions and 1129 deletions

View File

@ -126,37 +126,32 @@
TORCH_CHECK_WITH(DistBackendError, false, err); \
}
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \
ncclResult_t state = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
if (state == ncclInProgress) { \
for (const auto i : c10::irange(comms_.size())) { \
do { \
if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
currentTimepoint - startTimepoint) \
.count(); \
if (timeElapsed > nccl_nonblocking_timeout()) { \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
":" + std::to_string(__LINE__) + ", " + \
ncclGetErrorWithVersion(state) + "\n" + \
getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} \
ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \
} while (state == ncclInProgress); \
if (state != ncclSuccess) { \
break; /* fall through to failed case */ \
} \
} \
} \
if (state != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
"\n" + getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \
ncclResult_t state = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
if (state == ncclInProgress) { \
do { \
if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
currentTimepoint - startTimepoint) \
.count(); \
if (timeElapsed > nccl_nonblocking_timeout()) { \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
":" + std::to_string(__LINE__) + ", " + \
ncclGetErrorWithVersion(state) + "\n" + \
getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} \
ncclCommGetAsyncError(comm->getNcclComm(), &state); \
} while (state == ncclInProgress); \
} \
if (state != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
"\n" + getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
}
// Macro to print and abort on a non-successful NCCL return value.