Add python bindings for NCCL CTA policies (#164309)

NCCLConfig can now be constructed with non-default [cta policies][1]

```python
import torch
from torch.distributed import ProcessGroupNCCL as nccl

config = nccl.NCCLConfig()
config.cta_policy = nccl.NCCL_CTA_POLICY_ZERO  # NCCL version >= 2.28
```

[1]: https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2283/user-guide/docs/api/flags.html#nccl-communicator-cta-policy-flags

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164309
Approved by: https://github.com/eqy
This commit is contained in:
Lakshay Garg
2025-10-07 18:16:20 +00:00
committed by PyTorch MergeBot
parent 078d475d3b
commit 9ecd092bd9

View File

@ -3358,6 +3358,20 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
return ::c10d::getNcclVersionTuple();
});
#ifdef NCCL_HAS_CTA_POLICY
processGroupNCCL.def_property_readonly_static(
"NCCL_CTA_POLICY_DEFAULT",
[](const py::object&) { return NCCL_CTA_POLICY_DEFAULT; });
processGroupNCCL.def_property_readonly_static(
"NCCL_CTA_POLICY_EFFICIENCY",
[](const py::object&) { return NCCL_CTA_POLICY_EFFICIENCY; });
#ifdef NCCL_CTA_POLICY_ZERO // requires NCCL version >= 2.28
processGroupNCCL.def_property_readonly_static(
"NCCL_CTA_POLICY_ZERO",
[](const py::object&) { return NCCL_CTA_POLICY_ZERO; });
#endif // NCCL_CTA_POLICY_ZERO
#endif // NCCL_HAS_CTA_POLICY
module.def(
"_get_intra_node_comm_usage_counter",
&::c10d::intra_node_comm::getIntraNodeCommUsageCounter);