[a2av] Separate in/out splits into two tensors (#163837)

Old signature:
`all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name)`
New signature:
`all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name)`

i.e. split `in_out_splits` into IN tensor and OUT tensor so that we can define the TORCH_LIBRARY signature better.
Also to be in line with the 2D version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163837
Approved by: https://github.com/fduwjj
ghstack dependencies: #163886
This commit is contained in:
Ke Wen
2025-09-25 18:00:38 -07:00
committed by PyTorch MergeBot
parent 5daa79fd6e
commit bbf8aa43ef
4 changed files with 36 additions and 29 deletions

View File

@ -299,28 +299,33 @@ class NVSHMEMAll2AllTest(MultiProcContinuousTest):
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
in_out_splits = symm_mem.empty(
(3, self.world_size), dtype=torch.int64, device=self.device
in_splits = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
out_splits_offsets = symm_mem.empty(
(2, self.world_size), dtype=torch.int64, device=self.device
)
# Row 0 is input splits
in_out_splits[0].copy_(inp_splits)
in_splits.copy_(inp_splits)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name)
torch.ops.symm_mem.all_to_all_vdev(
inp, out, in_splits, out_splits_offsets, group_name
)
# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_out_splits[0], inp_splits)
torch.testing.assert_close(in_splits, inp_splits)
# Check output splits (row 1)
torch.testing.assert_close(in_out_splits[1], out_splits)
torch.testing.assert_close(out_splits_offsets[0], out_splits)
# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
# output offsets from `all_to_all_vdev` is exclusive scan
self.assertEqual(in_out_splits[2][0], 0)
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
self.assertEqual(out_splits_offsets[1][0], 0)
torch.testing.assert_close(out_splits_offsets[1][1:], out_offsets[:-1])
# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)

View File

@ -502,7 +502,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
m.def(
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
m.def(
"all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
"all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name) -> ()");
m.def(
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
m.def(

View File

@ -180,16 +180,15 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
// - input splits (IN)
// - output splits (OUT) and
// - source offsets (OUT).
__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, nvshmem_team_t team) {
__global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_splits_offsets, nvshmem_team_t team) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
int mype = nvshmem_team_my_pe(team);
int npes = nvshmem_team_n_pes(team);
auto input_splits = in_out_splits;
auto output_splits = in_out_splits + npes;
auto source_offsets = in_out_splits + npes * 2;
auto output_splits = out_splits_offsets;
auto source_offsets = out_splits_offsets + npes;
int tid = threadIdx.x;
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
@ -214,15 +213,15 @@ __global__ void exchangeSplitAndOffset(int64_t* in_out_splits, nvshmem_team_t te
// This kernel is used to do the actual data exchange.
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
// `stride` is the stride at dim 0, unit in byte.
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, nvshmem_team_t team) {
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_offsets, size_t stride, nvshmem_team_t team) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
int mype = nvshmem_team_my_pe(team);
int npes = nvshmem_team_n_pes(team);
auto output_splits = in_out_splits + npes;
auto source_offsets = in_out_splits + npes * 2;
auto output_splits = out_splits_offsets;
auto source_offsets = out_splits_offsets + npes;
int bid = blockIdx.x;
int tid = threadIdx.x;
int blocks_per_peer = max(gridDim.x / npes, 1);
@ -277,29 +276,31 @@ static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
return std::min(num_blocks, max_blocks);
}
at::Tensor all_to_all_vdev(
void all_to_all_vdev(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
at::Tensor& in_splits,
at::Tensor& out_splits_offsets,
std::string group_name) {
/* Perform AllToAllv operation using NVSHMEM, with split information provided on device.
* Arguments:
* - `input` is the input tensor
* - `out` is the output tensor
* - `in_out_splits` is a 2D tensor of size (3, npes). The rows are (in order):
input splits (IN)
output splits (OUT) and
output offsets (OUT).
* - `in_splits` is a 1D tensor of size (npes), containing the input splits
* - `out_splits_offsets` is a 2D tensor of size (2, npes). The rows are (in order):
output splits and output offsets.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name);
int rank = input_hdl->get_rank();
int world_size = input_hdl->get_world_size();
void* input_ptr = input.data_ptr();
void* output_ptr = out.mutable_data_ptr();
int64_t* splits_ptr = (int64_t*)(in_out_splits.mutable_data_ptr());
int64_t* in_splits_ptr = (int64_t*)(in_splits.const_data_ptr());
int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr());
TORCH_CHECK_EQ(input.device(), out.device());
auto device = input.device();
@ -311,7 +312,8 @@ at::Tensor all_to_all_vdev(
// Exchange output splits and source offsets
// Use collective launch because kernel involves nvshmem barrier
void* args0[] = {
&splits_ptr,
&in_splits_ptr,
&out_splits_offsets_ptr,
&team};
nvshmemx_collective_launch(
(const void*)exchangeSplitAndOffset,
@ -335,7 +337,7 @@ at::Tensor all_to_all_vdev(
void* args1[] = {
&input_ptr,
&output_ptr,
&splits_ptr,
&out_splits_offsets_ptr,
&stride_bytes,
&team};
nvshmemx_collective_launch(
@ -345,7 +347,6 @@ at::Tensor all_to_all_vdev(
args1,
0,
stream);
return out;
}
// Start of `all_to_all_vdev_2d`

View File

@ -32,10 +32,11 @@ at::Tensor nvshmem_all_to_all(
at::Tensor& out,
std::string group_name);
at::Tensor all_to_all_vdev(
void all_to_all_vdev(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
at::Tensor& in_splits,
at::Tensor& out_splits_offsets,
std::string group_name);
void all_to_all_vdev_2d(