[SymmMem] Multi-root tile reduction (#164757)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom):

Perform multiple tile reductions concurrently, with each tile reduced to a separate root.

- The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API.

- Currently supports NVLink SHARP scope only.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164757
Approved by: https://github.com/weifengpy, https://github.com/fegin
ghstack dependencies: #162243
This commit is contained in:
Ke Wen
2025-10-07 15:20:44 -07:00
committed by PyTorch MergeBot
parent 83458197d1
commit 5c827a4133
4 changed files with 159 additions and 0 deletions

View File

@ -749,6 +749,72 @@ class NVSHMEMTileCommTest(MultiProcContinuousTest):
torch.testing.assert_close(full_out, expected)
@skipIfRocm
@parametrize("tile_size", [32, 128, 512])
@parametrize(
"root_ratio", [1, 2]
) # 1: all ranks are roots, 2: half of ranks are roots
@parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
def test_multi_root_tile_reduce(
self, tile_size: int, root_ratio: int, dtype: torch.dtype
) -> None:
full_size = 2048
num_slices_col = 2 # number of tiles on column dimension
num_slices_row = (
self.world_size // num_slices_col
) # number of tiles on row dimension
assert tile_size * num_slices_col <= full_size
assert tile_size * num_slices_row <= full_size
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
full_inp = symm_mem.empty(
full_size, full_size, dtype=dtype, device=self.device
).fill_(self.rank)
full_out = symm_mem.empty(
full_size, full_size, dtype=dtype, device=self.device
).fill_(0)
# Get range of each slice in terms of element indices
slices_row = [
slice(s * tile_size, (s + 1) * tile_size) for s in range(num_slices_row)
]
slices_col = [
slice(s * tile_size, (s + 1) * tile_size) for s in range(num_slices_col)
]
# Active roots, can be a subset of all ranks
num_active_roots = self.world_size // root_ratio
active_roots = list(range(num_active_roots))
# Map rank to slice indices (e.g. rank 0 -> (0, 0), rank 1 -> (0, 1), rank 2 -> (1, 0), rank 3 -> (1, 1))
map_rank_to_slices = lambda r: ( # noqa: E731
slices_row[r // num_slices_col],
slices_col[r % num_slices_col],
)
# Populate input tiles
input_tiles_ij = [map_rank_to_slices(r) for r in active_roots]
input_tiles = [
full_inp[slice_i, slice_j] for (slice_i, slice_j) in input_tiles_ij
]
# My output tile (i.e. the one that I will reduce)
out_tile_ij = map_rank_to_slices(self.rank)
out_tile = full_out[out_tile_ij[0], out_tile_ij[1]]
# Reduce the tiles
torch.ops.symm_mem.multi_root_tile_reduce(
input_tiles, out_tile, active_roots, group_name
)
# Check data
expected = torch.zeros_like(full_out)
expected_tile = expected[out_tile_ij[0], out_tile_ij[1]]
if self.rank in active_roots:
expected_tile.fill_(self.world_size * (self.world_size - 1) / 2)
torch.testing.assert_close(full_out, expected)
if __name__ == "__main__":
run_tests()

View File

@ -514,6 +514,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
m.def(
"tile_reduce(Tensor in_tile, Tensor(a!) out_tile, int root, str group_name, str reduce_op='sum') -> ()");
m.def(
"multi_root_tile_reduce(Tensor[] in_tiles, Tensor(a!) out_tile, int[] roots, str group_name, str reduce_op='sum') -> ()");
}
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

View File

@ -986,6 +986,89 @@ void tile_reduce(
});
}
/* Multi-tile Communication */
void multi_root_tile_reduce(
at::ArrayRef<at::Tensor> in_tiles,
at::Tensor& out_tile,
at::ArrayRef<int64_t> roots,
std::string group_name,
std::string reduce_op) {
/* Perform multiple tile reductions concurrently, with each tile reduced to a separate root.
Args:
- `in_tiles` is a list of input tensors.
- `out_tile` is the output tensor.
- `roots` is a list of root ranks corresponding to each input tile, in the same order. A rank cannot be a root more than once.
- `group_name` is the name of the group to use for the collective operation.
- `reduce_op` is the reduction operation to perform. Currently only "sum" is supported.
*/
TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now");
TORCH_CHECK(out_tile.dim() == 2, "Only 2D tensors are supported");
TORCH_CHECK(roots.size() == in_tiles.size(), "Number of roots must match number of tiles");
// Get device and stream
auto device = out_tile.device();
c10::cuda::CUDAGuard guard(device);
auto stream = at::cuda::getCurrentCUDAStream();
// Rendezvous all tensors, and find the tile "I" need to reduce
auto hdl = c10d::symmetric_memory::rendezvous(out_tile, group_name);
int rank = hdl->get_rank();
int world_size = hdl->get_world_size();
int i = 0, my_tile_idx = 0, root = world_size;
// Note: if there is no tile for the current rank, my_tile_idx will remain
// initial value 0, and root will remain `world_size`. This is OK. In
// `nvshmemx::tile_sum_reduce_block`, this rank would skip the reduction
// operation, but would still participate in the barrier.
for (auto& in_tile : in_tiles) {
TORCH_CHECK(in_tile.dim() == 2, "Only 2D tensors are supported");
c10d::symmetric_memory::rendezvous(in_tile, group_name);
TORCH_CHECK(roots[i] < world_size && roots[i] >= 0, "Invalid root");
if (roots[i] == rank) {
TORCH_CHECK(root == world_size, "Each rank can only be a root once");
my_tile_idx = i;
root = rank;
}
i++;
}
// Ideally 16 bytes per thread
int nblocks = at::ceil_div(
out_tile.numel() * out_tile.element_size(),
(int64_t)THREADS_PER_BLOCK * 16);
nblocks = std::min(nblocks, 24);
// Need one team per block
auto& team_manager = TeamManager::get(device);
auto [teams, teams_dev] = team_manager.get_n_teams(
group_name, hdl->get_rank_to_global_rank(), nblocks);
// Prepare launch parameters
auto shape = nvshmemx::make_shape(out_tile.sizes()[0], out_tile.sizes()[1]);
auto stride = nvshmemx::make_stride(out_tile.strides()[0], out_tile.strides()[1]);
auto in_tile_ptr = in_tiles[my_tile_idx].const_data_ptr();
auto out_tile_ptr = out_tile.mutable_data_ptr();
void* args[] = {
&in_tile_ptr,
&out_tile_ptr,
&shape,
&stride,
&root,
&teams_dev};
AT_DISPATCH_NVSHMEM_FLOATS(out_tile.scalar_type(), "multi_root_tile_reduce", [&]() {
nvshmemx_collective_launch(
(const void*)tile_reduce_kernel<scalar_t>,
dim3(nblocks),
dim3(THREADS_PER_BLOCK),
args,
0,
stream);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
} // namespace c10d::nvshmem_extension
@ -1000,4 +1083,5 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d);
m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset);
m.impl("tile_reduce", c10d::nvshmem_extension::tile_reduce);
m.impl("multi_root_tile_reduce", c10d::nvshmem_extension::multi_root_tile_reduce);
}

View File

@ -65,4 +65,11 @@ void tile_reduce(
std::string group_name,
std::string reduce_op = "sum");
void multi_root_tile_reduce(
at::ArrayRef<at::Tensor> in_tiles,
at::Tensor& out_tile,
at::ArrayRef<int64_t> roots,
std::string group_name,
std::string reduce_op = "sum");
} // namespace c10d::nvshmem_extension