diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index 1cdcb483be27..2a601957ebf9 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -749,6 +749,72 @@ class NVSHMEMTileCommTest(MultiProcContinuousTest): torch.testing.assert_close(full_out, expected) + @skipIfRocm + @parametrize("tile_size", [32, 128, 512]) + @parametrize( + "root_ratio", [1, 2] + ) # 1: all ranks are roots, 2: half of ranks are roots + @parametrize("dtype", [torch.float, torch.half, torch.bfloat16]) + def test_multi_root_tile_reduce( + self, tile_size: int, root_ratio: int, dtype: torch.dtype + ) -> None: + full_size = 2048 + num_slices_col = 2 # number of tiles on column dimension + num_slices_row = ( + self.world_size // num_slices_col + ) # number of tiles on row dimension + assert tile_size * num_slices_col <= full_size + assert tile_size * num_slices_row <= full_size + + self._init_device() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + full_inp = symm_mem.empty( + full_size, full_size, dtype=dtype, device=self.device + ).fill_(self.rank) + full_out = symm_mem.empty( + full_size, full_size, dtype=dtype, device=self.device + ).fill_(0) + + # Get range of each slice in terms of element indices + slices_row = [ + slice(s * tile_size, (s + 1) * tile_size) for s in range(num_slices_row) + ] + slices_col = [ + slice(s * tile_size, (s + 1) * tile_size) for s in range(num_slices_col) + ] + + # Active roots, can be a subset of all ranks + num_active_roots = self.world_size // root_ratio + active_roots = list(range(num_active_roots)) + + # Map rank to slice indices (e.g. rank 0 -> (0, 0), rank 1 -> (0, 1), rank 2 -> (1, 0), rank 3 -> (1, 1)) + map_rank_to_slices = lambda r: ( # noqa: E731 + slices_row[r // num_slices_col], + slices_col[r % num_slices_col], + ) + # Populate input tiles + input_tiles_ij = [map_rank_to_slices(r) for r in active_roots] + input_tiles = [ + full_inp[slice_i, slice_j] for (slice_i, slice_j) in input_tiles_ij + ] + # My output tile (i.e. the one that I will reduce) + out_tile_ij = map_rank_to_slices(self.rank) + out_tile = full_out[out_tile_ij[0], out_tile_ij[1]] + + # Reduce the tiles + torch.ops.symm_mem.multi_root_tile_reduce( + input_tiles, out_tile, active_roots, group_name + ) + + # Check data + expected = torch.zeros_like(full_out) + expected_tile = expected[out_tile_ij[0], out_tile_ij[1]] + if self.rank in active_roots: + expected_tile.fill_(self.world_size * (self.world_size - 1) / 2) + torch.testing.assert_close(full_out, expected) + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index b8cfe3edec34..b0e08b49d374 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -514,6 +514,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { "all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()"); m.def( "tile_reduce(Tensor in_tile, Tensor(a!) out_tile, int root, str group_name, str reduce_op='sum') -> ()"); + m.def( + "multi_root_tile_reduce(Tensor[] in_tiles, Tensor(a!) out_tile, int[] roots, str group_name, str reduce_op='sum') -> ()"); } TORCH_LIBRARY_IMPL(symm_mem, Meta, m) { diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu index c8c6d6648f59..cb5d40ef4183 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -986,6 +986,89 @@ void tile_reduce( }); } +/* Multi-tile Communication */ + +void multi_root_tile_reduce( + at::ArrayRef in_tiles, + at::Tensor& out_tile, + at::ArrayRef roots, + std::string group_name, + std::string reduce_op) { + /* Perform multiple tile reductions concurrently, with each tile reduced to a separate root. + Args: + - `in_tiles` is a list of input tensors. + - `out_tile` is the output tensor. + - `roots` is a list of root ranks corresponding to each input tile, in the same order. A rank cannot be a root more than once. + - `group_name` is the name of the group to use for the collective operation. + - `reduce_op` is the reduction operation to perform. Currently only "sum" is supported. + */ + TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now"); + TORCH_CHECK(out_tile.dim() == 2, "Only 2D tensors are supported"); + TORCH_CHECK(roots.size() == in_tiles.size(), "Number of roots must match number of tiles"); + + // Get device and stream + auto device = out_tile.device(); + c10::cuda::CUDAGuard guard(device); + auto stream = at::cuda::getCurrentCUDAStream(); + + // Rendezvous all tensors, and find the tile "I" need to reduce + auto hdl = c10d::symmetric_memory::rendezvous(out_tile, group_name); + int rank = hdl->get_rank(); + int world_size = hdl->get_world_size(); + int i = 0, my_tile_idx = 0, root = world_size; + // Note: if there is no tile for the current rank, my_tile_idx will remain + // initial value 0, and root will remain `world_size`. This is OK. In + // `nvshmemx::tile_sum_reduce_block`, this rank would skip the reduction + // operation, but would still participate in the barrier. + for (auto& in_tile : in_tiles) { + TORCH_CHECK(in_tile.dim() == 2, "Only 2D tensors are supported"); + c10d::symmetric_memory::rendezvous(in_tile, group_name); + TORCH_CHECK(roots[i] < world_size && roots[i] >= 0, "Invalid root"); + if (roots[i] == rank) { + TORCH_CHECK(root == world_size, "Each rank can only be a root once"); + my_tile_idx = i; + root = rank; + } + i++; + } + + // Ideally 16 bytes per thread + int nblocks = at::ceil_div( + out_tile.numel() * out_tile.element_size(), + (int64_t)THREADS_PER_BLOCK * 16); + nblocks = std::min(nblocks, 24); + + // Need one team per block + auto& team_manager = TeamManager::get(device); + auto [teams, teams_dev] = team_manager.get_n_teams( + group_name, hdl->get_rank_to_global_rank(), nblocks); + + // Prepare launch parameters + auto shape = nvshmemx::make_shape(out_tile.sizes()[0], out_tile.sizes()[1]); + auto stride = nvshmemx::make_stride(out_tile.strides()[0], out_tile.strides()[1]); + auto in_tile_ptr = in_tiles[my_tile_idx].const_data_ptr(); + auto out_tile_ptr = out_tile.mutable_data_ptr(); + + void* args[] = { + &in_tile_ptr, + &out_tile_ptr, + &shape, + &stride, + &root, + &teams_dev}; + + AT_DISPATCH_NVSHMEM_FLOATS(out_tile.scalar_type(), "multi_root_tile_reduce", [&]() { + nvshmemx_collective_launch( + (const void*)tile_reduce_kernel, + dim3(nblocks), + dim3(THREADS_PER_BLOCK), + args, + 0, + stream); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + } // namespace c10d::nvshmem_extension @@ -1000,4 +1083,5 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d); m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset); m.impl("tile_reduce", c10d::nvshmem_extension::tile_reduce); + m.impl("multi_root_tile_reduce", c10d::nvshmem_extension::multi_root_tile_reduce); } diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh index 774246132409..50b9e268cba7 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh @@ -65,4 +65,11 @@ void tile_reduce( std::string group_name, std::string reduce_op = "sum"); +void multi_root_tile_reduce( + at::ArrayRef in_tiles, + at::Tensor& out_tile, + at::ArrayRef roots, + std::string group_name, + std::string reduce_op = "sum"); + } // namespace c10d::nvshmem_extension