mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Previously, we launch the a2av kernel with at most 8 blocks for intra-node cases, which turns out to saturate only 57 GB/s bandwidth. This PR adds more blocks for intra-node, up to 8 per peer, pumping up data parallelism. The kernel now achieves 350 GB/s SOL for Hopper. See figure. It also uses a simple tuning based on input size to avoid jumping to 8 CTAs directly (i.e. 1, 2, 4, then 8) For inter-node, we cap at 8 blocks, since 57 GB/s seems bigger than regular NIC bandwidths (400 Gb/s).  Pull Request resolved: https://github.com/pytorch/pytorch/pull/153509 Approved by: https://github.com/ngimel ghstack dependencies: #153483