Files
pytorch/torch/csrc/distributed
Ke Wen e2ce17c6ef [SymmMem][a2av] Use more CTAs for intra-node case (#153509)
Previously, we launch the a2av kernel with at most 8 blocks for intra-node cases, which turns out to saturate only 57 GB/s bandwidth.

This PR adds more blocks for intra-node, up to 8 per peer, pumping up data parallelism.  The kernel now achieves 350 GB/s SOL for Hopper. See figure.

It also uses a simple tuning based on input size to avoid jumping to 8 CTAs directly (i.e. 1, 2, 4, then 8)

For inter-node, we cap at 8 blocks, since 57 GB/s seems bigger than regular NIC bandwidths (400 Gb/s).

![all_to_all_vdev Performance on 8xH100](https://github.com/user-attachments/assets/d4b841e6-4c42-4a2e-aa9f-2bc116ba9d25)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153509
Approved by: https://github.com/ngimel
ghstack dependencies: #153483
2025-05-14 04:24:32 +00:00
..