mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Compare commits
3 Commits
cpp-docs-d
...
ciflow/h10
| Author | SHA1 | Date | |
|---|---|---|---|
| ece42ed689 | |||
| 05b3b1024b | |||
| 2fd0573646 |
@ -231,6 +231,36 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def get_occurrence_numbers(tensor):
|
||||
"""
|
||||
Transform tensor to show which occurrence each element is.
|
||||
|
||||
Example: tensor([1, 2, 1, 3, 1, 2]) -> tensor([1, 1, 2, 1, 3, 2])
|
||||
"""
|
||||
device = tensor.device
|
||||
# Get unique values and their inverse mapping
|
||||
unique_vals, inverse = torch.unique(tensor, return_inverse=True)
|
||||
|
||||
# Create a tensor to count occurrences for each unique value
|
||||
n_unique = len(unique_vals)
|
||||
n_elements = len(tensor)
|
||||
|
||||
# Create a matrix where each row corresponds to a unique value
|
||||
# and columns correspond to positions in the original tensor
|
||||
indicator_matrix = torch.zeros(
|
||||
n_unique, n_elements, dtype=torch.float, device=device
|
||||
)
|
||||
indicator_matrix[inverse, torch.arange(n_elements)] = 1.0
|
||||
|
||||
# Cumulative sum along columns gives us occurrence numbers
|
||||
occurrence_counts = torch.cumsum(indicator_matrix, dim=1) - indicator_matrix
|
||||
|
||||
# Extract the occurrence number for each position
|
||||
result = occurrence_counts[inverse, torch.arange(n_elements, device=device)]
|
||||
|
||||
return result.long()
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_nvshmem()
|
||||
@requires_cuda_p2p_access()
|
||||
@ -557,6 +587,314 @@ class NVSHMEMAll2AllTest(MultiProcContinuousTest):
|
||||
# Check data
|
||||
torch.testing.assert_close(out_expected, out[:out_numel])
|
||||
|
||||
@skipIfRocm
|
||||
def test_make_a2a_exchange_plan(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
# Number of elements for a peer is random between [0, k)
|
||||
k = 10
|
||||
orig_inp_splits = torch.randint(k, (self.world_size,), device=self.device)
|
||||
|
||||
# Create symm_mem tensors
|
||||
in_splits = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
src_offsets = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
out_splits = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
dst_offsets = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
in_splits.copy_(orig_inp_splits)
|
||||
|
||||
# Sync all ranks to ensure remote tensors are allocated
|
||||
dist.barrier()
|
||||
|
||||
symm_mem.make_a2a_exchange_plan(
|
||||
in_splits, src_offsets, out_splits, dst_offsets, group_name
|
||||
)
|
||||
|
||||
# Check input splits -- should not change
|
||||
torch.testing.assert_close(in_splits, orig_inp_splits)
|
||||
|
||||
# Check output splits
|
||||
# Exchange input splits to get output splits
|
||||
expected_out_splits = torch.zeros_like(orig_inp_splits)
|
||||
dist.all_to_all_single(expected_out_splits, orig_inp_splits)
|
||||
torch.testing.assert_close(expected_out_splits, out_splits)
|
||||
|
||||
# Check src offsets
|
||||
orig_src_offsets = torch.cumsum(orig_inp_splits, dim=0) # inclusive scan
|
||||
# Make it exclusive
|
||||
orig_src_offsets = torch.cat(
|
||||
[torch.zeros(1, device=self.device), orig_src_offsets[:-1]]
|
||||
).to(torch.int64)
|
||||
expected_src_offsets = torch.empty_like(orig_src_offsets)
|
||||
dist.all_to_all_single(expected_src_offsets, orig_src_offsets)
|
||||
torch.testing.assert_close(src_offsets, expected_src_offsets)
|
||||
|
||||
# Check dst offsets
|
||||
expected_dst_offsets = torch.cumsum(
|
||||
expected_out_splits, dim=0
|
||||
) # inclusive scan
|
||||
self.assertEqual(dst_offsets[0], 0)
|
||||
torch.testing.assert_close(dst_offsets[1:], expected_dst_offsets[:-1])
|
||||
|
||||
@skipIfRocm
|
||||
def test_a2a_with_exchange_plan(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
# Number of elements for a peer is random between [0, k)
|
||||
k = 10
|
||||
orig_inp_splits = torch.randint(k, (self.world_size,), device=self.device)
|
||||
|
||||
# Create splits and offsets
|
||||
in_splits = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
src_offsets = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
out_splits = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
dst_offsets = symm_mem.empty(
|
||||
self.world_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
# Create data
|
||||
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
|
||||
max_inp_numel = k * self.world_size
|
||||
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
|
||||
overflow_factor = self.world_size # worst case: one rank receives all data
|
||||
max_out_numel = max_inp_numel * overflow_factor
|
||||
dtype = torch.float
|
||||
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).copy_(
|
||||
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
|
||||
)
|
||||
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
|
||||
|
||||
in_splits.copy_(orig_inp_splits)
|
||||
|
||||
# Sync all ranks to ensure remote tensors are allocated
|
||||
dist.barrier()
|
||||
|
||||
# Create exchange plan
|
||||
plan = symm_mem.make_a2a_exchange_plan(
|
||||
in_splits, src_offsets, out_splits, dst_offsets, group_name
|
||||
)
|
||||
|
||||
# Prepare expected output
|
||||
inp_numel = in_splits.sum().item()
|
||||
out_numel = out_splits.sum().item()
|
||||
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
|
||||
dist.all_to_all_single(
|
||||
expected, inp[:inp_numel], out_splits.tolist(), in_splits.tolist()
|
||||
)
|
||||
|
||||
# Exchange data with plan
|
||||
# Loop a couple times to ensure the plan is reusable
|
||||
for _ in range(3):
|
||||
symm_mem.all_to_all_v(inp, out, plan, group_name)
|
||||
torch.testing.assert_close(out[:out_numel], expected)
|
||||
|
||||
@skipIfRocm
|
||||
@parametrize("align", [1]) # `major_align` of output
|
||||
def test_make_a2a_2d_exchange_plan(self, align: int) -> None:
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
# Number of experts per rank
|
||||
ne = 8
|
||||
nsplits = ne * self.world_size
|
||||
|
||||
# Number of elements for an expert is random between [0, k)
|
||||
k = 10
|
||||
orig_inp_splits = torch.randint(
|
||||
k, (nsplits,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
# Create symm_mem tensors
|
||||
in_splits = symm_mem.empty(nsplits, dtype=torch.int64, device=self.device)
|
||||
src_offsets = symm_mem.empty(nsplits, dtype=torch.int64, device=self.device)
|
||||
out_splits = symm_mem.empty(
|
||||
nsplits, dtype=torch.int64, device=self.device
|
||||
).fill_(0)
|
||||
dst_offsets = symm_mem.empty(
|
||||
nsplits, dtype=torch.int64, device=self.device
|
||||
).fill_(0)
|
||||
|
||||
in_splits.copy_(orig_inp_splits)
|
||||
|
||||
# Sync all ranks to ensure remote tensors are allocated
|
||||
dist.barrier()
|
||||
|
||||
plan = symm_mem.make_a2a_2d_exchange_plan(
|
||||
in_splits, src_offsets, out_splits, dst_offsets, group_name
|
||||
)
|
||||
|
||||
# Exchange input splits to get output splits
|
||||
expected_out_splits = torch.zeros_like(orig_inp_splits)
|
||||
dist.all_to_all_single(expected_out_splits, orig_inp_splits)
|
||||
# We do a .t() here because there is a rank-major to expert-major shuffle
|
||||
expected_out_splits = expected_out_splits.reshape(self.world_size, ne).t()
|
||||
torch.testing.assert_close(plan.out_splits, expected_out_splits.reshape(-1))
|
||||
|
||||
# Check dst offsets
|
||||
out_split_list = expected_out_splits.tolist()
|
||||
for i in range(ne):
|
||||
expert_sum = 0
|
||||
for j in range(self.world_size):
|
||||
expert_sum += out_split_list[i][j]
|
||||
# # Align up expert_sum
|
||||
# expert_sum_aligned = (expert_sum + align - 1) // align * align
|
||||
# # If 0, make it at least `align` (bc cutlass currently does not support empty bins)
|
||||
# expert_sum_aligned = max(expert_sum_aligned, align)
|
||||
# # last element absorbs the padding
|
||||
# out_split_list[i][-1] += expert_sum_aligned - expert_sum
|
||||
|
||||
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
|
||||
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
|
||||
# Make it exclusive scan because that's what `all_to_all_vdev_2d` returns
|
||||
out_offsets = torch.cat(
|
||||
[torch.zeros(1, device=self.device), out_offsets[:-1]]
|
||||
).to(torch.int64)
|
||||
expected_dst_offsets = torch.empty(
|
||||
nsplits, dtype=torch.int64, device=self.device
|
||||
)
|
||||
dist.all_to_all_single(
|
||||
expected_dst_offsets,
|
||||
out_offsets.reshape(ne, self.world_size).t().contiguous(),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
expected_dst_offsets,
|
||||
plan.dst_offsets,
|
||||
msg=f"""
|
||||
Expecting
|
||||
{expected_dst_offsets}
|
||||
Got
|
||||
{plan.dst_offsets}""",
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
def test_all_to_all_v_2d_index_push(self) -> None:
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
# Number of experts per rank
|
||||
ne = 4
|
||||
tot_experts = ne * self.world_size
|
||||
|
||||
# Create topk indices of shape (n_tokens, topk)
|
||||
topk = 2
|
||||
n_tokens = 128
|
||||
topk_indices = torch.randint(
|
||||
tot_experts, (n_tokens, topk), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
# Convert indices to splits
|
||||
orig_inp_splits = torch.histc(
|
||||
topk_indices,
|
||||
bins=tot_experts,
|
||||
)
|
||||
|
||||
# Create symm_mem tensors
|
||||
in_splits = symm_mem.empty(
|
||||
tot_experts, dtype=torch.int64, device=self.device
|
||||
).copy_(orig_inp_splits)
|
||||
src_offsets = symm_mem.empty(tot_experts, dtype=torch.int64, device=self.device)
|
||||
out_splits = symm_mem.empty(
|
||||
tot_experts, dtype=torch.int64, device=self.device
|
||||
).fill_(0)
|
||||
dst_offsets = symm_mem.empty(
|
||||
tot_experts, dtype=torch.int64, device=self.device
|
||||
).fill_(0)
|
||||
|
||||
# Sync all ranks to ensure remote tensors are allocated
|
||||
dist.barrier()
|
||||
|
||||
plan = symm_mem.make_a2a_2d_exchange_plan(
|
||||
in_splits, src_offsets, out_splits, dst_offsets, group_name
|
||||
)
|
||||
|
||||
# Create data
|
||||
max_out_tokens = n_tokens * self.world_size
|
||||
dtype = torch.float
|
||||
hid_dim = 1024
|
||||
inp = symm_mem.empty(n_tokens, hid_dim, dtype=dtype, device=self.device).copy_(
|
||||
torch.randn(n_tokens, hid_dim, dtype=dtype, device=self.device)
|
||||
)
|
||||
out = symm_mem.empty(
|
||||
max_out_tokens, hid_dim, dtype=dtype, device=self.device
|
||||
).fill_(-1)
|
||||
|
||||
# Figure out rank of each token in its expert chunk
|
||||
occurrences = get_occurrence_numbers(topk_indices.view(-1))
|
||||
|
||||
# Number of CUDA blocks (random choice)
|
||||
n_blocks = 2
|
||||
# Evenly spread token to CUDA blocks
|
||||
tokens_per_block = n_tokens // n_blocks
|
||||
# Start offset of each CUDA block
|
||||
b_start = torch.arange(
|
||||
0, n_tokens, tokens_per_block, dtype=torch.int64, device=self.device
|
||||
)
|
||||
# Number of tokens for each CUDA block
|
||||
b_len = torch.full(
|
||||
(n_blocks,), tokens_per_block, dtype=torch.int64, device=self.device
|
||||
)
|
||||
# Ready signal for each CUDA block. In this test we set all tokens as ready in one shot
|
||||
b_head = b_start + b_len
|
||||
|
||||
dist.barrier()
|
||||
|
||||
torch.ops.symm_mem._all_to_all_v_2d_index_push(
|
||||
inp,
|
||||
out,
|
||||
topk_indices,
|
||||
occurrences,
|
||||
plan.dst_offsets,
|
||||
group_name,
|
||||
b_start,
|
||||
b_len,
|
||||
b_head,
|
||||
)
|
||||
|
||||
# Check data using all_to_all_vdev_2d
|
||||
# Token sequence is inflated topk times
|
||||
expanded_seqlen = n_tokens * topk
|
||||
sorted_indices = torch.argsort(topk_indices.view(-1))
|
||||
expanded_inp = symm_mem.empty(
|
||||
expanded_seqlen, hid_dim, dtype=dtype, device=self.device
|
||||
).copy_(inp[sorted_indices // topk])
|
||||
overflow = 2
|
||||
expected_out = symm_mem.empty(
|
||||
expanded_seqlen * overflow, hid_dim, dtype=dtype, device=self.device
|
||||
)
|
||||
out_splits_offsets = symm_mem.empty(
|
||||
(2, tot_experts), dtype=torch.int64, device=self.device
|
||||
)
|
||||
dist.barrier()
|
||||
torch.ops.symm_mem.all_to_all_vdev_2d(
|
||||
expanded_inp, expected_out, in_splits, out_splits_offsets, group_name
|
||||
)
|
||||
|
||||
# Check data
|
||||
out_len = out_splits_offsets[1][-1] + out_splits_offsets[0][-1]
|
||||
torch.testing.assert_close(out[:out_len], expected_out[:out_len])
|
||||
|
||||
|
||||
# Help function used by multiple tests
|
||||
def dispatch_then_combine(device, align: int, group) -> None:
|
||||
|
||||
@ -510,6 +510,14 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
|
||||
m.def(
|
||||
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
|
||||
m.def(
|
||||
"_make_a2a_exchange_plan(Tensor in_splits, Tensor(a!) src_offsets, Tensor(a!) out_splits, Tensor(a!) dst_offsets, str group_name) -> ()");
|
||||
m.def(
|
||||
"_all_to_all_get(Tensor input, Tensor(a!) out, Tensor src_offsets, Tensor out_splits, Tensor dst_offsets, str group_name, Tensor? b_start, Tensor? b_len, Tensor? b_head) -> ()");
|
||||
m.def(
|
||||
"_make_a2a_2d_exchange_plan(Tensor in_splits, Tensor(a!) src_offsets, Tensor(a!) out_splits, Tensor(a!) dst_offsets, str group_name) -> ()");
|
||||
m.def(
|
||||
"_all_to_all_v_2d_index_push(Tensor input, Tensor(a!) out, Tensor topk_indices, Tensor occurrences, Tensor dst_offsets, str group_name, Tensor b_start, Tensor b_len, Tensor b_head) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
||||
|
||||
@ -198,19 +198,16 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
|
||||
}
|
||||
|
||||
// This kernel is used to exchange output splits and source offsets between peers.
|
||||
// `in_out_splits` is of size (3, npes) and contains:
|
||||
// - input splits (IN)
|
||||
// - output splits (OUT) and
|
||||
// - source offsets (OUT).
|
||||
__global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_splits_offsets, nvshmem_team_t team) {
|
||||
__device__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* source_offsets, int64_t* output_splits, nvshmem_team_t team) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
|
||||
int mype = nvshmem_team_my_pe(team);
|
||||
int npes = nvshmem_team_n_pes(team);
|
||||
auto output_splits = out_splits_offsets;
|
||||
auto source_offsets = out_splits_offsets + npes;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
|
||||
@ -232,27 +229,38 @@ __global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_split
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is a kernel wrapper for `exchangeSplitAndOffset`.
|
||||
__global__ void exchangeSplitAndOffsetKernel(int64_t* in_splits, int64_t* src_offsets, int64_t* out_splits, nvshmem_team_t team) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
exchangeSplitAndOffset(in_splits, src_offsets, out_splits, team);
|
||||
#endif
|
||||
}
|
||||
|
||||
// This kernel is used to do the actual data exchange.
|
||||
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
|
||||
// `stride` is the stride at dim 0, unit in byte.
|
||||
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_offsets, size_t stride, nvshmem_team_t team) {
|
||||
__global__ void allToAllGet(void *send_data, void *recv_data, int64_t* source_offsets, int64_t* output_splits, int64_t* dst_offsets, size_t stride, nvshmem_team_t team) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
|
||||
int mype = nvshmem_team_my_pe(team);
|
||||
int npes = nvshmem_team_n_pes(team);
|
||||
auto output_splits = out_splits_offsets;
|
||||
auto source_offsets = out_splits_offsets + npes;
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int blocks_per_peer = max(gridDim.x / npes, 1);
|
||||
|
||||
// Calculate the output offsets
|
||||
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
|
||||
// Calculate the output offsets if not provided
|
||||
__shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
|
||||
prefixSum(peer_offsets, output_splits, npes);
|
||||
__syncthreads();
|
||||
int64_t* out_offsets = dst_offsets;
|
||||
if (out_offsets == nullptr) {
|
||||
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
|
||||
prefixSum(peer_offsets, output_splits, npes);
|
||||
__syncthreads();
|
||||
out_offsets = peer_offsets;
|
||||
}
|
||||
|
||||
// Target a different peer based on bid
|
||||
for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) {
|
||||
@ -267,22 +275,81 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_
|
||||
// This block's offset in the data from `peer`
|
||||
auto block_offset = block_size * (bid % blocks_per_peer);
|
||||
auto source_offset = source_offsets[peer] * stride + block_offset;
|
||||
auto write_offset = peer_offsets[peer] * stride + block_offset;
|
||||
auto write_offset = out_offsets[peer] * stride + block_offset;
|
||||
nvshmemx_getmem_nbi_block(
|
||||
(char*)recv_data + write_offset,
|
||||
(char*)send_data + source_offset,
|
||||
block_size,
|
||||
peer_global);
|
||||
}
|
||||
// Write out the output offsets (to the scratchpad line)
|
||||
if (bid == 0 && tid < npes) {
|
||||
source_offsets[tid] = peer_offsets[tid];
|
||||
|
||||
// Write out the output offsets if not provided
|
||||
if (dst_offsets == nullptr) {
|
||||
if (bid == 0 && tid < npes) {
|
||||
// source_offsets alias dst_offsets space when dst_offsets is not provided
|
||||
source_offsets[tid] = out_offsets[tid];
|
||||
}
|
||||
}
|
||||
// Make sure getmem_nbi calls finish
|
||||
nvshmem_quiet();
|
||||
#endif
|
||||
}
|
||||
|
||||
// This kernel extends `allToAllGet` with an out signal that indicates readiness of a "token" after each fetch.
|
||||
// `b_start`, `b_len` and `b_head` are fed into `_allToAllV_2d_index_push_kernel`.
|
||||
__global__ void allToAllGet_signal_out(
|
||||
void *send_data, void *recv_data, int64_t* source_offsets, int64_t* output_splits, int64_t* dst_offsets, size_t stride, nvshmem_team_t team,
|
||||
int64_t* b_start, int64_t* b_len, int64_t* b_head) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID);
|
||||
int mype = nvshmem_team_my_pe(team);
|
||||
int npes = nvshmem_team_n_pes(team);
|
||||
int bid = blockIdx.x;
|
||||
CUDA_KERNEL_ASSERT(gridDim.x % npes == 0 && " Number of blocks must be multiple of npes\n");
|
||||
int blocks_per_peer = gridDim.x / npes;
|
||||
|
||||
// dst_offset must be provided to this kernel e.g. via exchange plan
|
||||
CUDA_KERNEL_ASSERT(dst_offsets != nullptr && " dst_offset must be provided\n");
|
||||
int64_t* out_offsets = dst_offsets;
|
||||
|
||||
// Target a different peer based on bid, in shifting manner
|
||||
int peer_shift = bid / blocks_per_peer;
|
||||
int peer = (mype + peer_shift) % npes;
|
||||
auto peer_global = nvshmem_team_translate_pe(team, peer, NVSHMEM_TEAM_WORLD);
|
||||
// Total number of tokens from `peer`
|
||||
auto peer_tokens = output_splits[peer];
|
||||
// Amount to get from `peer` in this block
|
||||
int64_t block_tokens = peer_tokens / blocks_per_peer;
|
||||
// Being lazy here, we should handle the residual if the division is not exact
|
||||
CUDA_KERNEL_ASSERT(block_tokens * blocks_per_peer == peer_tokens);
|
||||
// Assign to b_len
|
||||
b_len[bid] = block_tokens;
|
||||
// This block's offset in the data from `peer`, all 3 offsets below are in unit of token
|
||||
auto block_offset = block_tokens * (bid % blocks_per_peer);
|
||||
auto source_offset = source_offsets[peer] + block_offset;
|
||||
int64_t write_offset = out_offsets[peer] + block_offset;
|
||||
// Assign to b_start
|
||||
b_start[bid] = write_offset;
|
||||
auto b_head_ptr = b_head + bid;
|
||||
|
||||
// Now let's start data fetch, token by token
|
||||
for (int i = 0; i < block_tokens; i++) {
|
||||
auto write_head = write_offset + i;
|
||||
// Use blocking API for now
|
||||
nvshmemx_getmem_block(
|
||||
(char*)recv_data + write_head * stride,
|
||||
(char*)send_data + (source_offset + i) * stride,
|
||||
stride, // byte size of 1 token
|
||||
peer_global);
|
||||
// Advance ready signal, use release here as a memory fence, scope is local GPU.
|
||||
// Writing `write_head + 1` because that's the protocol with downstream kernel.
|
||||
asm volatile("st.release.gpu.global.s64 [%0], %1;" :: "l"(__cvta_generic_to_global(b_head_ptr)), "l"(write_head + 1) : "memory");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
|
||||
// Check user setting first
|
||||
int num_blocks = c10d::symmetric_memory::getenv_nblocks();
|
||||
@ -298,6 +365,88 @@ static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) {
|
||||
return std::min(num_blocks, max_blocks);
|
||||
}
|
||||
|
||||
// Perform a 1D AllToAllv shuffle operation, with source offset information, and get operations.
|
||||
// ** Caller must ensure that the splits and offsets are (i) symmetric addresses and (ii) have been rendezvoused. **
|
||||
void _all_to_all_get_inner(
|
||||
at::Tensor& input,
|
||||
at::Tensor& out,
|
||||
int64_t* src_offsets_ptr,
|
||||
const int64_t* out_splits_ptr,
|
||||
int64_t* dst_offsets_ptr,
|
||||
std::string group_name,
|
||||
std::optional<nvshmem_team_t> team,
|
||||
const std::optional<at::Tensor>& b_start = std::nullopt,
|
||||
const std::optional<at::Tensor>& b_len = std::nullopt,
|
||||
const std::optional<at::Tensor>& b_head = std::nullopt) {
|
||||
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
|
||||
c10d::symmetric_memory::rendezvous(out, group_name);
|
||||
|
||||
// Verify inputs
|
||||
TORCH_CHECK_EQ(input.device(), out.device());
|
||||
TORCH_CHECK(input.dtype() == out.dtype(), "input and out must have the same dtype");
|
||||
TORCH_CHECK(input.stride(0) == out.stride(0), "input and out must have the same stride at dim 0");
|
||||
TORCH_CHECK(input.is_contiguous() && out.is_contiguous(), "input and out must be contiguous");
|
||||
|
||||
auto device = input.device();
|
||||
c10::cuda::CUDAGuard guard(device);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
// Get team if not provided
|
||||
nvshmem_team_t team_to_use;
|
||||
if (team.has_value()) {
|
||||
team_to_use = team.value();
|
||||
} else {
|
||||
auto& team_manager = TeamManager::get(device);
|
||||
team_to_use = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank());
|
||||
}
|
||||
|
||||
bool has_signal = b_head.has_value();
|
||||
|
||||
// CTA Tuning
|
||||
auto input_size = input.numel() * input.element_size();
|
||||
int num_blocks = has_signal ?
|
||||
b_head.value().size(0) :
|
||||
get_a2a_nblocks(
|
||||
input_size,
|
||||
input_hdl->get_world_size(),
|
||||
input_hdl->world_within_direct_access());
|
||||
|
||||
// Stride at dim 0
|
||||
size_t stride_bytes = input.stride(0) * input.element_size();
|
||||
auto input_ptr = input.const_data_ptr();
|
||||
auto output_ptr = out.mutable_data_ptr();
|
||||
|
||||
// Required args
|
||||
std::vector<void*> args = {
|
||||
&input_ptr,
|
||||
&output_ptr,
|
||||
&src_offsets_ptr,
|
||||
&out_splits_ptr,
|
||||
&dst_offsets_ptr,
|
||||
&stride_bytes,
|
||||
&team_to_use};
|
||||
|
||||
// Optional args for signals
|
||||
void *b_start_ptr, *b_len_ptr, *b_head_ptr;
|
||||
if (has_signal) {
|
||||
b_start_ptr = b_start.value().mutable_data_ptr();
|
||||
b_len_ptr = b_len.value().mutable_data_ptr();
|
||||
b_head_ptr = b_head.value().mutable_data_ptr();
|
||||
args.push_back(b_start_ptr);
|
||||
args.push_back(b_len_ptr);
|
||||
args.push_back(b_head_ptr);
|
||||
}
|
||||
|
||||
auto functor = has_signal ? (const void*)allToAllGet_signal_out : (const void*)allToAllGet;
|
||||
C10_CUDA_CHECK(cudaLaunchKernel(
|
||||
functor,
|
||||
dim3(num_blocks),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args.data(),
|
||||
0,
|
||||
stream));
|
||||
}
|
||||
|
||||
void all_to_all_vdev(
|
||||
at::Tensor& input,
|
||||
at::Tensor& out,
|
||||
@ -312,63 +461,40 @@ void all_to_all_vdev(
|
||||
* - `out_splits_offsets` is a 2D tensor of size (2, npes). The rows are (in order):
|
||||
output splits and output offsets.
|
||||
*/
|
||||
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
|
||||
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
|
||||
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
|
||||
auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name);
|
||||
int rank = input_hdl->get_rank();
|
||||
int world_size = input_hdl->get_world_size();
|
||||
auto symm_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
|
||||
c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name);
|
||||
int world_size = symm_hdl->get_world_size();
|
||||
|
||||
void* input_ptr = input.data_ptr();
|
||||
void* output_ptr = out.mutable_data_ptr();
|
||||
int64_t* in_splits_ptr = (int64_t*)(in_splits.const_data_ptr());
|
||||
int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr());
|
||||
|
||||
TORCH_CHECK_EQ(input.device(), out.device());
|
||||
auto device = input.device();
|
||||
TORCH_CHECK_EQ(in_splits.device(), out_splits_offsets.device());
|
||||
auto device = in_splits.device();
|
||||
c10::cuda::CUDAGuard guard(device);
|
||||
auto& team_manager = TeamManager::get(device);
|
||||
auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank());
|
||||
auto team = team_manager.get_team(group_name, symm_hdl->get_rank_to_global_rank());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
// Exchange output splits and source offsets
|
||||
// Use collective launch because kernel involves nvshmem barrier
|
||||
// Borrowing the space of out_offsets as a temporary exchange pad for source offsets.
|
||||
auto tmp_src_offsets_ptr = out_splits_offsets_ptr + world_size;
|
||||
void* args0[] = {
|
||||
&in_splits_ptr,
|
||||
&tmp_src_offsets_ptr,
|
||||
&out_splits_offsets_ptr,
|
||||
&team};
|
||||
// Use collective launch because kernel involves nvshmem barrier
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)exchangeSplitAndOffset,
|
||||
(const void*)exchangeSplitAndOffsetKernel,
|
||||
dim3(1),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args0,
|
||||
0,
|
||||
stream);
|
||||
|
||||
// CTA Tuning
|
||||
auto input_size = input.numel() * input.element_size();
|
||||
int num_blocks = get_a2a_nblocks(
|
||||
input_size,
|
||||
input_hdl->get_world_size(),
|
||||
input_hdl->world_within_direct_access());
|
||||
|
||||
// Stride at dim 0 (assuming input is contiguous, TODO)
|
||||
size_t stride_bytes = input.stride(0) * input.element_size();
|
||||
|
||||
// All to all data exchange
|
||||
void* args1[] = {
|
||||
&input_ptr,
|
||||
&output_ptr,
|
||||
&out_splits_offsets_ptr,
|
||||
&stride_bytes,
|
||||
&team};
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)allToAllV,
|
||||
dim3(num_blocks),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args1,
|
||||
0,
|
||||
stream);
|
||||
// Get data based on exchange plan
|
||||
_all_to_all_get_inner(
|
||||
input, out, tmp_src_offsets_ptr, out_splits_offsets_ptr, nullptr, group_name, team);
|
||||
}
|
||||
|
||||
// Start of `all_to_all_vdev_2d`
|
||||
@ -863,6 +989,376 @@ void all_to_all_vdev_2d_offset(
|
||||
0,
|
||||
stream);
|
||||
}
|
||||
|
||||
// This kernel is used to exchange output splits and source offsets between peers.
|
||||
__global__ void makeExchangePlan(int64_t* in_splits, int64_t* src_offsets, int64_t* out_splits, int64_t* dst_offsets, nvshmem_team_t team) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
// Given input splits, exchange the input offsets output splits
|
||||
exchangeSplitAndOffset(in_splits, src_offsets, out_splits, team);
|
||||
|
||||
// Calculate the output offsets
|
||||
int npes = nvshmem_team_n_pes(team);
|
||||
CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK);
|
||||
__shared__ int64_t out_offsets[THREADS_PER_BLOCK];
|
||||
prefixSum(out_offsets, out_splits, npes);
|
||||
__syncthreads();
|
||||
|
||||
// Write out the output offsets
|
||||
int tid = threadIdx.x;
|
||||
if (tid < npes) {
|
||||
dst_offsets[tid] = out_offsets[tid];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void _make_a2a_exchange_plan(
|
||||
at::Tensor& in_splits,
|
||||
at::Tensor& src_offsets,
|
||||
at::Tensor& out_splits,
|
||||
at::Tensor& dst_offsets,
|
||||
std::string group_name) {
|
||||
// Make an exchange plan for a 1D AllToAllv shuffle operation.
|
||||
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
|
||||
auto src_offsets_hdl = c10d::symmetric_memory::rendezvous(src_offsets, group_name);
|
||||
auto out_splits_hdl = c10d::symmetric_memory::rendezvous(out_splits, group_name);
|
||||
auto dst_offsets_hdl = c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
|
||||
|
||||
// Verify inputs
|
||||
auto npes = in_splits_hdl->get_world_size();
|
||||
TORCH_CHECK(npes <= THREADS_PER_BLOCK, "Number of peers must be smaller than THREADS_PER_BLOCK", THREADS_PER_BLOCK);
|
||||
TORCH_CHECK(in_splits.size(0) == npes && src_offsets.size(0) == npes && out_splits.size(0) == npes && dst_offsets.size(0) == npes,
|
||||
"in_splits, src_offsets, out_splits and dst_offsets must have the same size as world_size");
|
||||
TORCH_CHECK(in_splits.scalar_type() == at::kLong && src_offsets.scalar_type() == at::kLong
|
||||
&& out_splits.scalar_type() == at::kLong && dst_offsets.scalar_type() == at::kLong,
|
||||
"splits and offsets must be int64");
|
||||
|
||||
auto in_splits_ptr = in_splits.const_data_ptr<int64_t>();
|
||||
auto src_offsets_ptr = src_offsets.mutable_data_ptr<int64_t>();
|
||||
auto out_splits_ptr = out_splits.mutable_data_ptr<int64_t>();
|
||||
auto dst_offsets_ptr = dst_offsets.mutable_data_ptr<int64_t>();
|
||||
|
||||
auto device = in_splits.device();
|
||||
c10::cuda::CUDAGuard guard(device);
|
||||
auto& team_manager = TeamManager::get(device);
|
||||
auto team = team_manager.get_team(group_name, in_splits_hdl->get_rank_to_global_rank());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
// Exchange output splits and source offsets
|
||||
// Use collective launch because kernel involves nvshmem barrier
|
||||
void* args0[] = {
|
||||
&in_splits_ptr,
|
||||
&src_offsets_ptr,
|
||||
&out_splits_ptr,
|
||||
&dst_offsets_ptr,
|
||||
&team};
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)makeExchangePlan,
|
||||
dim3(1),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args0,
|
||||
0,
|
||||
stream);
|
||||
}
|
||||
|
||||
void _all_to_all_get(
|
||||
at::Tensor& input,
|
||||
at::Tensor& out,
|
||||
at::Tensor& src_offsets,
|
||||
at::Tensor& out_splits,
|
||||
at::Tensor& dst_offsets,
|
||||
std::string group_name,
|
||||
const std::optional<at::Tensor>& b_start = std::nullopt,
|
||||
const std::optional<at::Tensor>& b_len = std::nullopt,
|
||||
const std::optional<at::Tensor>& b_head = std::nullopt) {
|
||||
// Perform a 1D AllToAllv shuffle operation, with source offset information, and get operations.
|
||||
c10d::symmetric_memory::rendezvous(src_offsets, group_name);
|
||||
c10d::symmetric_memory::rendezvous(out_splits, group_name);
|
||||
c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
|
||||
|
||||
auto src_offsets_ptr = src_offsets.mutable_data_ptr<int64_t>();
|
||||
auto out_splits_ptr = out_splits.const_data_ptr<int64_t>();
|
||||
auto dst_offsets_ptr = dst_offsets.mutable_data_ptr<int64_t>();
|
||||
|
||||
_all_to_all_get_inner(
|
||||
input, out, src_offsets_ptr, out_splits_ptr, dst_offsets_ptr, group_name,
|
||||
/*team=*/ std::nullopt, // determined inside
|
||||
b_start, b_len, b_head // optional signals
|
||||
);
|
||||
}
|
||||
|
||||
/* 2D all-to-all-v exchange plan */
|
||||
|
||||
#define MAX_N_PEERS (THREADS_PER_BLOCK / WARP_SIZE)
|
||||
#define MAX_LOCAL_EXPERTS 8
|
||||
#define MAX_N_EXPERTS (MAX_N_PEERS * MAX_LOCAL_EXPERTS)
|
||||
|
||||
|
||||
// This kernel is used to exchange output splits and dest offsets between peers.
|
||||
__global__ void make2dExchangePlan(int64_t* in_splits, int64_t* src_offsets, int64_t* out_splits, int64_t* dst_offsets, nvshmem_team_t team, int ne) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
int peer = threadIdx.x / WARP_SIZE;
|
||||
int my_pe = nvshmem_team_my_pe(team);
|
||||
int npes = nvshmem_team_n_pes(team);
|
||||
int nsplits = ne * npes;
|
||||
CUDA_KERNEL_ASSERT(npes <= MAX_N_PEERS && " Number of peers must be smaller than MAX_N_PEERS\n");
|
||||
CUDA_KERNEL_ASSERT(nsplits <= MAX_N_EXPERTS && " Number of splits must be smaller than MAX_N_EXPERTS\n");
|
||||
|
||||
// Gather input splits from peers
|
||||
__shared__ int64_t in_splits_all[MAX_N_PEERS][MAX_N_EXPERTS];
|
||||
// Use 1 warp per peer, to get into shared memory
|
||||
if (peer < npes) {
|
||||
auto global_peer_id = nvshmem_team_translate_pe(team, peer, NVSHMEM_TEAM_WORLD);
|
||||
nvshmemx_int64_get_warp(in_splits_all[peer], in_splits, nsplits, global_peer_id);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write to output splits
|
||||
int local_expert = threadIdx.x % WARP_SIZE;
|
||||
if (local_expert < ne) {
|
||||
// experts I own
|
||||
int expert_id = ne * my_pe + local_expert;
|
||||
// out_splits is (ne, npes) flattened
|
||||
out_splits[local_expert * npes + peer] = in_splits_all[peer][expert_id];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// This is the aggregate per expert. There are a total of npes * ne experts,
|
||||
// and we arrange it as a 2D matrix
|
||||
__shared__ int64_t aggregate_per_expert[MAX_N_PEERS][MAX_LOCAL_EXPERTS];
|
||||
// Calculate the cusum of splits for an expert
|
||||
// One thread per expert, for all experts
|
||||
auto& cusum_per_expert = in_splits_all; // in-place prefix sum
|
||||
if (peer < npes && local_expert < ne) {
|
||||
int expert = peer * ne + local_expert;
|
||||
int64_t tmp = 0, cusum = 0;
|
||||
// This prefix sum is in the row direction, thus we cannot use cub helper
|
||||
for (int rank = 0; rank < npes; rank++) {
|
||||
tmp = in_splits_all[rank][expert];
|
||||
cusum_per_expert[rank][expert] = cusum;
|
||||
cusum += tmp;
|
||||
}
|
||||
aggregate_per_expert[peer][local_expert] = cusum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Calculate each expert's offset within their ranks
|
||||
// There are two solutions below. The first one uses warpScan, and the second
|
||||
// one uses manual prefix sum. warpScan's result is not stable, so we use the
|
||||
// second one for now. (Since ne is usually small, the performance difference
|
||||
// is negligible)
|
||||
#if 0
|
||||
// Solution 1: use warpScan
|
||||
__shared__ int64_t cusum_per_rank[MAX_N_PEERS][MAX_LOCAL_EXPERTS];
|
||||
// One warp per target rank
|
||||
prefixSum_warp<MAX_N_PEERS>(cusum_per_rank[peer], aggregate_per_expert[peer], ne);
|
||||
#else
|
||||
// Solution 2: manual sum
|
||||
auto& cusum_per_rank = aggregate_per_expert; // in-place prefix sum
|
||||
if (peer < npes && local_expert == 0) {
|
||||
int64_t tmp = 0, cusum = 0;
|
||||
// Summing the experts of a rank
|
||||
for (int e = 0; e < ne; e++) {
|
||||
tmp = aggregate_per_expert[peer][e];
|
||||
cusum_per_rank[peer][e] = cusum;
|
||||
cusum += tmp;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
__syncthreads();
|
||||
|
||||
// Now add the in-rank offsets to the in-expert offsets, then "I" will know
|
||||
// where to write my data in the dest rank
|
||||
if (peer < npes && local_expert < ne) {
|
||||
int expert = peer * ne + local_expert;
|
||||
dst_offsets[expert] = cusum_per_expert[my_pe][expert] + cusum_per_rank[peer][local_expert];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void _make_a2a_2d_exchange_plan(
|
||||
at::Tensor& in_splits,
|
||||
at::Tensor& src_offsets,
|
||||
at::Tensor& out_splits,
|
||||
at::Tensor& dst_offsets,
|
||||
std::string group_name) {
|
||||
// Make an exchange plan for a AllToAllv_2D shuffle operation.
|
||||
auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name);
|
||||
auto src_offsets_hdl = c10d::symmetric_memory::rendezvous(src_offsets, group_name);
|
||||
auto out_splits_hdl = c10d::symmetric_memory::rendezvous(out_splits, group_name);
|
||||
auto dst_offsets_hdl = c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
|
||||
|
||||
// Verify inputs
|
||||
auto npes = in_splits_hdl->get_world_size();
|
||||
int nsplits = in_splits.size(0);
|
||||
TORCH_CHECK(nsplits % npes == 0, "Number of splits must be a multiple of number of peers");
|
||||
TORCH_CHECK(src_offsets.size(0) == nsplits && out_splits.size(0) == nsplits && dst_offsets.size(0) == nsplits,
|
||||
"in_splits, src_offsets, out_splits and dst_offsets must have the same size");
|
||||
TORCH_CHECK(in_splits.scalar_type() == at::kLong && src_offsets.scalar_type() == at::kLong
|
||||
&& out_splits.scalar_type() == at::kLong && dst_offsets.scalar_type() == at::kLong,
|
||||
"splits and offsets must be int64");
|
||||
// Number of experts per rank
|
||||
int ne = nsplits / npes;
|
||||
|
||||
auto in_splits_ptr = in_splits.const_data_ptr<int64_t>();
|
||||
auto src_offsets_ptr = src_offsets.mutable_data_ptr<int64_t>();
|
||||
auto out_splits_ptr = out_splits.mutable_data_ptr<int64_t>();
|
||||
auto dst_offsets_ptr = dst_offsets.mutable_data_ptr<int64_t>();
|
||||
|
||||
auto device = in_splits.device();
|
||||
c10::cuda::CUDAGuard guard(device);
|
||||
auto& team_manager = TeamManager::get(device);
|
||||
auto team = team_manager.get_team(group_name, in_splits_hdl->get_rank_to_global_rank());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
// Exchange output splits and source offsets
|
||||
// Use collective launch because kernel involves nvshmem barrier
|
||||
void* args0[] = {
|
||||
&in_splits_ptr,
|
||||
&src_offsets_ptr,
|
||||
&out_splits_ptr,
|
||||
&dst_offsets_ptr,
|
||||
&team,
|
||||
&ne};
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)make2dExchangePlan,
|
||||
dim3(1),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args0,
|
||||
0,
|
||||
stream);
|
||||
}
|
||||
|
||||
/* ld.volatile.global: This refers to a volatile load from global memory. This
|
||||
* ensures that the data is read directly from memory every time, preventing the
|
||||
* compiler from optimizing the access by referring to a cached value. */
|
||||
__device__ __forceinline__ int64_t ld_volatile_global(int64_t *ptr) {
|
||||
int64_t ans;
|
||||
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ans) : "l"((uintptr_t)__cvta_generic_to_global(ptr)) : "memory");
|
||||
return ans;
|
||||
}
|
||||
|
||||
/* This function ensures that no memory access instruction that appears after the
|
||||
* ld.acquire in the program can be reordered to execute before it. Scope: within
|
||||
* the GPU itself. */
|
||||
__device__ __forceinline__ void fence_acquire_gpu() {
|
||||
static __device__ int dummy;
|
||||
int tmp;
|
||||
asm volatile("ld.acquire.gpu.s32 %0,[%1];" : "=r"(tmp) : "l"(&dummy) : "memory");
|
||||
dummy = tmp;
|
||||
}
|
||||
|
||||
__global__ void _allToAllV_2d_index_push_kernel(
|
||||
void *send_data, void *recv_data,
|
||||
int64_t* topk_indices, int64_t* occurrences, int64_t* dst_offsets,
|
||||
int topk, int n_local_experts, size_t stride, nvshmem_team_t team,
|
||||
int64_t* b_start, int64_t* b_len, int64_t* b_head) {
|
||||
/* Args:
|
||||
* send_data: the data to be sent
|
||||
* recv_data: the data to be received
|
||||
* topk_indices: (n_tokens, topk), the experts to send current token to
|
||||
* dst_offsets: (n_experts), the destination offsets of the expert chunk to be sent, within the destination rank
|
||||
* occurrences: (n_tokens, topk), the rank of current token within the tokens to be sent to an expert
|
||||
* stride: the stride of the data to be sent, i.e. token hidden size * element size, in bytes
|
||||
* b_start: the start offsets of the data to be sent by each CUDA block (equivalent to offset of data received from each rail)
|
||||
* b_len: the length of the data to be sent by each CUDA block
|
||||
* b_head: the most recent ready-to-send token index, plus 1, for each CUDA block
|
||||
*/
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
auto bid = blockIdx.x;
|
||||
auto tail = b_start[bid];
|
||||
auto end = b_start[bid] + b_len[bid];
|
||||
// Use volatile bc we need to re-read its value
|
||||
volatile auto head_ptr = b_head + bid;
|
||||
while (tail < end) {
|
||||
while (tail >= ld_volatile_global(head_ptr)) {
|
||||
// Wait for producer kernel to mark newer ready tokens
|
||||
}
|
||||
// Ready signal has been provided by producer kernel (on the same GPU).
|
||||
// To make sure the token data is readily written by the producer kernel, we
|
||||
// use an acquire fence here, with scope of the same GPU.
|
||||
fence_acquire_gpu();
|
||||
|
||||
// New token has arrived
|
||||
auto send_ptr = (char*)send_data + tail * stride;
|
||||
// Loop over the topk experts
|
||||
for (int k = 0; k < topk; k++) {
|
||||
auto expert = topk_indices[tail * topk + k];
|
||||
// Get the destination rank
|
||||
auto dst_rank = expert / n_local_experts;
|
||||
auto dst_rank_global = nvshmem_team_translate_pe(team, dst_rank, NVSHMEM_TEAM_WORLD);
|
||||
// Get the destination offset
|
||||
auto dst_offset = dst_offsets[expert] + occurrences[tail * topk + k];
|
||||
// Get the destination pointer
|
||||
auto dst_ptr = (char*)recv_data + dst_offset * stride;
|
||||
// Send the data, i.e. 1 token
|
||||
nvshmemx_putmem_nbi_block(dst_ptr, send_ptr, stride, dst_rank_global);
|
||||
}
|
||||
tail++;
|
||||
}
|
||||
// Make sure all data has been sent
|
||||
nvshmem_quiet();
|
||||
#endif
|
||||
}
|
||||
|
||||
void _all_to_all_v_2d_index_push(
|
||||
at::Tensor& input,
|
||||
at::Tensor& out,
|
||||
at::Tensor& topk_indices,
|
||||
at::Tensor& occurrences,
|
||||
at::Tensor& dst_offsets,
|
||||
std::string group_name,
|
||||
at::Tensor& b_start,
|
||||
at::Tensor& b_len,
|
||||
at::Tensor& b_head) {
|
||||
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
|
||||
c10d::symmetric_memory::rendezvous(out, group_name);
|
||||
c10d::symmetric_memory::rendezvous(dst_offsets, group_name);
|
||||
|
||||
auto input_ptr = input.data_ptr();
|
||||
auto out_ptr = out.data_ptr();
|
||||
auto topk_indices_ptr = topk_indices.data_ptr<int64_t>();
|
||||
auto occurrences_ptr = occurrences.data_ptr<int64_t>();
|
||||
auto dst_offsets_ptr = dst_offsets.data_ptr<int64_t>();
|
||||
auto b_start_ptr = b_start.data_ptr<int64_t>();
|
||||
auto b_len_ptr = b_len.data_ptr<int64_t>();
|
||||
auto b_head_ptr = b_head.data_ptr<int64_t>();
|
||||
|
||||
auto topk = topk_indices.size(1);
|
||||
auto n_local_experts = dst_offsets.size(0) / input_hdl->get_world_size();
|
||||
auto stride_bytes = input.stride(0) * input.element_size();
|
||||
|
||||
auto device = input.device();
|
||||
c10::cuda::CUDAGuard guard(device);
|
||||
auto& team_manager = TeamManager::get(device);
|
||||
auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
TORCH_CHECK(
|
||||
b_start.size(0) == b_len.size(0) && b_start.size(0) == b_head.size(0),
|
||||
"Block start, len and head should have same size");
|
||||
|
||||
int nblocks = b_start.size(0);
|
||||
void* args[] = {
|
||||
&input_ptr, &out_ptr,
|
||||
&topk_indices_ptr, &occurrences_ptr, &dst_offsets_ptr,
|
||||
&topk, &n_local_experts, &stride_bytes, &team,
|
||||
&b_start_ptr, &b_len_ptr, &b_head_ptr
|
||||
};
|
||||
C10_CUDA_CHECK(cudaLaunchKernel(
|
||||
(const void*)_allToAllV_2d_index_push_kernel,
|
||||
dim3(nblocks),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args,
|
||||
0,
|
||||
stream));
|
||||
}
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
||||
|
||||
@ -876,4 +1372,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||
m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev);
|
||||
m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d);
|
||||
m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset);
|
||||
m.impl("_make_a2a_exchange_plan", c10d::nvshmem_extension::_make_a2a_exchange_plan);
|
||||
m.impl("_all_to_all_get", c10d::nvshmem_extension::_all_to_all_get);
|
||||
m.impl("_make_a2a_2d_exchange_plan", c10d::nvshmem_extension::_make_a2a_2d_exchange_plan);
|
||||
m.impl("_all_to_all_v_2d_index_push", c10d::nvshmem_extension::_all_to_all_v_2d_index_push);
|
||||
}
|
||||
|
||||
@ -4,12 +4,13 @@ import math
|
||||
import os
|
||||
import socket
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
@ -1824,4 +1825,107 @@ def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def]
|
||||
return _SymmetricMemory.get_mempool_allocator(torch.device(device))
|
||||
|
||||
|
||||
__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"]
|
||||
# Create a type, ExchangePlan.
|
||||
""" A namedtuple consisting of meta information which accelerates all_to_all operations.
|
||||
- in_splits: splits of my input towards different peers.
|
||||
- src_offsets: offsets within peers' input from which I should fetch data.
|
||||
- out_splits: splits of peers' contribution to my output.
|
||||
- dst_offsets: offsets within my output where I should store peers' contribution.
|
||||
"""
|
||||
ExchangePlan = namedtuple(
|
||||
"ExchangePlan", ["in_splits", "src_offsets", "out_splits", "dst_offsets"]
|
||||
)
|
||||
|
||||
|
||||
def make_a2a_exchange_plan(
|
||||
in_splits: torch.Tensor,
|
||||
src_offsets: torch.Tensor,
|
||||
out_splits: torch.Tensor,
|
||||
dst_offsets: torch.Tensor,
|
||||
group_name: str,
|
||||
) -> ExchangePlan:
|
||||
r"""
|
||||
Create an all-to-all exchange plan given the input splits. This is a
|
||||
collective operation.
|
||||
Args:
|
||||
in_splits (class:`torch.Tensor`): the input splits for the exchange plan (IN).
|
||||
src_offsets (class:`torch.Tensor`): the source offsets for the exchange plan (OUT).
|
||||
out_splits (class:`torch.Tensor`): the output splits for the exchange plan (OUT).
|
||||
dst_offsets (class:`torch.Tensor`): the destination offsets for the exchange plan (OUT).
|
||||
group_name (str): the group over which to exchange the splits and offsets.
|
||||
Returns:
|
||||
An `ExchangePlan` capturing the above tensors.
|
||||
"""
|
||||
torch.ops.symm_mem._make_a2a_exchange_plan(
|
||||
in_splits, src_offsets, out_splits, dst_offsets, group_name
|
||||
)
|
||||
return ExchangePlan(in_splits, src_offsets, out_splits, dst_offsets)
|
||||
|
||||
|
||||
def all_to_all_v(
|
||||
input: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
plan: ExchangePlan,
|
||||
group_name: str,
|
||||
b_start: Optional[torch.Tensor] = None,
|
||||
b_len: Optional[torch.Tensor] = None,
|
||||
b_head: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Perform an all-to-all-v operation given an `ExchangePlan`.
|
||||
Args:
|
||||
input (class:`torch.Tensor`): the input tensor for the all-to-all operation (IN).
|
||||
out (class:`torch.Tensor`): the output tensor for the all-to-all operation (OUT).
|
||||
plan (`ExchangePlan`): a tuple consisting of (in_splits, src_offsets, out_splits, dst_offsets).
|
||||
group_name (str): the group over which to perform the all-to-all.
|
||||
"""
|
||||
# For now we use the get style, in future we can extend it to support the
|
||||
# put style too, given a flag or something.
|
||||
torch.ops.symm_mem._all_to_all_get(
|
||||
input,
|
||||
out,
|
||||
plan.src_offsets,
|
||||
plan.out_splits,
|
||||
plan.dst_offsets,
|
||||
group_name,
|
||||
b_start,
|
||||
b_len,
|
||||
b_head,
|
||||
)
|
||||
|
||||
|
||||
def make_a2a_2d_exchange_plan(
|
||||
in_splits: torch.Tensor,
|
||||
src_offsets: torch.Tensor,
|
||||
out_splits: torch.Tensor,
|
||||
dst_offsets: torch.Tensor,
|
||||
group_name: str,
|
||||
) -> ExchangePlan:
|
||||
r"""
|
||||
Create an all-to-all-2d exchange plan given the input splits. This is a
|
||||
collective operation.
|
||||
Args:
|
||||
in_splits (class:`torch.Tensor`): the input splits for the exchange plan (IN).
|
||||
src_offsets (class:`torch.Tensor`): the source offsets for the exchange plan (OUT).
|
||||
out_splits (class:`torch.Tensor`): the output splits for the exchange plan (OUT).
|
||||
dst_offsets (class:`torch.Tensor`): the destination offsets for the exchange plan (OUT).
|
||||
group_name (str): the group over which to exchange the splits and offsets.
|
||||
Returns:
|
||||
An `ExchangePlan` capturing the above tensors.
|
||||
"""
|
||||
torch.ops.symm_mem._make_a2a_2d_exchange_plan(
|
||||
in_splits, src_offsets, out_splits, dst_offsets, group_name
|
||||
)
|
||||
return ExchangePlan(in_splits, src_offsets, out_splits, dst_offsets)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"empty",
|
||||
"rendezvous",
|
||||
"is_nvshmem_available",
|
||||
"set_backend",
|
||||
"get_backend",
|
||||
"ExchangePlan",
|
||||
"make_a2a_exchange_plan",
|
||||
"all_to_all_v",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user