Compare commits

...

3 Commits

Author SHA1 Message Date
ece42ed689 Update
[ghstack-poisoned]
2025-10-31 17:42:50 -07:00
05b3b1024b Update
[ghstack-poisoned]
2025-10-31 17:38:04 -07:00
2fd0573646 Update (base update)
[ghstack-poisoned]
2025-10-31 17:38:04 -07:00
4 changed files with 1004 additions and 54 deletions

View File

@ -231,6 +231,36 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
dist.barrier()
def get_occurrence_numbers(tensor):
"""
Transform tensor to show which occurrence each element is.
Example: tensor([1, 2, 1, 3, 1, 2]) -> tensor([1, 1, 2, 1, 3, 2])
"""
device = tensor.device
# Get unique values and their inverse mapping
unique_vals, inverse = torch.unique(tensor, return_inverse=True)
# Create a tensor to count occurrences for each unique value
n_unique = len(unique_vals)
n_elements = len(tensor)
# Create a matrix where each row corresponds to a unique value
# and columns correspond to positions in the original tensor
indicator_matrix = torch.zeros(
n_unique, n_elements, dtype=torch.float, device=device
)
indicator_matrix[inverse, torch.arange(n_elements)] = 1.0
# Cumulative sum along columns gives us occurrence numbers
occurrence_counts = torch.cumsum(indicator_matrix, dim=1) - indicator_matrix
# Extract the occurrence number for each position
result = occurrence_counts[inverse, torch.arange(n_elements, device=device)]
return result.long()
@instantiate_parametrized_tests
@requires_nvshmem()
@requires_cuda_p2p_access()
@ -557,6 +587,314 @@ class NVSHMEMAll2AllTest(MultiProcContinuousTest):
# Check data
torch.testing.assert_close(out_expected, out[:out_numel])
@skipIfRocm
def test_make_a2a_exchange_plan(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
# Number of elements for a peer is random between [0, k)
k = 10
orig_inp_splits = torch.randint(k, (self.world_size,), device=self.device)
# Create symm_mem tensors
in_splits = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
src_offsets = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
out_splits = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
dst_offsets = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
in_splits.copy_(orig_inp_splits)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
symm_mem.make_a2a_exchange_plan(
in_splits, src_offsets, out_splits, dst_offsets, group_name
)
# Check input splits -- should not change
torch.testing.assert_close(in_splits, orig_inp_splits)
# Check output splits
# Exchange input splits to get output splits
expected_out_splits = torch.zeros_like(orig_inp_splits)
dist.all_to_all_single(expected_out_splits, orig_inp_splits)
torch.testing.assert_close(expected_out_splits, out_splits)
# Check src offsets
orig_src_offsets = torch.cumsum(orig_inp_splits, dim=0) # inclusive scan
# Make it exclusive
orig_src_offsets = torch.cat(
[torch.zeros(1, device=self.device), orig_src_offsets[:-1]]
).to(torch.int64)
expected_src_offsets = torch.empty_like(orig_src_offsets)
dist.all_to_all_single(expected_src_offsets, orig_src_offsets)
torch.testing.assert_close(src_offsets, expected_src_offsets)
# Check dst offsets
expected_dst_offsets = torch.cumsum(
expected_out_splits, dim=0
) # inclusive scan
self.assertEqual(dst_offsets[0], 0)
torch.testing.assert_close(dst_offsets[1:], expected_dst_offsets[:-1])
@skipIfRocm
def test_a2a_with_exchange_plan(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
# Number of elements for a peer is random between [0, k)
k = 10
orig_inp_splits = torch.randint(k, (self.world_size,), device=self.device)
# Create splits and offsets
in_splits = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
src_offsets = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
out_splits = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
dst_offsets = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
# Create data
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
max_inp_numel = k * self.world_size
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
overflow_factor = self.world_size # worst case: one rank receives all data
max_out_numel = max_inp_numel * overflow_factor
dtype = torch.float
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).copy_(
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
in_splits.copy_(orig_inp_splits)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
# Create exchange plan
plan = symm_mem.make_a2a_exchange_plan(
in_splits, src_offsets, out_splits, dst_offsets, group_name
)
# Prepare expected output
inp_numel = in_splits.sum().item()
out_numel = out_splits.sum().item()
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
dist.all_to_all_single(
expected, inp[:inp_numel], out_splits.tolist(), in_splits.tolist()
)
# Exchange data with plan
# Loop a couple times to ensure the plan is reusable
for _ in range(3):
symm_mem.all_to_all_v(inp, out, plan, group_name)
torch.testing.assert_close(out[:out_numel], expected)
@skipIfRocm
@parametrize("align", [1]) # `major_align` of output
def test_make_a2a_2d_exchange_plan(self, align: int) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
# Number of experts per rank
ne = 8
nsplits = ne * self.world_size
# Number of elements for an expert is random between [0, k)
k = 10
orig_inp_splits = torch.randint(
k, (nsplits,), dtype=torch.int64, device=self.device
)
# Create symm_mem tensors
in_splits = symm_mem.empty(nsplits, dtype=torch.int64, device=self.device)
src_offsets = symm_mem.empty(nsplits, dtype=torch.int64, device=self.device)
out_splits = symm_mem.empty(
nsplits, dtype=torch.int64, device=self.device
).fill_(0)
dst_offsets = symm_mem.empty(
nsplits, dtype=torch.int64, device=self.device
).fill_(0)
in_splits.copy_(orig_inp_splits)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
plan = symm_mem.make_a2a_2d_exchange_plan(
in_splits, src_offsets, out_splits, dst_offsets, group_name
)
# Exchange input splits to get output splits
expected_out_splits = torch.zeros_like(orig_inp_splits)
dist.all_to_all_single(expected_out_splits, orig_inp_splits)
# We do a .t() here because there is a rank-major to expert-major shuffle
expected_out_splits = expected_out_splits.reshape(self.world_size, ne).t()
torch.testing.assert_close(plan.out_splits, expected_out_splits.reshape(-1))
# Check dst offsets
out_split_list = expected_out_splits.tolist()
for i in range(ne):
expert_sum = 0
for j in range(self.world_size):
expert_sum += out_split_list[i][j]
# # Align up expert_sum
# expert_sum_aligned = (expert_sum + align - 1) // align * align
# # If 0, make it at least `align` (bc cutlass currently does not support empty bins)
# expert_sum_aligned = max(expert_sum_aligned, align)
# # last element absorbs the padding
# out_split_list[i][-1] += expert_sum_aligned - expert_sum
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
# Make it exclusive scan because that's what `all_to_all_vdev_2d` returns
out_offsets = torch.cat(
[torch.zeros(1, device=self.device), out_offsets[:-1]]
).to(torch.int64)
expected_dst_offsets = torch.empty(
nsplits, dtype=torch.int64, device=self.device
)
dist.all_to_all_single(
expected_dst_offsets,
out_offsets.reshape(ne, self.world_size).t().contiguous(),
)
torch.testing.assert_close(
expected_dst_offsets,
plan.dst_offsets,
msg=f"""
Expecting
{expected_dst_offsets}
Got
{plan.dst_offsets}""",
)
@skipIfRocm
def test_all_to_all_v_2d_index_push(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
# Number of experts per rank
ne = 4
tot_experts = ne * self.world_size
# Create topk indices of shape (n_tokens, topk)
topk = 2
n_tokens = 128
topk_indices = torch.randint(
tot_experts, (n_tokens, topk), dtype=torch.int64, device=self.device
)
# Convert indices to splits
orig_inp_splits = torch.histc(
topk_indices,
bins=tot_experts,
)
# Create symm_mem tensors
in_splits = symm_mem.empty(
tot_experts, dtype=torch.int64, device=self.device
).copy_(orig_inp_splits)
src_offsets = symm_mem.empty(tot_experts, dtype=torch.int64, device=self.device)
out_splits = symm_mem.empty(
tot_experts, dtype=torch.int64, device=self.device
).fill_(0)
dst_offsets = symm_mem.empty(
tot_experts, dtype=torch.int64, device=self.device
).fill_(0)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
plan = symm_mem.make_a2a_2d_exchange_plan(
in_splits, src_offsets, out_splits, dst_offsets, group_name
)
# Create data
max_out_tokens = n_tokens * self.world_size
dtype = torch.float
hid_dim = 1024
inp = symm_mem.empty(n_tokens, hid_dim, dtype=dtype, device=self.device).copy_(
torch.randn(n_tokens, hid_dim, dtype=dtype, device=self.device)
)
out = symm_mem.empty(
max_out_tokens, hid_dim, dtype=dtype, device=self.device
).fill_(-1)
# Figure out rank of each token in its expert chunk
occurrences = get_occurrence_numbers(topk_indices.view(-1))
# Number of CUDA blocks (random choice)
n_blocks = 2
# Evenly spread token to CUDA blocks
tokens_per_block = n_tokens // n_blocks
# Start offset of each CUDA block
b_start = torch.arange(
0, n_tokens, tokens_per_block, dtype=torch.int64, device=self.device
)
# Number of tokens for each CUDA block
b_len = torch.full(
(n_blocks,), tokens_per_block, dtype=torch.int64, device=self.device
)
# Ready signal for each CUDA block. In this test we set all tokens as ready in one shot
b_head = b_start + b_len
dist.barrier()
torch.ops.symm_mem._all_to_all_v_2d_index_push(
inp,
out,
topk_indices,
occurrences,
plan.dst_offsets,
group_name,
b_start,
b_len,
b_head,
)
# Check data using all_to_all_vdev_2d
# Token sequence is inflated topk times
expanded_seqlen = n_tokens * topk
sorted_indices = torch.argsort(topk_indices.view(-1))
expanded_inp = symm_mem.empty(
expanded_seqlen, hid_dim, dtype=dtype, device=self.device
).copy_(inp[sorted_indices // topk])
overflow = 2
expected_out = symm_mem.empty(
expanded_seqlen * overflow, hid_dim, dtype=dtype, device=self.device
)
out_splits_offsets = symm_mem.empty(
(2, tot_experts), dtype=torch.int64, device=self.device
)
dist.barrier()
torch.ops.symm_mem.all_to_all_vdev_2d(
expanded_inp, expected_out, in_splits, out_splits_offsets, group_name
)
# Check data
out_len = out_splits_offsets[1][-1] + out_splits_offsets[0][-1]
torch.testing.assert_close(out[:out_len], expected_out[:out_len])
# Help function used by multiple tests
def dispatch_then_combine(device, align: int, group) -> None:

View File

@ -510,6 +510,14 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
m.def(
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
m.def(
"_make_a2a_exchange_plan(Tensor in_splits, Tensor(a!) src_offsets, Tensor(a!) out_splits, Tensor(a!) dst_offsets, str group_name) -> ()");
m.def(
"_all_to_all_get(Tensor input, Tensor(a!) out, Tensor src_offsets, Tensor out_splits, Tensor dst_offsets, str group_name, Tensor? b_start, Tensor? b_len, Tensor? b_head) -> ()");
m.def(
"_make_a2a_2d_exchange_plan(Tensor in_splits, Tensor(a!) src_offsets, Tensor(a!) out_splits, Tensor(a!) dst_offsets, str group_name) -> ()");
m.def(
"_all_to_all_v_2d_index_push(Tensor input, Tensor(a!) out, Tensor topk_indices, Tensor occurrences, Tensor dst_offsets, str group_name, Tensor b_start, Tensor b_len, Tensor b_head) -> ()");
}
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

View File

@ -198,19 +198,16 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
}
// This kernel is used to exchange output splits and source offsets between peers.
// `in_out_splits` is of size (3, npes) and contains:
// - input splits (IN)
// - output splits (OUT) and
// - source offsets (OUT).
__global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_splits_offsets, nvshmem_team_t team) {
__device__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* source_offsets, int64_t* output_splits, nvshmem_team_t team) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
int mype = nvshmem_team_my_pe(team);
int npes = nvshmem_team_n_pes(team);
auto output_splits = out_splits_offsets;
auto source_offsets = out_splits_offsets + npes;
int tid = threadIdx.x;
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
@ -232,27 +229,38 @@ __global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_split
#endif
}
// This is a kernel wrapper for `exchangeSplitAndOffset`.
__global__ void exchangeSplitAndOffsetKernel(int64_t* in_splits, int64_t* src_offsets, int64_t* out_splits, nvshmem_team_t team) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
exchangeSplitAndOffset(in_splits, src_offsets, out_splits, team);
#endif
}
// This kernel is used to do the actual data exchange.
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
// `stride` is the stride at dim 0, unit in byte.
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_offsets, size_t stride, nvshmem_team_t team) {
__global__ void allToAllGet(void *send_data, void *recv_data, int64_t* source_offsets, int64_t* output_splits, int64_t* dst_offsets, size_t stride, nvshmem_team_t team) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
int mype = nvshmem_team_my_pe(team);
int npes = nvshmem_team_n_pes(team);
auto output_splits = out_splits_offsets;
auto source_offsets = out_splits_offsets + npes;
int bid = blockIdx.x;
int tid = threadIdx.x;
int blocks_per_peer = max(gridDim.x / npes, 1);
// Calculate the output offsets
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
// Calculate the output offsets if not provided
__shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
prefixSum(peer_offsets, output_splits, npes);
__syncthreads();
int64_t* out_offsets = dst_offsets;
if (out_offsets == nullptr) {
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
prefixSum(peer_offsets, output_splits, npes);
__syncthreads();
out_offsets = peer_offsets;
}
// Target a different peer based on bid
for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) {
@ -267,22 +275,81 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_
// This block's offset in the data from `peer`
auto block_offset = block_size * (bid % blocks_per_peer);
auto source_offset = source_offsets[peer] * stride + block_offset;
auto write_offset = peer_offsets[peer] * stride + block_offset;
auto write_offset = out_offsets[peer] * stride + block_offset;
nvshmemx_getmem_nbi_block(
(char*)recv_data + write_offset,
(char*)send_data + source_offset,
block_size,
peer_global);
}
// Write out the output offsets (to the scratchpad line)
if (bid == 0 && tid < npes) {
source_offsets[tid] = peer_offsets[tid];
// Write out the output offsets if not provided
if (dst_offsets == nullptr) {
if (bid == 0 && tid < npes) {
// source_offsets alias dst_offsets space when dst_offsets is not provided
source_offsets[tid] = out_offsets[tid];
}
}
// Make sure getmem_nbi calls finish
nvshmem_quiet();
#endif
}
// This kernel extends `allToAllGet` with an out signal that indicates readiness of a "token" after each fetch.
// `b_start`, `b_len` and `b_head` are fed into `_allToAllV_2d_index_push_kernel`.
__global__ void allToAllGet_signal_out(
void *send_data, void *recv_data, int64_t* source_offsets, int64_t* output_splits, int64_t* dst_offsets, size_t stride, nvshmem_team_t team,
int64_t* b_start, int64_t* b_len, int64_t* b_head) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
int mype = nvshmem_team_my_pe(team);
int npes = nvshmem_team_n_pes(team);
int bid = blockIdx.x;
CUDA_KERNEL_ASSERT(gridDim.x % npes == 0 && " Number of blocks must be multiple of npes\n");
int blocks_per_peer = gridDim.x / npes;
// dst_offset must be provided to this kernel e.g. via exchange plan
CUDA_KERNEL_ASSERT(dst_offsets != nullptr && " dst_offset must be provided\n");
int64_t* out_offsets = dst_offsets;
// Target a different peer based on bid, in shifting manner
int peer_shift = bid / blocks_per_peer;
int peer = (mype + peer_shift) % npes;
auto peer_global = nvshmem_team_translate_pe(team, peer, NVSHMEM_TEAM_WORLD);
// Total number of tokens from `peer`
auto peer_tokens = output_splits[peer];
// Amount to get from `peer` in this block
int64_t block_tokens = peer_tokens / blocks_per_peer;
// Being lazy here, we should handle the residual if the division is not exact
CUDA_KERNEL_ASSERT(block_tokens * blocks_per_peer == peer_tokens);
// Assign to b_len
b_len[bid] = block_tokens;
// This block's offset in the data from `peer`, all 3 offsets below are in unit of token
auto block_offset = block_tokens * (bid % blocks_per_peer);
auto source_offset = source_offsets[peer] + block_offset;
int64_t write_offset = out_offsets[peer] + block_offset;
// Assign to b_start
b_start[bid] = write_offset;
auto b_head_ptr = b_head + bid;
// Now let's start data fetch, token by token
for (int i = 0; i < block_tokens; i++) {
auto write_head = write_offset + i;
// Use blocking API for now
nvshmemx_getmem_block(
(char*)recv_data + write_head * stride,
(char*)send_data + (source_offset + i) * stride,
stride, // byte size of 1 token
peer_global);
// Advance ready signal, use release here as a memory fence, scope is local GPU.
// Writing `write_head + 1` because that's the protocol with downstream kernel.
asm volatile("st.release.gpu.global.s64 [%0], %1;" :: "l"(__cvta_generic_to_global(b_head_ptr)), "l"(write_head + 1) : "memory");
}
#endif
}
static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
// Check user setting first
int num_blocks = c10d::symmetric_memory::getenv_nblocks();
@ -298,6 +365,88 @@ static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
return std::min(num_blocks, max_blocks);
}
// Perform a 1D AllToAllv shuffle operation, with source offset information, and get operations.
// ** Caller must ensure that the splits and offsets are (i) symmetric addresses and (ii) have been rendezvoused. **
void _all_to_all_get_inner(
at::Tensor& input,
at::Tensor& out,
int64_t* src_offsets_ptr,
const int64_t* out_splits_ptr,
int64_t* dst_offsets_ptr,
std::string group_name,
std::optional<nvshmem_team_t> team,
const std::optional<at::Tensor>& b_start = std::nullopt,
const std::optional<at::Tensor>& b_len = std::nullopt,
const std::optional<at::Tensor>& b_head = std::nullopt) {
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
c10d::symmetric_memory::rendezvous(out, group_name);
// Verify inputs
TORCH_CHECK_EQ(input.device(), out.device());
TORCH_CHECK(input.dtype() == out.dtype(), "input and out must have the same dtype");
TORCH_CHECK(input.stride(0) == out.stride(0), "input and out must have the same stride at dim 0");
TORCH_CHECK(input.is_contiguous() && out.is_contiguous(), "input and out must be contiguous");
auto device = input.device();
c10::cuda::CUDAGuard guard(device);
auto stream = at::cuda::getCurrentCUDAStream(device.index());
// Get team if not provided
nvshmem_team_t team_to_use;
if (team.has_value()) {
team_to_use = team.value();
} else {
auto& team_manager = TeamManager::get(device);
team_to_use = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank());
}
bool has_signal = b_head.has_value();
// CTA Tuning
auto input_size = input.numel() * input.element_size();
int num_blocks = has_signal ?
b_head.value().size(0) :
get_a2a_nblocks(
input_size,
input_hdl->get_world_size(),
input_hdl->world_within_direct_access());
// Stride at dim 0
size_t stride_bytes = input.stride(0) * input.element_size();
auto input_ptr = input.const_data_ptr();
auto output_ptr = out.mutable_data_ptr();
// Required args
std::vector<void*> args = {
&input_ptr,
&output_ptr,
&src_offsets_ptr,
&out_splits_ptr,
&dst_offsets_ptr,
&stride_bytes,
&team_to_use};
// Optional args for signals
void *b_start_ptr, *b_len_ptr, *b_head_ptr;
if (has_signal) {
b_start_ptr = b_start.value().mutable_data_ptr();
b_len_ptr = b_len.value().mutable_data_ptr();
b_head_ptr = b_head.value().mutable_data_ptr();
args.push_back(b_start_ptr);
args.push_back(b_len_ptr);
args.push_back(b_head_ptr);
}
auto functor = has_signal ? (const void*)allToAllGet_signal_out : (const void*)allToAllGet;
C10_CUDA_CHECK(cudaLaunchKernel(
functor,
dim3(num_blocks),
dim3(THREADS_PER_BLOCK),
args.data(),
0,
stream));
}
void all_to_all_vdev(
at::Tensor& input,
at::Tensor& out,
@ -312,63 +461,40 @@ void all_to_all_vdev(
* - `out_splits_offsets` is a 2D tensor of size (2, npes). The rows are (in order):
output splits and output offsets.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name);
int rank = input_hdl->get_rank();
int world_size = input_hdl->get_world_size();
auto symm_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name);
int world_size = symm_hdl->get_world_size();
void* input_ptr = input.data_ptr();
void* output_ptr = out.mutable_data_ptr();
int64_t* in_splits_ptr = (int64_t*)(in_splits.const_data_ptr());
int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr());
TORCH_CHECK_EQ(input.device(), out.device());
auto device = input.device();
TORCH_CHECK_EQ(in_splits.device(), out_splits_offsets.device());
auto device = in_splits.device();
c10::cuda::CUDAGuard guard(device);
auto& team_manager = TeamManager::get(device);
auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank());
auto team = team_manager.get_team(group_name, symm_hdl->get_rank_to_global_rank());
auto stream = at::cuda::getCurrentCUDAStream(device.index());
// Exchange output splits and source offsets
// Use collective launch because kernel involves nvshmem barrier
// Borrowing the space of out_offsets as a temporary exchange pad for source offsets.
auto tmp_src_offsets_ptr = out_splits_offsets_ptr + world_size;
void* args0[] = {
&in_splits_ptr,
&tmp_src_offsets_ptr,
&out_splits_offsets_ptr,
&team};
// Use collective launch because kernel involves nvshmem barrier
nvshmemx_collective_launch(
(const void*)exchangeSplitAndOffset,
(const void*)exchangeSplitAndOffsetKernel,
dim3(1),
dim3(THREADS_PER_BLOCK),
args0,
0,
stream);
// CTA Tuning
auto input_size = input.numel() * input.element_size();
int num_blocks = get_a2a_nblocks(
input_size,
input_hdl->get_world_size(),
input_hdl->world_within_direct_access());
// Stride at dim 0 (assuming input is contiguous, TODO)
size_t stride_bytes = input.stride(0) * input.element_size();
// All to all data exchange
void* args1[] = {
&input_ptr,
&output_ptr,
&out_splits_offsets_ptr,
&stride_bytes,
&team};
nvshmemx_collective_launch(
(const void*)allToAllV,
dim3(num_blocks),
dim3(THREADS_PER_BLOCK),
args1,
0,
stream);
// Get data based on exchange plan
_all_to_all_get_inner(
input, out, tmp_src_offsets_ptr, out_splits_offsets_ptr, nullptr, group_name, team);
}
// Start of `all_to_all_vdev_2d`
@ -863,6 +989,376 @@ void all_to_all_vdev_2d_offset(
0,
stream);
}
// This kernel is used to exchange output splits and source offsets between peers.
__global__ void makeExchangePlan(int64_t* in_splits, int64_t* src_offsets, int64_t* out_splits, int64_t* dst_offsets, nvshmem_team_t team) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
// Given input splits, exchange the input offsets output splits
exchangeSplitAndOffset(in_splits, src_offsets, out_splits, team);
// Calculate the output offsets
int npes = nvshmem_team_n_pes(team);
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
__shared__ int64_t out_offsets[THREADS_PER_BLOCK];
prefixSum(out_offsets, out_splits, npes);
__syncthreads();
// Write out the output offsets
int tid = threadIdx.x;
if (tid < npes) {
dst_offsets[tid] = out_offsets[tid];
}
#endif
}
void _make_a2a_exchange_plan(
at::Tensor& in_splits,
at::Tensor& src_offsets,
at::Tensor& out_splits,
at::Tensor& dst_offsets,
std::string group_name) {
// Make an exchange plan for a 1D AllToAllv shuffle operation.
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
auto src_offsets_hdl = c10d::symmetric_memory::rendezvous(src_offsets, group_name);
auto out_splits_hdl = c10d::symmetric_memory::rendezvous(out_splits, group_name);
auto dst_offsets_hdl = c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
// Verify inputs
auto npes = in_splits_hdl->get_world_size();
TORCH_CHECK(npes <= THREADS_PER_BLOCK, "Number of peers must be smaller than THREADS_PER_BLOCK", THREADS_PER_BLOCK);
TORCH_CHECK(in_splits.size(0) == npes && src_offsets.size(0) == npes && out_splits.size(0) == npes && dst_offsets.size(0) == npes,
"in_splits, src_offsets, out_splits and dst_offsets must have the same size as world_size");
TORCH_CHECK(in_splits.scalar_type() == at::kLong && src_offsets.scalar_type() == at::kLong
&& out_splits.scalar_type() == at::kLong && dst_offsets.scalar_type() == at::kLong,
"splits and offsets must be int64");
auto in_splits_ptr = in_splits.const_data_ptr<int64_t>();
auto src_offsets_ptr = src_offsets.mutable_data_ptr<int64_t>();
auto out_splits_ptr = out_splits.mutable_data_ptr<int64_t>();
auto dst_offsets_ptr = dst_offsets.mutable_data_ptr<int64_t>();
auto device = in_splits.device();
c10::cuda::CUDAGuard guard(device);
auto& team_manager = TeamManager::get(device);
auto team = team_manager.get_team(group_name, in_splits_hdl->get_rank_to_global_rank());
auto stream = at::cuda::getCurrentCUDAStream(device.index());
// Exchange output splits and source offsets
// Use collective launch because kernel involves nvshmem barrier
void* args0[] = {
&in_splits_ptr,
&src_offsets_ptr,
&out_splits_ptr,
&dst_offsets_ptr,
&team};
nvshmemx_collective_launch(
(const void*)makeExchangePlan,
dim3(1),
dim3(THREADS_PER_BLOCK),
args0,
0,
stream);
}
void _all_to_all_get(
at::Tensor& input,
at::Tensor& out,
at::Tensor& src_offsets,
at::Tensor& out_splits,
at::Tensor& dst_offsets,
std::string group_name,
const std::optional<at::Tensor>& b_start = std::nullopt,
const std::optional<at::Tensor>& b_len = std::nullopt,
const std::optional<at::Tensor>& b_head = std::nullopt) {
// Perform a 1D AllToAllv shuffle operation, with source offset information, and get operations.
c10d::symmetric_memory::rendezvous(src_offsets, group_name);
c10d::symmetric_memory::rendezvous(out_splits, group_name);
c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
auto src_offsets_ptr = src_offsets.mutable_data_ptr<int64_t>();
auto out_splits_ptr = out_splits.const_data_ptr<int64_t>();
auto dst_offsets_ptr = dst_offsets.mutable_data_ptr<int64_t>();
_all_to_all_get_inner(
input, out, src_offsets_ptr, out_splits_ptr, dst_offsets_ptr, group_name,
/*team=*/ std::nullopt, // determined inside
b_start, b_len, b_head // optional signals
);
}
/* 2D all-to-all-v exchange plan */
#define MAX_N_PEERS (THREADS_PER_BLOCK / WARP_SIZE)
#define MAX_LOCAL_EXPERTS 8
#define MAX_N_EXPERTS (MAX_N_PEERS * MAX_LOCAL_EXPERTS)
// This kernel is used to exchange output splits and dest offsets between peers.
__global__ void make2dExchangePlan(int64_t* in_splits, int64_t* src_offsets, int64_t* out_splits, int64_t* dst_offsets, nvshmem_team_t team, int ne) {
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
int peer = threadIdx.x / WARP_SIZE;
int my_pe = nvshmem_team_my_pe(team);
int npes = nvshmem_team_n_pes(team);
int nsplits = ne * npes;
CUDA_KERNEL_ASSERT(npes <= MAX_N_PEERS && " Number of peers must be smaller than MAX_N_PEERS\n");
CUDA_KERNEL_ASSERT(nsplits <= MAX_N_EXPERTS && " Number of splits must be smaller than MAX_N_EXPERTS\n");
// Gather input splits from peers
__shared__ int64_t in_splits_all[MAX_N_PEERS][MAX_N_EXPERTS];
// Use 1 warp per peer, to get into shared memory
if (peer < npes) {
auto global_peer_id = nvshmem_team_translate_pe(team, peer, NVSHMEM_TEAM_WORLD);
nvshmemx_int64_get_warp(in_splits_all[peer], in_splits, nsplits, global_peer_id);
}
__syncthreads();
// Write to output splits
int local_expert = threadIdx.x % WARP_SIZE;
if (local_expert < ne) {
// experts I own
int expert_id = ne * my_pe + local_expert;
// out_splits is (ne, npes) flattened
out_splits[local_expert * npes + peer] = in_splits_all[peer][expert_id];
}
__syncthreads();
// This is the aggregate per expert. There are a total of npes * ne experts,
// and we arrange it as a 2D matrix
__shared__ int64_t aggregate_per_expert[MAX_N_PEERS][MAX_LOCAL_EXPERTS];
// Calculate the cusum of splits for an expert
// One thread per expert, for all experts
auto& cusum_per_expert = in_splits_all; // in-place prefix sum
if (peer < npes && local_expert < ne) {
int expert = peer * ne + local_expert;
int64_t tmp = 0, cusum = 0;
// This prefix sum is in the row direction, thus we cannot use cub helper
for (int rank = 0; rank < npes; rank++) {
tmp = in_splits_all[rank][expert];
cusum_per_expert[rank][expert] = cusum;
cusum += tmp;
}
aggregate_per_expert[peer][local_expert] = cusum;
}
__syncthreads();
// Calculate each expert's offset within their ranks
// There are two solutions below. The first one uses warpScan, and the second
// one uses manual prefix sum. warpScan's result is not stable, so we use the
// second one for now. (Since ne is usually small, the performance difference
// is negligible)
#if 0
// Solution 1: use warpScan
__shared__ int64_t cusum_per_rank[MAX_N_PEERS][MAX_LOCAL_EXPERTS];
// One warp per target rank
prefixSum_warp<MAX_N_PEERS>(cusum_per_rank[peer], aggregate_per_expert[peer], ne);
#else
// Solution 2: manual sum
auto& cusum_per_rank = aggregate_per_expert; // in-place prefix sum
if (peer < npes && local_expert == 0) {
int64_t tmp = 0, cusum = 0;
// Summing the experts of a rank
for (int e = 0; e < ne; e++) {
tmp = aggregate_per_expert[peer][e];
cusum_per_rank[peer][e] = cusum;
cusum += tmp;
}
}
#endif
__syncthreads();
// Now add the in-rank offsets to the in-expert offsets, then "I" will know
// where to write my data in the dest rank
if (peer < npes && local_expert < ne) {
int expert = peer * ne + local_expert;
dst_offsets[expert] = cusum_per_expert[my_pe][expert] + cusum_per_rank[peer][local_expert];
}
#endif
}
void _make_a2a_2d_exchange_plan(
at::Tensor& in_splits,
at::Tensor& src_offsets,
at::Tensor& out_splits,
at::Tensor& dst_offsets,
std::string group_name) {
// Make an exchange plan for a AllToAllv_2D shuffle operation.
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
auto src_offsets_hdl = c10d::symmetric_memory::rendezvous(src_offsets, group_name);
auto out_splits_hdl = c10d::symmetric_memory::rendezvous(out_splits, group_name);
auto dst_offsets_hdl = c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
// Verify inputs
auto npes = in_splits_hdl->get_world_size();
int nsplits = in_splits.size(0);
TORCH_CHECK(nsplits % npes == 0, "Number of splits must be a multiple of number of peers");
TORCH_CHECK(src_offsets.size(0) == nsplits && out_splits.size(0) == nsplits && dst_offsets.size(0) == nsplits,
"in_splits, src_offsets, out_splits and dst_offsets must have the same size");
TORCH_CHECK(in_splits.scalar_type() == at::kLong && src_offsets.scalar_type() == at::kLong
&& out_splits.scalar_type() == at::kLong && dst_offsets.scalar_type() == at::kLong,
"splits and offsets must be int64");
// Number of experts per rank
int ne = nsplits / npes;
auto in_splits_ptr = in_splits.const_data_ptr<int64_t>();
auto src_offsets_ptr = src_offsets.mutable_data_ptr<int64_t>();
auto out_splits_ptr = out_splits.mutable_data_ptr<int64_t>();
auto dst_offsets_ptr = dst_offsets.mutable_data_ptr<int64_t>();
auto device = in_splits.device();
c10::cuda::CUDAGuard guard(device);
auto& team_manager = TeamManager::get(device);
auto team = team_manager.get_team(group_name, in_splits_hdl->get_rank_to_global_rank());
auto stream = at::cuda::getCurrentCUDAStream(device.index());
// Exchange output splits and source offsets
// Use collective launch because kernel involves nvshmem barrier
void* args0[] = {
&in_splits_ptr,
&src_offsets_ptr,
&out_splits_ptr,
&dst_offsets_ptr,
&team,
&ne};
nvshmemx_collective_launch(
(const void*)make2dExchangePlan,
dim3(1),
dim3(THREADS_PER_BLOCK),
args0,
0,
stream);
}
/* ld.volatile.global: This refers to a volatile load from global memory. This
* ensures that the data is read directly from memory every time, preventing the
* compiler from optimizing the access by referring to a cached value. */
__device__ __forceinline__ int64_t ld_volatile_global(int64_t *ptr) {
int64_t ans;
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ans) : "l"((uintptr_t)__cvta_generic_to_global(ptr)) : "memory");
return ans;
}
/* This function ensures that no memory access instruction that appears after the
* ld.acquire in the program can be reordered to execute before it. Scope: within
* the GPU itself. */
__device__ __forceinline__ void fence_acquire_gpu() {
static __device__ int dummy;
int tmp;
asm volatile("ld.acquire.gpu.s32 %0,[%1];" : "=r"(tmp) : "l"(&dummy) : "memory");
dummy = tmp;
}
__global__ void _allToAllV_2d_index_push_kernel(
void *send_data, void *recv_data,
int64_t* topk_indices, int64_t* occurrences, int64_t* dst_offsets,
int topk, int n_local_experts, size_t stride, nvshmem_team_t team,
int64_t* b_start, int64_t* b_len, int64_t* b_head) {
/* Args:
* send_data: the data to be sent
* recv_data: the data to be received
* topk_indices: (n_tokens, topk), the experts to send current token to
* dst_offsets: (n_experts), the destination offsets of the expert chunk to be sent, within the destination rank
* occurrences: (n_tokens, topk), the rank of current token within the tokens to be sent to an expert
* stride: the stride of the data to be sent, i.e. token hidden size * element size, in bytes
* b_start: the start offsets of the data to be sent by each CUDA block (equivalent to offset of data received from each rail)
* b_len: the length of the data to be sent by each CUDA block
* b_head: the most recent ready-to-send token index, plus 1, for each CUDA block
*/
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
#else
auto bid = blockIdx.x;
auto tail = b_start[bid];
auto end = b_start[bid] + b_len[bid];
// Use volatile bc we need to re-read its value
volatile auto head_ptr = b_head + bid;
while (tail < end) {
while (tail >= ld_volatile_global(head_ptr)) {
// Wait for producer kernel to mark newer ready tokens
}
// Ready signal has been provided by producer kernel (on the same GPU).
// To make sure the token data is readily written by the producer kernel, we
// use an acquire fence here, with scope of the same GPU.
fence_acquire_gpu();
// New token has arrived
auto send_ptr = (char*)send_data + tail * stride;
// Loop over the topk experts
for (int k = 0; k < topk; k++) {
auto expert = topk_indices[tail * topk + k];
// Get the destination rank
auto dst_rank = expert / n_local_experts;
auto dst_rank_global = nvshmem_team_translate_pe(team, dst_rank, NVSHMEM_TEAM_WORLD);
// Get the destination offset
auto dst_offset = dst_offsets[expert] + occurrences[tail * topk + k];
// Get the destination pointer
auto dst_ptr = (char*)recv_data + dst_offset * stride;
// Send the data, i.e. 1 token
nvshmemx_putmem_nbi_block(dst_ptr, send_ptr, stride, dst_rank_global);
}
tail++;
}
// Make sure all data has been sent
nvshmem_quiet();
#endif
}
void _all_to_all_v_2d_index_push(
at::Tensor& input,
at::Tensor& out,
at::Tensor& topk_indices,
at::Tensor& occurrences,
at::Tensor& dst_offsets,
std::string group_name,
at::Tensor& b_start,
at::Tensor& b_len,
at::Tensor& b_head) {
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
c10d::symmetric_memory::rendezvous(out, group_name);
c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
auto input_ptr = input.data_ptr();
auto out_ptr = out.data_ptr();
auto topk_indices_ptr = topk_indices.data_ptr<int64_t>();
auto occurrences_ptr = occurrences.data_ptr<int64_t>();
auto dst_offsets_ptr = dst_offsets.data_ptr<int64_t>();
auto b_start_ptr = b_start.data_ptr<int64_t>();
auto b_len_ptr = b_len.data_ptr<int64_t>();
auto b_head_ptr = b_head.data_ptr<int64_t>();
auto topk = topk_indices.size(1);
auto n_local_experts = dst_offsets.size(0) / input_hdl->get_world_size();
auto stride_bytes = input.stride(0) * input.element_size();
auto device = input.device();
c10::cuda::CUDAGuard guard(device);
auto& team_manager = TeamManager::get(device);
auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank());
auto stream = at::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(
b_start.size(0) == b_len.size(0) && b_start.size(0) == b_head.size(0),
"Block start, len and head should have same size");
int nblocks = b_start.size(0);
void* args[] = {
&input_ptr, &out_ptr,
&topk_indices_ptr, &occurrences_ptr, &dst_offsets_ptr,
&topk, &n_local_experts, &stride_bytes, &team,
&b_start_ptr, &b_len_ptr, &b_head_ptr
};
C10_CUDA_CHECK(cudaLaunchKernel(
(const void*)_allToAllV_2d_index_push_kernel,
dim3(nblocks),
dim3(THREADS_PER_BLOCK),
args,
0,
stream));
}
} // namespace c10d::nvshmem_extension
@ -876,4 +1372,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev);
m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d);
m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset);
m.impl("_make_a2a_exchange_plan", c10d::nvshmem_extension::_make_a2a_exchange_plan);
m.impl("_all_to_all_get", c10d::nvshmem_extension::_all_to_all_get);
m.impl("_make_a2a_2d_exchange_plan", c10d::nvshmem_extension::_make_a2a_2d_exchange_plan);
m.impl("_all_to_all_v_2d_index_push", c10d::nvshmem_extension::_all_to_all_v_2d_index_push);
}

View File

@ -4,12 +4,13 @@ import math
import os
import socket
import uuid
from collections import namedtuple
from collections.abc import Callable, Generator
from contextlib import contextmanager
from datetime import timedelta
from enum import Enum
from functools import partial
from typing import Any, Literal
from typing import Any, Literal, Optional
import torch
import torch.distributed._functional_collectives as funcol
@ -1824,4 +1825,107 @@ def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def]
return _SymmetricMemory.get_mempool_allocator(torch.device(device))
__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"]
# Create a type, ExchangePlan.
""" A namedtuple consisting of meta information which accelerates all_to_all operations.
- in_splits: splits of my input towards different peers.
- src_offsets: offsets within peers' input from which I should fetch data.
- out_splits: splits of peers' contribution to my output.
- dst_offsets: offsets within my output where I should store peers' contribution.
"""
ExchangePlan = namedtuple(
"ExchangePlan", ["in_splits", "src_offsets", "out_splits", "dst_offsets"]
)
def make_a2a_exchange_plan(
in_splits: torch.Tensor,
src_offsets: torch.Tensor,
out_splits: torch.Tensor,
dst_offsets: torch.Tensor,
group_name: str,
) -> ExchangePlan:
r"""
Create an all-to-all exchange plan given the input splits. This is a
collective operation.
Args:
in_splits (class:`torch.Tensor`): the input splits for the exchange plan (IN).
src_offsets (class:`torch.Tensor`): the source offsets for the exchange plan (OUT).
out_splits (class:`torch.Tensor`): the output splits for the exchange plan (OUT).
dst_offsets (class:`torch.Tensor`): the destination offsets for the exchange plan (OUT).
group_name (str): the group over which to exchange the splits and offsets.
Returns:
An `ExchangePlan` capturing the above tensors.
"""
torch.ops.symm_mem._make_a2a_exchange_plan(
in_splits, src_offsets, out_splits, dst_offsets, group_name
)
return ExchangePlan(in_splits, src_offsets, out_splits, dst_offsets)
def all_to_all_v(
input: torch.Tensor,
out: torch.Tensor,
plan: ExchangePlan,
group_name: str,
b_start: Optional[torch.Tensor] = None,
b_len: Optional[torch.Tensor] = None,
b_head: Optional[torch.Tensor] = None,
) -> None:
r"""
Perform an all-to-all-v operation given an `ExchangePlan`.
Args:
input (class:`torch.Tensor`): the input tensor for the all-to-all operation (IN).
out (class:`torch.Tensor`): the output tensor for the all-to-all operation (OUT).
plan (`ExchangePlan`): a tuple consisting of (in_splits, src_offsets, out_splits, dst_offsets).
group_name (str): the group over which to perform the all-to-all.
"""
# For now we use the get style, in future we can extend it to support the
# put style too, given a flag or something.
torch.ops.symm_mem._all_to_all_get(
input,
out,
plan.src_offsets,
plan.out_splits,
plan.dst_offsets,
group_name,
b_start,
b_len,
b_head,
)
def make_a2a_2d_exchange_plan(
in_splits: torch.Tensor,
src_offsets: torch.Tensor,
out_splits: torch.Tensor,
dst_offsets: torch.Tensor,
group_name: str,
) -> ExchangePlan:
r"""
Create an all-to-all-2d exchange plan given the input splits. This is a
collective operation.
Args:
in_splits (class:`torch.Tensor`): the input splits for the exchange plan (IN).
src_offsets (class:`torch.Tensor`): the source offsets for the exchange plan (OUT).
out_splits (class:`torch.Tensor`): the output splits for the exchange plan (OUT).
dst_offsets (class:`torch.Tensor`): the destination offsets for the exchange plan (OUT).
group_name (str): the group over which to exchange the splits and offsets.
Returns:
An `ExchangePlan` capturing the above tensors.
"""
torch.ops.symm_mem._make_a2a_2d_exchange_plan(
in_splits, src_offsets, out_splits, dst_offsets, group_name
)
return ExchangePlan(in_splits, src_offsets, out_splits, dst_offsets)
__all__ = [
"empty",
"rendezvous",
"is_nvshmem_available",
"set_backend",
"get_backend",
"ExchangePlan",
"make_a2a_exchange_plan",
"all_to_all_v",
]