mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SymmMem] Multi-root tile reduction (#164757)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): Perform multiple tile reductions concurrently, with each tile reduced to a separate root. - The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API. - Currently supports NVLink SHARP scope only. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164757 Approved by: https://github.com/weifengpy, https://github.com/fegin ghstack dependencies: #162243
This commit is contained in:
@ -749,6 +749,72 @@ class NVSHMEMTileCommTest(MultiProcContinuousTest):
|
||||
|
||||
torch.testing.assert_close(full_out, expected)
|
||||
|
||||
@skipIfRocm
|
||||
@parametrize("tile_size", [32, 128, 512])
|
||||
@parametrize(
|
||||
"root_ratio", [1, 2]
|
||||
) # 1: all ranks are roots, 2: half of ranks are roots
|
||||
@parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
||||
def test_multi_root_tile_reduce(
|
||||
self, tile_size: int, root_ratio: int, dtype: torch.dtype
|
||||
) -> None:
|
||||
full_size = 2048
|
||||
num_slices_col = 2 # number of tiles on column dimension
|
||||
num_slices_row = (
|
||||
self.world_size // num_slices_col
|
||||
) # number of tiles on row dimension
|
||||
assert tile_size * num_slices_col <= full_size
|
||||
assert tile_size * num_slices_row <= full_size
|
||||
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
full_inp = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(self.rank)
|
||||
full_out = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(0)
|
||||
|
||||
# Get range of each slice in terms of element indices
|
||||
slices_row = [
|
||||
slice(s * tile_size, (s + 1) * tile_size) for s in range(num_slices_row)
|
||||
]
|
||||
slices_col = [
|
||||
slice(s * tile_size, (s + 1) * tile_size) for s in range(num_slices_col)
|
||||
]
|
||||
|
||||
# Active roots, can be a subset of all ranks
|
||||
num_active_roots = self.world_size // root_ratio
|
||||
active_roots = list(range(num_active_roots))
|
||||
|
||||
# Map rank to slice indices (e.g. rank 0 -> (0, 0), rank 1 -> (0, 1), rank 2 -> (1, 0), rank 3 -> (1, 1))
|
||||
map_rank_to_slices = lambda r: ( # noqa: E731
|
||||
slices_row[r // num_slices_col],
|
||||
slices_col[r % num_slices_col],
|
||||
)
|
||||
# Populate input tiles
|
||||
input_tiles_ij = [map_rank_to_slices(r) for r in active_roots]
|
||||
input_tiles = [
|
||||
full_inp[slice_i, slice_j] for (slice_i, slice_j) in input_tiles_ij
|
||||
]
|
||||
# My output tile (i.e. the one that I will reduce)
|
||||
out_tile_ij = map_rank_to_slices(self.rank)
|
||||
out_tile = full_out[out_tile_ij[0], out_tile_ij[1]]
|
||||
|
||||
# Reduce the tiles
|
||||
torch.ops.symm_mem.multi_root_tile_reduce(
|
||||
input_tiles, out_tile, active_roots, group_name
|
||||
)
|
||||
|
||||
# Check data
|
||||
expected = torch.zeros_like(full_out)
|
||||
expected_tile = expected[out_tile_ij[0], out_tile_ij[1]]
|
||||
if self.rank in active_roots:
|
||||
expected_tile.fill_(self.world_size * (self.world_size - 1) / 2)
|
||||
torch.testing.assert_close(full_out, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -514,6 +514,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
|
||||
m.def(
|
||||
"tile_reduce(Tensor in_tile, Tensor(a!) out_tile, int root, str group_name, str reduce_op='sum') -> ()");
|
||||
m.def(
|
||||
"multi_root_tile_reduce(Tensor[] in_tiles, Tensor(a!) out_tile, int[] roots, str group_name, str reduce_op='sum') -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
||||
|
@ -986,6 +986,89 @@ void tile_reduce(
|
||||
});
|
||||
}
|
||||
|
||||
/* Multi-tile Communication */
|
||||
|
||||
void multi_root_tile_reduce(
|
||||
at::ArrayRef<at::Tensor> in_tiles,
|
||||
at::Tensor& out_tile,
|
||||
at::ArrayRef<int64_t> roots,
|
||||
std::string group_name,
|
||||
std::string reduce_op) {
|
||||
/* Perform multiple tile reductions concurrently, with each tile reduced to a separate root.
|
||||
Args:
|
||||
- `in_tiles` is a list of input tensors.
|
||||
- `out_tile` is the output tensor.
|
||||
- `roots` is a list of root ranks corresponding to each input tile, in the same order. A rank cannot be a root more than once.
|
||||
- `group_name` is the name of the group to use for the collective operation.
|
||||
- `reduce_op` is the reduction operation to perform. Currently only "sum" is supported.
|
||||
*/
|
||||
TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now");
|
||||
TORCH_CHECK(out_tile.dim() == 2, "Only 2D tensors are supported");
|
||||
TORCH_CHECK(roots.size() == in_tiles.size(), "Number of roots must match number of tiles");
|
||||
|
||||
// Get device and stream
|
||||
auto device = out_tile.device();
|
||||
c10::cuda::CUDAGuard guard(device);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Rendezvous all tensors, and find the tile "I" need to reduce
|
||||
auto hdl = c10d::symmetric_memory::rendezvous(out_tile, group_name);
|
||||
int rank = hdl->get_rank();
|
||||
int world_size = hdl->get_world_size();
|
||||
int i = 0, my_tile_idx = 0, root = world_size;
|
||||
// Note: if there is no tile for the current rank, my_tile_idx will remain
|
||||
// initial value 0, and root will remain `world_size`. This is OK. In
|
||||
// `nvshmemx::tile_sum_reduce_block`, this rank would skip the reduction
|
||||
// operation, but would still participate in the barrier.
|
||||
for (auto& in_tile : in_tiles) {
|
||||
TORCH_CHECK(in_tile.dim() == 2, "Only 2D tensors are supported");
|
||||
c10d::symmetric_memory::rendezvous(in_tile, group_name);
|
||||
TORCH_CHECK(roots[i] < world_size && roots[i] >= 0, "Invalid root");
|
||||
if (roots[i] == rank) {
|
||||
TORCH_CHECK(root == world_size, "Each rank can only be a root once");
|
||||
my_tile_idx = i;
|
||||
root = rank;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
// Ideally 16 bytes per thread
|
||||
int nblocks = at::ceil_div(
|
||||
out_tile.numel() * out_tile.element_size(),
|
||||
(int64_t)THREADS_PER_BLOCK * 16);
|
||||
nblocks = std::min(nblocks, 24);
|
||||
|
||||
// Need one team per block
|
||||
auto& team_manager = TeamManager::get(device);
|
||||
auto [teams, teams_dev] = team_manager.get_n_teams(
|
||||
group_name, hdl->get_rank_to_global_rank(), nblocks);
|
||||
|
||||
// Prepare launch parameters
|
||||
auto shape = nvshmemx::make_shape(out_tile.sizes()[0], out_tile.sizes()[1]);
|
||||
auto stride = nvshmemx::make_stride(out_tile.strides()[0], out_tile.strides()[1]);
|
||||
auto in_tile_ptr = in_tiles[my_tile_idx].const_data_ptr();
|
||||
auto out_tile_ptr = out_tile.mutable_data_ptr();
|
||||
|
||||
void* args[] = {
|
||||
&in_tile_ptr,
|
||||
&out_tile_ptr,
|
||||
&shape,
|
||||
&stride,
|
||||
&root,
|
||||
&teams_dev};
|
||||
|
||||
AT_DISPATCH_NVSHMEM_FLOATS(out_tile.scalar_type(), "multi_root_tile_reduce", [&]() {
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)tile_reduce_kernel<scalar_t>,
|
||||
dim3(nblocks),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args,
|
||||
0,
|
||||
stream);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
||||
|
||||
@ -1000,4 +1083,5 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||
m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d);
|
||||
m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset);
|
||||
m.impl("tile_reduce", c10d::nvshmem_extension::tile_reduce);
|
||||
m.impl("multi_root_tile_reduce", c10d::nvshmem_extension::multi_root_tile_reduce);
|
||||
}
|
||||
|
@ -65,4 +65,11 @@ void tile_reduce(
|
||||
std::string group_name,
|
||||
std::string reduce_op = "sum");
|
||||
|
||||
void multi_root_tile_reduce(
|
||||
at::ArrayRef<at::Tensor> in_tiles,
|
||||
at::Tensor& out_tile,
|
||||
at::ArrayRef<int64_t> roots,
|
||||
std::string group_name,
|
||||
std::string reduce_op = "sum");
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
Reference in New Issue
Block a user