[SymmMem][a2av] Use more CTAs for intra-node case (#153509)

Previously, we launch the a2av kernel with at most 8 blocks for intra-node cases, which turns out to saturate only 57 GB/s bandwidth.

This PR adds more blocks for intra-node, up to 8 per peer, pumping up data parallelism.  The kernel now achieves 350 GB/s SOL for Hopper. See figure.

It also uses a simple tuning based on input size to avoid jumping to 8 CTAs directly (i.e. 1, 2, 4, then 8)

For inter-node, we cap at 8 blocks, since 57 GB/s seems bigger than regular NIC bandwidths (400 Gb/s).

![all_to_all_vdev Performance on 8xH100](https://github.com/user-attachments/assets/d4b841e6-4c42-4a2e-aa9f-2bc116ba9d25)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153509
Approved by: https://github.com/ngimel
ghstack dependencies: #153483
This commit is contained in:
Ke Wen
2025-05-13 17:48:22 -07:00
committed by PyTorch MergeBot
parent 20dbe644c7
commit e2ce17c6ef

View File

@ -202,22 +202,30 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_spli
auto source_offsets = in_out_splits + npes * 2; auto source_offsets = in_out_splits + npes * 2;
int bid = blockIdx.x; int bid = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
int blocks_per_peer = max(gridDim.x / npes, 1);
// Calculate the output offsets // Calculate the output offsets
__shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; __shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
prefixSum(peer_offsets, output_splits, npes); prefixSum(peer_offsets, output_splits, npes);
__syncthreads(); __syncthreads();
// Each block targets a different peer // Target a different peer based on bid
for (int i = bid; i < npes; i += gridDim.x) { for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) {
int peer = (mype + i) % npes; int peer = (mype + i) % npes;
auto size = output_splits[peer] * stride; // Total amount from `peer`
auto source_offset = source_offsets[peer] * stride; auto peer_size = output_splits[peer] * stride;
auto write_offset = peer_offsets[peer] * stride; // Amount to get from `peer` in this block
auto block_size = peer_size / blocks_per_peer;
// Being lazy here, we should handle the residual if the division is not exact
CUDA_KERNEL_ASSERT(block_size * blocks_per_peer == peer_size);
// This block's offset in the data from `peer`
auto block_offset = block_size * (bid % blocks_per_peer);
auto source_offset = source_offsets[peer] * stride + block_offset;
auto write_offset = peer_offsets[peer] * stride + block_offset;
nvshmemx_getmem_block( nvshmemx_getmem_block(
(char*)recv_data + write_offset, (char*)recv_data + write_offset,
(char*)send_data + source_offset, (char*)send_data + source_offset,
size, block_size,
peer); peer);
} }
// Write out the output offsets (to the scratchpad line) // Write out the output offsets (to the scratchpad line)
@ -266,11 +274,26 @@ at::Tensor nvshmem_all_to_all_vdev(
0, 0,
stream); stream);
// All to all data exchange // CTA Tuning
// Limit the number of blocks to 16 // Intra-node: use multiple blocks per peer to increase data parallelism, up to 8.
int num_blocks = std::min(world_size, 16); // Up to 1 MB -> 1 block
// Up to 2 MB -> 2 blocks
// Up to 4 MB -> 4 blocks
// More -> 8 blocks
auto input_size = input.numel() * input.element_size();
const int max_blocks_per_peer = input_size < 1024 * 1024 ? 1 :
(input_size < 2 * 1024 * 1024 ? 2 :
(input_size < 4 * 1024 * 1024 ? 4 : 8));
// Inter-node: limit the total the number of blocks to 8 which is able to
// drive 57 GB/s bandwidth in test, enough to drive a 400 Gb/s NIC.
// TODO: better intra vs inter detection, currently it is based on world_size
int num_blocks = world_size > 8 ? 8 : max_blocks_per_peer * world_size;
// Stride at dim 0 (assuming input is contiguous, TODO) // Stride at dim 0 (assuming input is contiguous, TODO)
size_t stride_bytes = input.stride(0) * input.element_size(); size_t stride_bytes = input.stride(0) * input.element_size();
// All to all data exchange
void* args1[] = { void* args1[] = {
&input_ptr, &input_ptr,
&output_ptr, &output_ptr,