mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/Skylion007, https://github.com/syed-ahmed, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
39e0a832c9
commit
a032510db3
@ -90,6 +90,10 @@ static_assert(
|
||||
#define NCCL_HAS_NVLS_CTAS
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||
#define NCCL_HAS_COMM_SHRINK
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
@ -294,6 +298,14 @@ class NCCLComm {
|
||||
ncclConfig_t& config);
|
||||
#endif // NCCL_HAS_COMM_SPLIT
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
static std::shared_ptr<NCCLComm> shrink(
|
||||
NCCLComm* source,
|
||||
std::vector<int>& ranks_to_exclude,
|
||||
ncclConfig_t* config,
|
||||
int shrinkFlags = 0);
|
||||
#endif // NCCL_HAS_COMM_SHRINK
|
||||
|
||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> ncclCommDump();
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user