mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enables configuration of NCCL communicators (#97394)
NCCL 2.17+ introduces some user configurable parameters for NCCL communicators using [ncclConfig_t](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#c.ncclConfig_t) datatype and [ncclCommInitRankConfig](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrankconfig). This PR enables that feature. A user can tune the parameters as follows: ``` import torch.distributed as dist nccl_options = dist.ProcessGroupNCCL.Options() nccl_options.config.max_ctas = 32 nccl_options.config.min_ctas = 8 nccl_options.config.cga_cluster_size = 2 dist.init_process_group(backend='nccl', init_method='env://', pg_options=nccl_options) my_group = dist.new_group(pg_options=nccl_options) ``` The default values of these parameters are what is initialized by `NCCL_CONFIG_INITIALIZER`. Only for DistributedDataParallel, this PR sets the default value of cga_cluster_size to 2 (a heuristic that works well especially for DDP workloads). Tuning these parameters can lead to improvement in end-to-end performance, since it affects the communication-computation overlap for NCCL kernels. CC: @ptrblck @kwen2501 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97394 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
3cae6d2493
commit
870880236b
@ -52,6 +52,12 @@
|
||||
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
|
||||
#endif
|
||||
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17)
|
||||
#define NCCL_HAS_COMM_CTA_CGA
|
||||
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
|
||||
#define NCCL_HAS_COMM_CTA_CGA
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
@ -179,22 +185,34 @@ class NCCLComm {
|
||||
int rank,
|
||||
ncclUniqueId commId) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
|
||||
#else
|
||||
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
|
||||
if (nccl_use_nonblocking()) {
|
||||
config.blocking = 0;
|
||||
}
|
||||
C10D_NCCL_CHECK_TIMEOUT(
|
||||
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
|
||||
#endif
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
return comm;
|
||||
}
|
||||
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
static std::shared_ptr<NCCLComm> create(
|
||||
int numRanks,
|
||||
int rank,
|
||||
ncclUniqueId commId,
|
||||
ncclConfig_t& config) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
if (nccl_use_nonblocking()) {
|
||||
config.blocking = 0;
|
||||
C10D_NCCL_CHECK_TIMEOUT(
|
||||
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt);
|
||||
}
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
return comm;
|
||||
}
|
||||
#endif
|
||||
|
||||
ncclUniqueId getNcclId() {
|
||||
return ncclId_;
|
||||
}
|
||||
|
Reference in New Issue
Block a user