shrink_group implementation to expose ncclCommShrink API (#164518)

Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/Skylion007, https://github.com/syed-ahmed, https://github.com/kwen2501
This commit is contained in:
Bruce Chang
2025-10-17 17:55:00 +00:00
committed by PyTorch MergeBot
parent 39e0a832c9
commit a032510db3
11 changed files with 1503 additions and 2 deletions

View File

@ -90,6 +90,10 @@ static_assert(
#define NCCL_HAS_NVLS_CTAS
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
#define NCCL_HAS_COMM_SHRINK
#endif
// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \
@ -294,6 +298,14 @@ class NCCLComm {
ncclConfig_t& config);
#endif // NCCL_HAS_COMM_SPLIT
#ifdef NCCL_HAS_COMM_SHRINK
static std::shared_ptr<NCCLComm> shrink(
NCCLComm* source,
std::vector<int>& ranks_to_exclude,
ncclConfig_t* config,
int shrinkFlags = 0);
#endif // NCCL_HAS_COMM_SHRINK
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> ncclCommDump();
#endif