mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[SymmMem][a2av] Use more CTAs for intra-node case (#153509)
Previously, we launch the a2av kernel with at most 8 blocks for intra-node cases, which turns out to saturate only 57 GB/s bandwidth. This PR adds more blocks for intra-node, up to 8 per peer, pumping up data parallelism. The kernel now achieves 350 GB/s SOL for Hopper. See figure. It also uses a simple tuning based on input size to avoid jumping to 8 CTAs directly (i.e. 1, 2, 4, then 8) For inter-node, we cap at 8 blocks, since 57 GB/s seems bigger than regular NIC bandwidths (400 Gb/s).  Pull Request resolved: https://github.com/pytorch/pytorch/pull/153509 Approved by: https://github.com/ngimel ghstack dependencies: #153483
This commit is contained in:
@ -202,22 +202,30 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_spli
|
||||
auto source_offsets = in_out_splits + npes * 2;
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int blocks_per_peer = max(gridDim.x / npes, 1);
|
||||
|
||||
// Calculate the output offsets
|
||||
__shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
|
||||
prefixSum(peer_offsets, output_splits, npes);
|
||||
__syncthreads();
|
||||
|
||||
// Each block targets a different peer
|
||||
for (int i = bid; i < npes; i += gridDim.x) {
|
||||
// Target a different peer based on bid
|
||||
for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) {
|
||||
int peer = (mype + i) % npes;
|
||||
auto size = output_splits[peer] * stride;
|
||||
auto source_offset = source_offsets[peer] * stride;
|
||||
auto write_offset = peer_offsets[peer] * stride;
|
||||
// Total amount from `peer`
|
||||
auto peer_size = output_splits[peer] * stride;
|
||||
// Amount to get from `peer` in this block
|
||||
auto block_size = peer_size / blocks_per_peer;
|
||||
// Being lazy here, we should handle the residual if the division is not exact
|
||||
CUDA_KERNEL_ASSERT(block_size * blocks_per_peer == peer_size);
|
||||
// This block's offset in the data from `peer`
|
||||
auto block_offset = block_size * (bid % blocks_per_peer);
|
||||
auto source_offset = source_offsets[peer] * stride + block_offset;
|
||||
auto write_offset = peer_offsets[peer] * stride + block_offset;
|
||||
nvshmemx_getmem_block(
|
||||
(char*)recv_data + write_offset,
|
||||
(char*)send_data + source_offset,
|
||||
size,
|
||||
block_size,
|
||||
peer);
|
||||
}
|
||||
// Write out the output offsets (to the scratchpad line)
|
||||
@ -266,11 +274,26 @@ at::Tensor nvshmem_all_to_all_vdev(
|
||||
0,
|
||||
stream);
|
||||
|
||||
// All to all data exchange
|
||||
// Limit the number of blocks to 16
|
||||
int num_blocks = std::min(world_size, 16);
|
||||
// CTA Tuning
|
||||
// Intra-node: use multiple blocks per peer to increase data parallelism, up to 8.
|
||||
// Up to 1 MB -> 1 block
|
||||
// Up to 2 MB -> 2 blocks
|
||||
// Up to 4 MB -> 4 blocks
|
||||
// More -> 8 blocks
|
||||
auto input_size = input.numel() * input.element_size();
|
||||
const int max_blocks_per_peer = input_size < 1024 * 1024 ? 1 :
|
||||
(input_size < 2 * 1024 * 1024 ? 2 :
|
||||
(input_size < 4 * 1024 * 1024 ? 4 : 8));
|
||||
|
||||
// Inter-node: limit the total the number of blocks to 8 which is able to
|
||||
// drive 57 GB/s bandwidth in test, enough to drive a 400 Gb/s NIC.
|
||||
// TODO: better intra vs inter detection, currently it is based on world_size
|
||||
int num_blocks = world_size > 8 ? 8 : max_blocks_per_peer * world_size;
|
||||
|
||||
// Stride at dim 0 (assuming input is contiguous, TODO)
|
||||
size_t stride_bytes = input.stride(0) * input.element_size();
|
||||
|
||||
// All to all data exchange
|
||||
void* args1[] = {
|
||||
&input_ptr,
|
||||
&output_ptr,
|
||||
|
Reference in New Issue
Block a user