[PGNCCL] Use non-blocking mode by default in eager init (#138527)

### Why use non-blocking mode in eager init?
For overlapping comm init and model init, etc.
![image](https://github.com/user-attachments/assets/9b0bf7a9-be26-4d16-827b-dbe861f083cd)

### Why can we set non-blocking as default?
If the setting is dangling -- i.e. not passed in by user nor set via env -- `ProcessGroupNCCL` can have some preferred logic. And torch-level API semantics does not change whether the NCCL comm is blocking or non-blocking (handled within `ProcessGroupNCCL`).

### Why not make non-blocking default for lazy mode as well?
PR https://github.com/pytorch/pytorch/pull/137544 tried it.
Two reasons why that's not preferred today:
1. It is hard -- too big a blast.
2. There is no gain by doing lazy init in non-blocking mode, because the right next CPU call is a collective, and we will block there waiting for comm to be ready, so same effect as blocked init, no "opening" compared to eager mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138527
Approved by: https://github.com/wconstab
ghstack dependencies: #138860
This commit is contained in:
Ke Wen
2024-10-25 16:59:20 -07:00
committed by PyTorch MergeBot
parent fed37dbfbc
commit ee11e2da1e
6 changed files with 106 additions and 70 deletions

View File

@ -321,25 +321,30 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_close_pg(self):
@parametrize("eager_init", [True, False])
def test_close_pg(self, eager_init: bool):
# Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
# abort the process group.
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
device = self.rank_to_GPU[self.rank][0]
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
c10d.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
device_id=device if eager_init else None,
)
t = torch.rand(10, 10, device=device)
# First allreduce to initialize state.
pg.allreduce(t)
dist.all_reduce(t)
# Destroy pg and validate pg is no longer valid
dist.destroy_process_group()
with self.assertRaises(dist.DistBackendError):
pg.allreduce([t])
del pg
with self.assertRaises(ValueError):
dist.all_reduce(t)
CUDA_12_AND_ABOVE = torch.cuda.is_available() and (
torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12
@ -803,27 +808,24 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_comm_lazy_init_split(self):
@parametrize("eager_init", [True, False])
def test_new_group(self, eager_init: bool):
# Test the optimization of new groups that contain all world
# ranks use the "transparent" `ncclCommSplit` optimization.
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
# Test lazy splitting behavior across each per-device backend.
for device in self.rank_to_GPU[self.rank]:
backend = pg._get_backend(torch.device(device))
# split doesn't happen unless the original process group has lazily
# created communicators, so first verify we haven't split even when
# making the new group and running an operation on the original pg.
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
c10d.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
device_id=device if eager_init else None,
)
ng = c10d.new_group()
tensor = torch.tensor([self.rank]).cuda(device)
pg.broadcast(tensor, 0)
self.assertEqual(backend.comm_split_count(), 0)
# The new group will not force a split because it is a lazy init.
ng.broadcast(tensor, 0)
self.assertEqual(backend.comm_split_count(), 0)
tensor = torch.tensor([self.rank], device=device)
dist.broadcast(tensor, 0)
dist.broadcast(tensor, 0, group=ng)
dist.destroy_process_group()
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@ -863,15 +865,11 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
pg = self._create_process_group_nccl(store, self.opts())
backend = pg._get_backend(torch.device(device))
self.assertEqual(backend._is_initialized(), False)
tensor = torch.full((1,), self.rank).cuda(device)
# create a subgroup eagerly
new_group = c10d.new_group([0, 1], device_id=device)
self.assertEqual(backend.comm_split_count(), 0)
new_backend = new_group._get_backend(torch.device(device))
self.assertEqual(new_backend._is_initialized(), True)
tensor = torch.full((1,), self.rank).cuda(device)
dist.broadcast(tensor, 0, group=new_group)
self.assertEqual(new_backend.comm_split_count(), 0)
# the default group should stay lazy
self.assertEqual(backend._is_initialized(), False)
torch.cuda.synchronize()
dist.destroy_process_group()

View File

@ -159,7 +159,6 @@ static inline void NCCL_CHECK(ncclResult_t result) {
}
// TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
// Default value: on
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
@ -194,7 +193,8 @@ static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
throw std::runtime_error(
"NCCL timeout when waiting for nonblocking call to become successful.");
}
sched_yield(); // yield to other threads
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
@ -226,7 +226,8 @@ static inline void NCCL_CHECK_TIMEOUT(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
throw std::runtime_error(
"NCCL timeout when waiting for nonblocking call to become successful.");
}
sched_yield(); // yield to other threads
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);

View File

@ -31,7 +31,7 @@ ncclComm_t NCCLComm::getNcclComm() {
commFailureMsg));
}
// In non-blocking mode, ensure comm is ready.
if (nccl_use_nonblocking()) {
if (nonBlocking_) {
// If timeout is reached, throw an exception.
C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt);
// ncclComm_ should be initialized by now
@ -101,6 +101,7 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
#endif
++source->ncclCommSplitCounter_;
comm->rank_ = rank;
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << source->rank_ << ": created child comm "
<< comm->repr() << " with color_id " << color_id;
return comm;
@ -163,15 +164,6 @@ size_t hashTensors(const std::vector<at::Tensor>& tensors) {
}
#endif
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
if (nccl_use_nonblocking_) {
TORCH_WARN_ONCE("Using experimental non-blocking NCCL communicator.");
}
return nccl_use_nonblocking_;
}
// Default value: 30 minutes
int nccl_nonblocking_timeout() {
static int timeout = -2; // -2 means not initialized

View File

@ -236,7 +236,6 @@ DEFINE_CONSTANT(started_state, "started");
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
TORCH_API std::string getNcclVersion();
TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error);
bool nccl_use_nonblocking();
int nccl_nonblocking_timeout();
// Provides additional detail into NCCL error codes based on when these are
@ -311,6 +310,8 @@ class NCCLComm {
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->initialized_ = true;
// Old style comm is always blocking.
comm->nonBlocking_ = false;
return comm;
}
@ -321,26 +322,19 @@ class NCCLComm {
ncclUniqueId commId,
ncclConfig_t& config) {
auto comm = std::make_shared<NCCLComm>();
bool isInitialized = false;
if (nccl_use_nonblocking()) {
config.blocking = 0;
LOG(INFO) << "Rank " << rank
<< ": creating NCCL communicator in nonblocking mode";
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
<< (comm->nonBlocking_ ? "nonblocking" : "blocking");
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
std::nullopt);
} else {
C10D_NCCL_CHECK(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
std::nullopt);
// under blocking mode, comm is initialized after NCCL CHECK
isInitialized = true;
}
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->initialized_ = isInitialized;
// Under blocking mode, comm is initialized immediately after NCCL init
// returns; Under nonblocking mode, we check whether comm is initialized the
// *next* time ncclComm_ is accessed.
comm->initialized_ = !comm->nonBlocking_;
return comm;
}
@ -385,6 +379,7 @@ class NCCLComm {
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
std::swap(initialized_, other.initialized_);
std::swap(nonBlocking_, other.nonBlocking_);
}
ncclComm_t getNcclComm();
@ -553,6 +548,10 @@ class NCCLComm {
// better error messaging.
std::optional<std::string> commFailureReason_{};
bool initialized_{false};
// Whether this communicator is using nonblocking mode. Recorded during comm
// creation or split. For safety, we give a default value of true (more
// protection).
bool nonBlocking_{true};
#ifdef NCCL_HAS_COMM_REGISTER
// Stores handlers for tensors registered by NCCL
std::unordered_map<void*, void*> registeredSegmentHandles_;

View File

@ -987,7 +987,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
<< ", TORCH_NCCL_USE_COMM_NONBLOCKING: " << nccl_use_nonblocking()
#ifdef NCCL_HAS_COMM_REGISTER
<< ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
<< useTensorRegisterAllocatorHook_
@ -1059,6 +1058,39 @@ void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
getNCCLComm(key, device, OpType::ALLREDUCE);
}
bool ProcessGroupNCCL::useNonblocking() {
#ifndef NCCL_HAS_COMM_NONBLOCKING
return false;
#endif
// Already parsed, return the cached value
if (useNonblocking_.has_value()) {
return useNonblocking_.value();
}
// Get environment variable.
auto nbEnv = c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING");
// 1st priority: Respect the user's setting
if (options_->config.blocking != NCCL_CONFIG_UNDEF_INT) {
useNonblocking_ = options_->config.blocking == 0;
}
// 2nd priority: Respect the environment variable
else if (nbEnv.has_value()) {
useNonblocking_ = nbEnv.value();
}
// 3rd priority: automatically use nonblocking if we are in eager init mode
else if (getBoundDeviceId()) {
useNonblocking_ = true;
}
// 4th priority: otherwise, nonblocking = false to preserve old behavior
else {
useNonblocking_ = false;
}
LOG(INFO) << logPrefix()
<< "Using non-blocking mode: " << useNonblocking_.value();
return useNonblocking_.value();
}
void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
// If our backend doesn't support splitting, this is a no-op for
// ranks not in the new subgroup (and ranks that would be in it will
@ -1067,6 +1099,8 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
const auto key = getKeyFromDevice(device);
LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
<< device << ", key " << key << ", i am " << this;
bool useNb = useNonblocking();
options_->config.blocking = useNb ? 0 : 1;
auto comm = getNCCLComm(key, device, OpType::ALLREDUCE);
NCCLComm::split(
comm.get(),
@ -2357,6 +2391,11 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
rank = p2pRank;
}
#ifdef NCCL_HAS_COMM_NONBLOCKING
bool useNb = useNonblocking();
options_->config.blocking = useNb ? 0 : 1;
#endif
#ifdef NCCL_HAS_COMM_SPLIT
if (options_->split_from) {
// Find a valid, healthy communicator to split from if possible.
@ -2773,7 +2812,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
work->ncclStartEvent_->record(ncclStream);
}
if (nccl_use_nonblocking()) {
if (useNonblocking()) {
groupEndNonblocking(comm);
} else {
groupEnd();
@ -3093,8 +3132,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
#endif
{
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
comm, nccl_use_nonblocking());
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking());
for (const auto i : c10::irange(inputs.size())) {
// Both `inputs' and `outputs' are created on a worker stream and used in
// different ncclStreams. Hence, both must record the ncclStream to
@ -4662,7 +4700,7 @@ void ProcessGroupNCCL::groupEndNonblocking(
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
#else
if (!nccl_use_nonblocking()) {
if (!useNonblocking()) {
C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
} else {
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt);

View File

@ -778,6 +778,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Abort all communicators on this rank.
bool abortComms(const std::optional<std::string>& abortReason = std::nullopt);
// A helper function to check if nonblocking API mode should be used.
// Use this helper instead of directly checking `useNonblocking_` variable.
bool useNonblocking();
private:
int globalRankStart;
int globalRankStride;
@ -1237,6 +1241,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::shared_ptr<ProcessGroupStatus> pgStatus_ =
std::make_shared<ProcessGroupStatus>();
// Internal cached value: use NCCL non-blocking API mode or not.
// Use `useNonblocking()` method instead of accessing this variable directly.
std::optional<bool> useNonblocking_{std::nullopt};
};
// Dumps the NCCL comm traces and additional information about the Process