Revert "[PGNCCL] Use non-blocking mode by default in eager init (#138527)"

This reverts commit 8fbf866904661b16cba4c799af81121557ba9da8.

Reverted https://github.com/pytorch/pytorch/pull/138527 on behalf of https://github.com/jeanschmidt due to Seems to have introduce regressions on main, pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 2, 3, linux.g4dn.12xlarge.nvidia.gpu) checking if revert will do ([comment](https://github.com/pytorch/pytorch/pull/138527#issuecomment-2432479338))
This commit is contained in:
PyTorch MergeBot
2024-10-23 14:49:49 +00:00
parent 2f007e5de5
commit cdfe1bffd1
4 changed files with 34 additions and 73 deletions

View File

@ -31,7 +31,7 @@ ncclComm_t NCCLComm::getNcclComm() {
commFailureMsg));
}
// In non-blocking mode, ensure comm is ready.
if (nonBlocking_) {
if (nccl_use_nonblocking()) {
// If timeout is reached, throw an exception.
C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt);
// ncclComm_ should be initialized by now
@ -101,7 +101,6 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
#endif
++source->ncclCommSplitCounter_;
comm->rank_ = rank;
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << source->rank_ << ": created child comm "
<< comm->repr() << " with color_id " << color_id;
return comm;
@ -164,6 +163,15 @@ size_t hashTensors(const std::vector<at::Tensor>& tensors) {
}
#endif
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
if (nccl_use_nonblocking_) {
TORCH_WARN_ONCE("Using experimental non-blocking NCCL communicator.");
}
return nccl_use_nonblocking_;
}
// Default value: 30 minutes
int nccl_nonblocking_timeout() {
static int timeout = -2; // -2 means not initialized

View File

@ -237,6 +237,7 @@ DEFINE_CONSTANT(started_state, "started");
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
TORCH_API std::string getNcclVersion();
TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error);
bool nccl_use_nonblocking();
int nccl_nonblocking_timeout();
// Provides additional detail into NCCL error codes based on when these are
@ -313,8 +314,6 @@ class NCCLComm {
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->initialized_ = true;
// Old style comm is always blocking.
comm->nonBlocking_ = false;
return comm;
}
@ -325,19 +324,26 @@ class NCCLComm {
ncclUniqueId commId,
ncclConfig_t& config) {
auto comm = std::make_shared<NCCLComm>();
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
<< (comm->nonBlocking_ ? "nonblocking" : "blocking");
bool isInitialized = false;
if (nccl_use_nonblocking()) {
config.blocking = 0;
LOG(INFO) << "Rank " << rank
<< ": creating NCCL communicator in nonblocking mode";
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
std::nullopt);
} else {
C10D_NCCL_CHECK(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
std::nullopt);
// under blocking mode, comm is initialized after NCCL CHECK
isInitialized = true;
}
comm->ncclId_ = commId;
comm->rank_ = rank;
// Under blocking mode, comm is initialized immediately after NCCL init
// returns; Under nonblocking mode, we check whether comm is initialized the
// *next* time ncclComm_ is accessed.
comm->initialized_ = !comm->nonBlocking_;
comm->initialized_ = isInitialized;
return comm;
}
@ -381,7 +387,6 @@ class NCCLComm {
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
std::swap(initialized_, other.initialized_);
std::swap(nonBlocking_, other.nonBlocking_);
}
ncclComm_t getNcclComm();
@ -545,10 +550,6 @@ class NCCLComm {
// better error messaging.
std::optional<std::string> commFailureReason_;
bool initialized_{false};
// Whether this communicator is using nonblocking mode. Recorded during comm
// creation or split. For safety, we give a default value of true (more
// protection).
bool nonBlocking_{true};
#ifdef NCCL_HAS_COMM_REGISTER
// Stores handlers for tensors registered by NCCL
std::unordered_map<void*, void*> registeredSegmentHandles_;

View File

@ -990,6 +990,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
<< ", TORCH_NCCL_USE_COMM_NONBLOCKING: " << nccl_use_nonblocking()
#ifdef NCCL_HAS_COMM_REGISTER
<< ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
<< useTensorRegisterAllocatorHook_
@ -1061,41 +1062,6 @@ void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
getNCCLComm(key, device, OpType::ALLREDUCE);
}
bool ProcessGroupNCCL::useNonblocking() {
#ifndef NCCL_HAS_COMM_NONBLOCKING
return false;
#endif
// Already parsed, return the cached value
if (useNonblocking_.has_value()) {
return useNonblocking_.value();
}
// Get environment variable.
auto nbEnv = c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING");
// 1st priority: Respect the user's setting
if (options_->config.blocking != NCCL_CONFIG_UNDEF_INT) {
useNonblocking_ = options_->config.blocking == 0;
goto print_and_return;
}
// 2nd priority: Respect the environment variable
if (nbEnv.has_value()) {
useNonblocking_ = nbEnv.value();
goto print_and_return;
}
// 3rd priority: automatically use nonblocking if we are in eager init mode
if (getBoundDeviceId()) {
useNonblocking_ = true;
goto print_and_return;
}
// 4th priority: otherwise, nonblocking = false to preserve old behavior
useNonblocking_ = false;
print_and_return:
LOG(INFO) << logPrefix()
<< "Using non-blocking mode: " << useNonblocking_.value();
return useNonblocking_.value();
}
void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
// If our backend doesn't support splitting, this is a no-op for
// ranks not in the new subgroup (and ranks that would be in it will
@ -1104,8 +1070,6 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
const auto key = getKeyFromDevice(device);
LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
<< device << ", key " << key << ", i am " << this;
bool useNb = useNonblocking();
options_->config.blocking = useNb ? 0 : 1;
auto comm = getNCCLComm(key, device, OpType::ALLREDUCE);
NCCLComm::split(
comm.get(),
@ -2361,11 +2325,6 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
rank = p2pRank;
}
#ifdef NCCL_HAS_COMM_NONBLOCKING
bool useNb = useNonblocking();
options_->config.blocking = useNb ? 0 : 1;
#endif
#ifdef NCCL_HAS_COMM_SPLIT
if (options_->split_from) {
// Find a valid, healthy communicator to split from if possible.
@ -2782,7 +2741,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
work->ncclStartEvent_->record(ncclStream);
}
if (useNonblocking()) {
if (nccl_use_nonblocking()) {
groupEndNonblocking(comm);
} else {
groupEnd();
@ -3102,7 +3061,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
#endif
{
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking());
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
comm, nccl_use_nonblocking());
for (const auto i : c10::irange(inputs.size())) {
// Both `inputs' and `outputs' are created on a worker stream and used in
// different ncclStreams. Hence, both must record the ncclStream to
@ -4654,7 +4614,7 @@ void ProcessGroupNCCL::groupEndNonblocking(
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
#else
if (!useNonblocking()) {
if (!nccl_use_nonblocking()) {
C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
} else {
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt);

View File

@ -782,10 +782,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Abort all communicators on this rank.
bool abortComms(std::optional<std::string> abortReason = std::nullopt);
// A helper function to check if nonblocking API mode should be used.
// Use this helper instead of directly checking `useNonblocking_` variable.
bool useNonblocking();
private:
int globalRankStart;
int globalRankStride;
@ -1242,10 +1238,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::shared_ptr<ProcessGroupStatus> pgStatus_ =
std::make_shared<ProcessGroupStatus>();
// Internal cached value: use NCCL non-blocking API mode or not.
// Use `useNonblocking()` method instead of accessing this variable directly.
std::optional<bool> useNonblocking_{std::nullopt};
};
// Dumps the NCCL comm traces and additional information about the Process