mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[a2av] Separate in/out splits into two tensors (#163837)
Old signature: `all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name)` New signature: `all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name)` i.e. split `in_out_splits` into IN tensor and OUT tensor so that we can define the TORCH_LIBRARY signature better. Also to be in line with the 2D version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163837 Approved by: https://github.com/fduwjj ghstack dependencies: #163886
This commit is contained in:
@ -299,28 +299,33 @@ class NVSHMEMAll2AllTest(MultiProcContinuousTest):
|
||||
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
|
||||
)
|
||||
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
|
||||
in_out_splits = symm_mem.empty(
|
||||
(3, self.world_size), dtype=torch.int64, device=self.device
|
||||
in_splits = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
out_splits_offsets = symm_mem.empty(
|
||||
(2, self.world_size), dtype=torch.int64, device=self.device
|
||||
)
|
||||
# Row 0 is input splits
|
||||
in_out_splits[0].copy_(inp_splits)
|
||||
in_splits.copy_(inp_splits)
|
||||
|
||||
# Sync all ranks to ensure remote tensors are allocated
|
||||
dist.barrier()
|
||||
|
||||
torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name)
|
||||
torch.ops.symm_mem.all_to_all_vdev(
|
||||
inp, out, in_splits, out_splits_offsets, group_name
|
||||
)
|
||||
|
||||
# Check input splits (row 0) -- should not change
|
||||
torch.testing.assert_close(in_out_splits[0], inp_splits)
|
||||
torch.testing.assert_close(in_splits, inp_splits)
|
||||
|
||||
# Check output splits (row 1)
|
||||
torch.testing.assert_close(in_out_splits[1], out_splits)
|
||||
torch.testing.assert_close(out_splits_offsets[0], out_splits)
|
||||
|
||||
# Check output offsets (row 2)
|
||||
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
|
||||
# output offsets from `all_to_all_vdev` is exclusive scan
|
||||
self.assertEqual(in_out_splits[2][0], 0)
|
||||
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
|
||||
self.assertEqual(out_splits_offsets[1][0], 0)
|
||||
torch.testing.assert_close(out_splits_offsets[1][1:], out_offsets[:-1])
|
||||
|
||||
# Check data
|
||||
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
|
||||
|
@ -502,7 +502,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
m.def(
|
||||
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
|
||||
m.def(
|
||||
"all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
|
||||
"all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name) -> ()");
|
||||
m.def(
|
||||
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
|
||||
m.def(
|
||||
|
@ -180,16 +180,15 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
|
||||
// - input splits (IN)
|
||||
// - output splits (OUT) and
|
||||
// - source offsets (OUT).
|
||||
__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, nvshmem_team_t team) {
|
||||
__global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_splits_offsets, nvshmem_team_t team) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
|
||||
int mype = nvshmem_team_my_pe(team);
|
||||
int npes = nvshmem_team_n_pes(team);
|
||||
auto input_splits = in_out_splits;
|
||||
auto output_splits = in_out_splits + npes;
|
||||
auto source_offsets = in_out_splits + npes * 2;
|
||||
auto output_splits = out_splits_offsets;
|
||||
auto source_offsets = out_splits_offsets + npes;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
|
||||
@ -214,15 +213,15 @@ __global__ void exchangeSplitAndOffset(int64_t* in_out_splits, nvshmem_team_t te
|
||||
// This kernel is used to do the actual data exchange.
|
||||
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
|
||||
// `stride` is the stride at dim 0, unit in byte.
|
||||
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, nvshmem_team_t team) {
|
||||
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_offsets, size_t stride, nvshmem_team_t team) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
|
||||
int mype = nvshmem_team_my_pe(team);
|
||||
int npes = nvshmem_team_n_pes(team);
|
||||
auto output_splits = in_out_splits + npes;
|
||||
auto source_offsets = in_out_splits + npes * 2;
|
||||
auto output_splits = out_splits_offsets;
|
||||
auto source_offsets = out_splits_offsets + npes;
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int blocks_per_peer = max(gridDim.x / npes, 1);
|
||||
@ -277,29 +276,31 @@ static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
|
||||
return std::min(num_blocks, max_blocks);
|
||||
}
|
||||
|
||||
at::Tensor all_to_all_vdev(
|
||||
void all_to_all_vdev(
|
||||
at::Tensor& input,
|
||||
at::Tensor& out,
|
||||
at::Tensor& in_out_splits,
|
||||
at::Tensor& in_splits,
|
||||
at::Tensor& out_splits_offsets,
|
||||
std::string group_name) {
|
||||
/* Perform AllToAllv operation using NVSHMEM, with split information provided on device.
|
||||
* Arguments:
|
||||
* - `input` is the input tensor
|
||||
* - `out` is the output tensor
|
||||
* - `in_out_splits` is a 2D tensor of size (3, npes). The rows are (in order):
|
||||
input splits (IN)
|
||||
output splits (OUT) and
|
||||
output offsets (OUT).
|
||||
* - `in_splits` is a 1D tensor of size (npes), containing the input splits
|
||||
* - `out_splits_offsets` is a 2D tensor of size (2, npes). The rows are (in order):
|
||||
output splits and output offsets.
|
||||
*/
|
||||
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
|
||||
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
|
||||
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
|
||||
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
|
||||
auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name);
|
||||
int rank = input_hdl->get_rank();
|
||||
int world_size = input_hdl->get_world_size();
|
||||
|
||||
void* input_ptr = input.data_ptr();
|
||||
void* output_ptr = out.mutable_data_ptr();
|
||||
int64_t* splits_ptr = (int64_t*)(in_out_splits.mutable_data_ptr());
|
||||
int64_t* in_splits_ptr = (int64_t*)(in_splits.const_data_ptr());
|
||||
int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr());
|
||||
|
||||
TORCH_CHECK_EQ(input.device(), out.device());
|
||||
auto device = input.device();
|
||||
@ -311,7 +312,8 @@ at::Tensor all_to_all_vdev(
|
||||
// Exchange output splits and source offsets
|
||||
// Use collective launch because kernel involves nvshmem barrier
|
||||
void* args0[] = {
|
||||
&splits_ptr,
|
||||
&in_splits_ptr,
|
||||
&out_splits_offsets_ptr,
|
||||
&team};
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)exchangeSplitAndOffset,
|
||||
@ -335,7 +337,7 @@ at::Tensor all_to_all_vdev(
|
||||
void* args1[] = {
|
||||
&input_ptr,
|
||||
&output_ptr,
|
||||
&splits_ptr,
|
||||
&out_splits_offsets_ptr,
|
||||
&stride_bytes,
|
||||
&team};
|
||||
nvshmemx_collective_launch(
|
||||
@ -345,7 +347,6 @@ at::Tensor all_to_all_vdev(
|
||||
args1,
|
||||
0,
|
||||
stream);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Start of `all_to_all_vdev_2d`
|
||||
|
@ -32,10 +32,11 @@ at::Tensor nvshmem_all_to_all(
|
||||
at::Tensor& out,
|
||||
std::string group_name);
|
||||
|
||||
at::Tensor all_to_all_vdev(
|
||||
void all_to_all_vdev(
|
||||
at::Tensor& input,
|
||||
at::Tensor& out,
|
||||
at::Tensor& in_out_splits,
|
||||
at::Tensor& in_splits,
|
||||
at::Tensor& out_splits_offsets,
|
||||
std::string group_name);
|
||||
|
||||
void all_to_all_vdev_2d(
|
||||
|
Reference in New Issue
Block a user