mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PGNCCL] Use non-blocking mode by default in eager init (#138527)
### Why use non-blocking mode in eager init? For overlapping comm init and model init, etc.  ### Why can we set non-blocking as default? If the setting is dangling -- i.e. not passed in by user nor set via env -- `ProcessGroupNCCL` can have some preferred logic. And torch-level API semantics does not change whether the NCCL comm is blocking or non-blocking (handled within `ProcessGroupNCCL`). ### Why not make non-blocking default for lazy mode as well? PR https://github.com/pytorch/pytorch/pull/137544 tried it. Two reasons why that's not preferred today: 1. It is hard -- too big a blast. 2. There is no gain by doing lazy init in non-blocking mode, because the right next CPU call is a collective, and we will block there waiting for comm to be ready, so same effect as blocked init, no "opening" compared to eager mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138527 Approved by: https://github.com/wconstab ghstack dependencies: #138860
This commit is contained in:
@ -321,25 +321,30 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_close_pg(self):
|
||||
@parametrize("eager_init", [True, False])
|
||||
def test_close_pg(self, eager_init: bool):
|
||||
# Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
|
||||
# abort the process group.
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
device = self.rank_to_GPU[self.rank][0]
|
||||
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
device_id=device if eager_init else None,
|
||||
)
|
||||
|
||||
t = torch.rand(10, 10, device=device)
|
||||
# First allreduce to initialize state.
|
||||
pg.allreduce(t)
|
||||
dist.all_reduce(t)
|
||||
|
||||
# Destroy pg and validate pg is no longer valid
|
||||
dist.destroy_process_group()
|
||||
with self.assertRaises(dist.DistBackendError):
|
||||
pg.allreduce([t])
|
||||
|
||||
del pg
|
||||
with self.assertRaises(ValueError):
|
||||
dist.all_reduce(t)
|
||||
|
||||
CUDA_12_AND_ABOVE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12
|
||||
@ -803,27 +808,24 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_comm_lazy_init_split(self):
|
||||
@parametrize("eager_init", [True, False])
|
||||
def test_new_group(self, eager_init: bool):
|
||||
# Test the optimization of new groups that contain all world
|
||||
# ranks use the "transparent" `ncclCommSplit` optimization.
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
|
||||
# Test lazy splitting behavior across each per-device backend.
|
||||
for device in self.rank_to_GPU[self.rank]:
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
|
||||
# split doesn't happen unless the original process group has lazily
|
||||
# created communicators, so first verify we haven't split even when
|
||||
# making the new group and running an operation on the original pg.
|
||||
ng = c10d.new_group()
|
||||
tensor = torch.tensor([self.rank]).cuda(device)
|
||||
pg.broadcast(tensor, 0)
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
|
||||
# The new group will not force a split because it is a lazy init.
|
||||
ng.broadcast(tensor, 0)
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
device_id=device if eager_init else None,
|
||||
)
|
||||
ng = c10d.new_group()
|
||||
tensor = torch.tensor([self.rank], device=device)
|
||||
dist.broadcast(tensor, 0)
|
||||
dist.broadcast(tensor, 0, group=ng)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@ -863,15 +865,11 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
self.assertEqual(backend._is_initialized(), False)
|
||||
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
# create a subgroup eagerly
|
||||
new_group = c10d.new_group([0, 1], device_id=device)
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
|
||||
new_backend = new_group._get_backend(torch.device(device))
|
||||
self.assertEqual(new_backend._is_initialized(), True)
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
dist.broadcast(tensor, 0, group=new_group)
|
||||
self.assertEqual(new_backend.comm_split_count(), 0)
|
||||
# the default group should stay lazy
|
||||
self.assertEqual(backend._is_initialized(), False)
|
||||
torch.cuda.synchronize()
|
||||
dist.destroy_process_group()
|
||||
|
@ -159,7 +159,6 @@ static inline void NCCL_CHECK(ncclResult_t result) {
|
||||
}
|
||||
|
||||
// TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
|
||||
// Default value: on
|
||||
bool nccl_use_nonblocking() {
|
||||
static bool nccl_use_nonblocking_ =
|
||||
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
|
||||
@ -194,7 +193,8 @@ static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
|
||||
currentTimepoint - startTimepoint)
|
||||
.count();
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) {
|
||||
throw std::runtime_error("NCCL timeout.");
|
||||
throw std::runtime_error(
|
||||
"NCCL timeout when waiting for nonblocking call to become successful.");
|
||||
}
|
||||
sched_yield(); // yield to other threads
|
||||
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
|
||||
@ -226,7 +226,8 @@ static inline void NCCL_CHECK_TIMEOUT(
|
||||
currentTimepoint - startTimepoint)
|
||||
.count();
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) {
|
||||
throw std::runtime_error("NCCL timeout.");
|
||||
throw std::runtime_error(
|
||||
"NCCL timeout when waiting for nonblocking call to become successful.");
|
||||
}
|
||||
sched_yield(); // yield to other threads
|
||||
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
|
||||
|
@ -31,7 +31,7 @@ ncclComm_t NCCLComm::getNcclComm() {
|
||||
commFailureMsg));
|
||||
}
|
||||
// In non-blocking mode, ensure comm is ready.
|
||||
if (nccl_use_nonblocking()) {
|
||||
if (nonBlocking_) {
|
||||
// If timeout is reached, throw an exception.
|
||||
C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt);
|
||||
// ncclComm_ should be initialized by now
|
||||
@ -101,6 +101,7 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
|
||||
#endif
|
||||
++source->ncclCommSplitCounter_;
|
||||
comm->rank_ = rank;
|
||||
comm->nonBlocking_ = config.blocking == 0;
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": created child comm "
|
||||
<< comm->repr() << " with color_id " << color_id;
|
||||
return comm;
|
||||
@ -163,15 +164,6 @@ size_t hashTensors(const std::vector<at::Tensor>& tensors) {
|
||||
}
|
||||
#endif
|
||||
|
||||
bool nccl_use_nonblocking() {
|
||||
static bool nccl_use_nonblocking_ =
|
||||
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
|
||||
if (nccl_use_nonblocking_) {
|
||||
TORCH_WARN_ONCE("Using experimental non-blocking NCCL communicator.");
|
||||
}
|
||||
return nccl_use_nonblocking_;
|
||||
}
|
||||
|
||||
// Default value: 30 minutes
|
||||
int nccl_nonblocking_timeout() {
|
||||
static int timeout = -2; // -2 means not initialized
|
||||
|
@ -236,7 +236,6 @@ DEFINE_CONSTANT(started_state, "started");
|
||||
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
|
||||
TORCH_API std::string getNcclVersion();
|
||||
TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error);
|
||||
bool nccl_use_nonblocking();
|
||||
int nccl_nonblocking_timeout();
|
||||
|
||||
// Provides additional detail into NCCL error codes based on when these are
|
||||
@ -311,6 +310,8 @@ class NCCLComm {
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
comm->initialized_ = true;
|
||||
// Old style comm is always blocking.
|
||||
comm->nonBlocking_ = false;
|
||||
return comm;
|
||||
}
|
||||
|
||||
@ -321,26 +322,19 @@ class NCCLComm {
|
||||
ncclUniqueId commId,
|
||||
ncclConfig_t& config) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
bool isInitialized = false;
|
||||
if (nccl_use_nonblocking()) {
|
||||
config.blocking = 0;
|
||||
LOG(INFO) << "Rank " << rank
|
||||
<< ": creating NCCL communicator in nonblocking mode";
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
std::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
std::nullopt);
|
||||
// under blocking mode, comm is initialized after NCCL CHECK
|
||||
isInitialized = true;
|
||||
}
|
||||
comm->nonBlocking_ = config.blocking == 0;
|
||||
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
|
||||
<< (comm->nonBlocking_ ? "nonblocking" : "blocking");
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
std::nullopt);
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
comm->initialized_ = isInitialized;
|
||||
// Under blocking mode, comm is initialized immediately after NCCL init
|
||||
// returns; Under nonblocking mode, we check whether comm is initialized the
|
||||
// *next* time ncclComm_ is accessed.
|
||||
comm->initialized_ = !comm->nonBlocking_;
|
||||
return comm;
|
||||
}
|
||||
|
||||
@ -385,6 +379,7 @@ class NCCLComm {
|
||||
std::swap(aborted_, other.aborted_);
|
||||
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||
std::swap(initialized_, other.initialized_);
|
||||
std::swap(nonBlocking_, other.nonBlocking_);
|
||||
}
|
||||
|
||||
ncclComm_t getNcclComm();
|
||||
@ -553,6 +548,10 @@ class NCCLComm {
|
||||
// better error messaging.
|
||||
std::optional<std::string> commFailureReason_{};
|
||||
bool initialized_{false};
|
||||
// Whether this communicator is using nonblocking mode. Recorded during comm
|
||||
// creation or split. For safety, we give a default value of true (more
|
||||
// protection).
|
||||
bool nonBlocking_{true};
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
// Stores handlers for tensors registered by NCCL
|
||||
std::unordered_map<void*, void*> registeredSegmentHandles_;
|
||||
|
@ -987,7 +987,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
|
||||
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
|
||||
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
|
||||
<< ", TORCH_NCCL_USE_COMM_NONBLOCKING: " << nccl_use_nonblocking()
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
<< ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
|
||||
<< useTensorRegisterAllocatorHook_
|
||||
@ -1059,6 +1058,39 @@ void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
|
||||
getNCCLComm(key, device, OpType::ALLREDUCE);
|
||||
}
|
||||
|
||||
bool ProcessGroupNCCL::useNonblocking() {
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
return false;
|
||||
#endif
|
||||
// Already parsed, return the cached value
|
||||
if (useNonblocking_.has_value()) {
|
||||
return useNonblocking_.value();
|
||||
}
|
||||
// Get environment variable.
|
||||
auto nbEnv = c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING");
|
||||
|
||||
// 1st priority: Respect the user's setting
|
||||
if (options_->config.blocking != NCCL_CONFIG_UNDEF_INT) {
|
||||
useNonblocking_ = options_->config.blocking == 0;
|
||||
}
|
||||
// 2nd priority: Respect the environment variable
|
||||
else if (nbEnv.has_value()) {
|
||||
useNonblocking_ = nbEnv.value();
|
||||
}
|
||||
// 3rd priority: automatically use nonblocking if we are in eager init mode
|
||||
else if (getBoundDeviceId()) {
|
||||
useNonblocking_ = true;
|
||||
}
|
||||
// 4th priority: otherwise, nonblocking = false to preserve old behavior
|
||||
else {
|
||||
useNonblocking_ = false;
|
||||
}
|
||||
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "Using non-blocking mode: " << useNonblocking_.value();
|
||||
return useNonblocking_.value();
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
|
||||
// If our backend doesn't support splitting, this is a no-op for
|
||||
// ranks not in the new subgroup (and ranks that would be in it will
|
||||
@ -1067,6 +1099,8 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
|
||||
const auto key = getKeyFromDevice(device);
|
||||
LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
|
||||
<< device << ", key " << key << ", i am " << this;
|
||||
bool useNb = useNonblocking();
|
||||
options_->config.blocking = useNb ? 0 : 1;
|
||||
auto comm = getNCCLComm(key, device, OpType::ALLREDUCE);
|
||||
NCCLComm::split(
|
||||
comm.get(),
|
||||
@ -2357,6 +2391,11 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
|
||||
rank = p2pRank;
|
||||
}
|
||||
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
bool useNb = useNonblocking();
|
||||
options_->config.blocking = useNb ? 0 : 1;
|
||||
#endif
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
if (options_->split_from) {
|
||||
// Find a valid, healthy communicator to split from if possible.
|
||||
@ -2773,7 +2812,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
|
||||
work->ncclStartEvent_->record(ncclStream);
|
||||
}
|
||||
|
||||
if (nccl_use_nonblocking()) {
|
||||
if (useNonblocking()) {
|
||||
groupEndNonblocking(comm);
|
||||
} else {
|
||||
groupEnd();
|
||||
@ -3093,8 +3132,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
|
||||
#endif
|
||||
|
||||
{
|
||||
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
|
||||
comm, nccl_use_nonblocking());
|
||||
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking());
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
// Both `inputs' and `outputs' are created on a worker stream and used in
|
||||
// different ncclStreams. Hence, both must record the ncclStream to
|
||||
@ -4662,7 +4700,7 @@ void ProcessGroupNCCL::groupEndNonblocking(
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
|
||||
#else
|
||||
if (!nccl_use_nonblocking()) {
|
||||
if (!useNonblocking()) {
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt);
|
||||
|
@ -778,6 +778,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Abort all communicators on this rank.
|
||||
bool abortComms(const std::optional<std::string>& abortReason = std::nullopt);
|
||||
|
||||
// A helper function to check if nonblocking API mode should be used.
|
||||
// Use this helper instead of directly checking `useNonblocking_` variable.
|
||||
bool useNonblocking();
|
||||
|
||||
private:
|
||||
int globalRankStart;
|
||||
int globalRankStride;
|
||||
@ -1237,6 +1241,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
std::shared_ptr<ProcessGroupStatus> pgStatus_ =
|
||||
std::make_shared<ProcessGroupStatus>();
|
||||
|
||||
// Internal cached value: use NCCL non-blocking API mode or not.
|
||||
// Use `useNonblocking()` method instead of accessing this variable directly.
|
||||
std::optional<bool> useNonblocking_{std::nullopt};
|
||||
};
|
||||
|
||||
// Dumps the NCCL comm traces and additional information about the Process
|
||||
|
Reference in New Issue
Block a user