mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PGNCCL] Use non-blocking mode by default in eager init (#138527)
### Why use non-blocking mode in eager init? For overlapping comm init and model init, etc.  ### Why can we set non-blocking as default? If the setting is dangling -- i.e. not passed in by user nor set via env -- `ProcessGroupNCCL` can have some preferred logic. And torch-level API semantics does not change whether the NCCL comm is blocking or non-blocking (handled within `ProcessGroupNCCL`). ### Why not make non-blocking default for lazy mode as well? PR https://github.com/pytorch/pytorch/pull/137544 tried it. Two reasons why that's not preferred today: 1. It is hard -- too big a blast. 2. There is no gain by doing lazy init in non-blocking mode, because the right next CPU call is a collective, and we will block there waiting for comm to be ready, so same effect as blocked init, no "opening" compared to eager mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138527 Approved by: https://github.com/wconstab ghstack dependencies: #138860
This commit is contained in:
@ -236,7 +236,6 @@ DEFINE_CONSTANT(started_state, "started");
|
||||
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
|
||||
TORCH_API std::string getNcclVersion();
|
||||
TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error);
|
||||
bool nccl_use_nonblocking();
|
||||
int nccl_nonblocking_timeout();
|
||||
|
||||
// Provides additional detail into NCCL error codes based on when these are
|
||||
@ -308,6 +307,8 @@ class NCCLComm {
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
comm->initialized_ = true;
|
||||
// Old style comm is always blocking.
|
||||
comm->nonBlocking_ = false;
|
||||
return comm;
|
||||
}
|
||||
|
||||
@ -318,26 +319,19 @@ class NCCLComm {
|
||||
ncclUniqueId commId,
|
||||
ncclConfig_t& config) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
bool isInitialized = false;
|
||||
if (nccl_use_nonblocking()) {
|
||||
config.blocking = 0;
|
||||
LOG(INFO) << "Rank " << rank
|
||||
<< ": creating NCCL communicator in nonblocking mode";
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
std::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
std::nullopt);
|
||||
// under blocking mode, comm is initialized after NCCL CHECK
|
||||
isInitialized = true;
|
||||
}
|
||||
comm->nonBlocking_ = config.blocking == 0;
|
||||
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
|
||||
<< (comm->nonBlocking_ ? "nonblocking" : "blocking");
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
std::nullopt);
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
comm->initialized_ = isInitialized;
|
||||
// Under blocking mode, comm is initialized immediately after NCCL init
|
||||
// returns; Under nonblocking mode, we check whether comm is initialized the
|
||||
// *next* time ncclComm_ is accessed.
|
||||
comm->initialized_ = !comm->nonBlocking_;
|
||||
return comm;
|
||||
}
|
||||
|
||||
@ -382,6 +376,7 @@ class NCCLComm {
|
||||
std::swap(aborted_, other.aborted_);
|
||||
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||
std::swap(initialized_, other.initialized_);
|
||||
std::swap(nonBlocking_, other.nonBlocking_);
|
||||
}
|
||||
|
||||
ncclComm_t getNcclComm();
|
||||
@ -550,6 +545,10 @@ class NCCLComm {
|
||||
// better error messaging.
|
||||
std::optional<std::string> commFailureReason_{};
|
||||
bool initialized_{false};
|
||||
// Whether this communicator is using nonblocking mode. Recorded during comm
|
||||
// creation or split. For safety, we give a default value of true (more
|
||||
// protection).
|
||||
bool nonBlocking_{true};
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
// Stores handlers for tensors registered by NCCL
|
||||
std::unordered_map<void*, void*> registeredSegmentHandles_;
|
||||
|
Reference in New Issue
Block a user