diff --git a/benchmarks/distributed/bench_nvshmem_tile_reduce.py b/benchmarks/distributed/bench_nvshmem_tile_reduce.py new file mode 100644 index 000000000000..da4ce796d7bb --- /dev/null +++ b/benchmarks/distributed/bench_nvshmem_tile_reduce.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +""" +Benchmark for NVSHMEM tile reduce operations. + +Usage: +python benchmarks/distributed/bench_nvshmem_tile_reduce.py + +This benchmark measures the performance of tile reduce operations across different +matrix sizes and tile configurations. +""" + +import time + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +from torch.testing._internal.common_distributed import MultiProcContinuousTest +from torch.testing._internal.common_utils import ( + requires_cuda_p2p_access, + skip_but_pass_in_sandcastle_if, + skipIfRocm, +) + + +# Decorator +def requires_nvshmem(): + return skip_but_pass_in_sandcastle_if( + not symm_mem.is_nvshmem_available(), + "bench_nvshmem_tile_reduce requires NVSHMEM, skipping benchmark", + ) + + +# So that benchmarks are written in device-agnostic way +device_type = "cuda" +device_module = torch.get_device_module(device_type) + + +@requires_nvshmem() +@requires_cuda_p2p_access() +class NVSHMEMTileReduceBenchmark(MultiProcContinuousTest): + def _init_device(self) -> None: + # TODO: relieve this (seems to hang if without) + device_module.set_device(self.device) + # Set NVSHMEM as SymmMem backend + symm_mem.set_backend("NVSHMEM") + + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) + + def _benchmark_tile_reduce_single( + self, + full_size: int, + tile_size: int, + warmup_iters: int = 5, + bench_iters: int = 10, + ) -> dict: + """ + Benchmark a single configuration of tile reduce. + + Args: + full_size: Size of the full matrix (full_size x full_size) + warmup_iters: Number of warmup iterations + bench_iters: Number of benchmark iterations + + Returns: + Dictionary with benchmark results + """ + self._init_device() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + + # Allocate full matrices + full_inp = symm_mem.empty( + full_size, full_size, dtype=dtype, device=self.device + ).fill_(self.rank) + full_out = symm_mem.empty( + full_size, full_size, dtype=dtype, device=self.device + ).fill_(0) + + slice_ut = slice(0, tile_size) + inp_tile = full_inp[slice_ut, slice_ut] + out_tile = full_out[slice_ut, slice_ut] + + root = 0 + + # Warmup iterations + for _ in range(warmup_iters): + torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name) + torch.cuda.synchronize(self.device) + + # Benchmark iterations + times = [] + + dist.barrier() + torch.cuda.synchronize(self.device) + start_time = time.perf_counter() + + for _ in range(bench_iters): + torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name) + + torch.cuda.synchronize(self.device) + end_time = time.perf_counter() + times.append((end_time - start_time) / bench_iters) + + # Calculate statistics + times = torch.tensor(times, dtype=torch.float64) + tile_elements = tile_size * tile_size + tile_bytes = ( + tile_elements * dtype.itemsize + if hasattr(dtype, "itemsize") + else tile_elements * 4 + ) + + results = { + "full_size": full_size, + "tile_size": tile_size, + "tile_elements": tile_elements, + "tile_bytes": tile_bytes, + "world_size": self.world_size, + "mean_time_ms": times.mean().item() * 1000, + "std_time_ms": times.std().item() * 1000, + "min_time_ms": times.min().item() * 1000, + "max_time_ms": times.max().item() * 1000, + "throughput_gb_s": tile_bytes / (times.mean().item() * 1e9), + "elements_per_sec": tile_elements / times.mean().item(), + } + + return results + + @skipIfRocm + def test_benchmark_tile_reduce_various_sizes(self) -> None: + """ + Benchmark tile reduce across various matrix sizes. + """ + # Test various matrix sizes + tile_sizes = [512, 1024, 2048, 4096, 8192, 16384] + full_size = tile_sizes[-1] + warmup_iters = 5 + bench_iters = 20 + + results = [] + + for tile_size in tile_sizes: + try: + result = self._benchmark_tile_reduce_single( + full_size, tile_size, warmup_iters, bench_iters + ) + results.append(result) + + if self.rank == 0: + print( + f"Matrix Size: {full_size}x{full_size}, Tile Size: {tile_size}x{tile_size}" + ) + print( + f" Mean Time: {result['mean_time_ms']:.3f} ± {result['std_time_ms']:.3f} ms" + ) + print(f" Throughput: {result['throughput_gb_s']:.2f} GB/s") + print(f" Bytes: {result['tile_bytes']:.0f}") + print() + + except Exception as e: + if self.rank == 0: + print(f"Failed to benchmark matrix size {full_size}: {e}") + + # Print summary + if self.rank == 0 and results: + print("=== BENCHMARK SUMMARY ===") + print( + f"{'Matrix Size':<12} {'Tile Size':<10} {'Time (ms)':<12} {'Throughput (GB/s)':<18} {'Bytes':<15}" + ) + print("-" * 70) + + for result in results: + print( + f"{result['full_size']}x{result['full_size']:<7} " + f"{result['tile_size']}x{result['tile_size']:<5} " + f"{result['mean_time_ms']:<12.3f} " + f"{result['throughput_gb_s']:<18.2f} " + f"{result['tile_bytes']:<15.0f}" + ) + + +if __name__ == "__main__": + # For standalone usage, you'd need to set up distributed environment + # For now, this is meant to be run via the PyTorch test framework + from torch.testing._internal.common_utils import run_tests + + run_tests() diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index 28b39e081781..1cdcb483be27 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -701,5 +701,54 @@ class DispatchCombineInSubgroups(MultiProcContinuousTest): dispatch_then_combine(self.device, align=8, group=subgroup) +@instantiate_parametrized_tests +@requires_nvshmem() +@requires_cuda_p2p_access() +class NVSHMEMTileCommTest(MultiProcContinuousTest): + def _init_device(self) -> None: + # TODO: relieve this (seems to hang if without) + device_module.set_device(self.device) + # Set NVSHMEM as SymmMem backend + symm_mem.set_backend("NVSHMEM") + + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) + + @skipIfRocm + @parametrize("tile_size", [32, 128, 512]) + @parametrize("dtype", [torch.float, torch.half, torch.bfloat16]) + def test_tile_reduce(self, tile_size: int, dtype: torch.dtype) -> None: + full_size = 1024 + assert tile_size <= full_size + + self._init_device() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + full_inp = symm_mem.empty( + full_size, full_size, dtype=dtype, device=self.device + ).fill_(self.rank) + full_out = symm_mem.empty( + full_size, full_size, dtype=dtype, device=self.device + ).fill_(0) + + slice_ut = slice(tile_size, 2 * tile_size) + inp_tile = full_inp[slice_ut, slice_ut] + out_tile = full_out[slice_ut, slice_ut] + + # Reduce the tile + root = 0 + torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name) + + # Check data + expected = torch.zeros_like(full_out) + expected_tile = expected[slice_ut, slice_ut] + if self.rank == root: + expected_tile.fill_(self.world_size * (self.world_size - 1) / 2) + + torch.testing.assert_close(full_out, expected) + + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index 547f2242bf86..fe488ac11ed4 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -510,6 +510,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { "all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()"); m.def( "all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()"); + m.def( + "tile_reduce(Tensor in_tile, Tensor(a!) out_tile, int root, str group_name, str reduce_op='sum') -> ()"); } TORCH_LIBRARY_IMPL(symm_mem, Meta, m) { diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu index 8ac0833c5284..c8c6d6648f59 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -9,6 +9,7 @@ #include #include +#include // Use torch's cub wrapper instead of CUDA's , see #55292 #include @@ -863,6 +864,128 @@ void all_to_all_vdev_2d_offset( 0, stream); } + +/* Tiled Communication */ + +using Shape2D = nvshmemx::shape; +using Stride2D = nvshmemx::stride; + +template +__global__ void tile_reduce_kernel( + T* src_ptr, T* dst_ptr, Shape2D shape, Stride2D strides, int64_t root, nvshmem_team_t* teams) { +#ifndef _NVSHMEM_DEVICELIB_SUPPORTED + CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM"); +#else + int bid = blockIdx.x; + auto team = teams[bid]; + CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID && " invalid team\n"); + + // Global tile shape + auto [rows, cols] = shape; + auto [stride0, stride1] = strides; + + // Divide rows among CUDA blocks + auto rows_per_block = at::ceil_div(rows, (int64_t)gridDim.x); + auto block_start_row = rows_per_block * bid; + auto block_shape = nvshmemx::make_shape(std::min(rows_per_block, rows - block_start_row), cols); + auto block_layout = nvshmemx::make_layout(block_shape, strides); + + // Start pointer of each block's sub-tile + auto block_src_ptr = src_ptr + stride0 * block_start_row; + auto block_dst_ptr = dst_ptr + stride0 * block_start_row; + auto block_src_tensor = nvshmemx::Tensor(block_src_ptr, block_layout); + auto block_dst_tensor = nvshmemx::Tensor(block_dst_ptr, block_layout); + + // Making these empty to avoid nvshmemx::tile_sum_reduce_block() from doing + // additional range checks + auto start_coord = nvshmemx::make_shape(); + auto boundary = nvshmemx::make_shape(); + + // Use one-shot pull to reduce the tile + uint64_t flag = 0; + constexpr auto algo = nvshmemx::tile_coll_algo_t::NVLS_ONE_SHOT_PULL_NBI; + nvshmemx::tile_sum_reduce_block( + team, block_src_tensor, block_dst_tensor, start_coord, boundary, root, flag /* unused */); + + // Wait for the operation to complete + nvshmemx::tile_collective_wait(team, flag /* unused */); +#endif +} + +#define AT_DISPATCH_CASE_CONVERT(enum_type, scalar_type, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + return __VA_ARGS__(); \ + } + +#define AT_DISPATCH_NVSHMEM_FLOATS(scalar_type, name, ...) \ + AT_DISPATCH_SWITCH( \ + scalar_type, name, \ + AT_DISPATCH_CASE_CONVERT(at::kBFloat16, __nv_bfloat16, __VA_ARGS__); \ + AT_DISPATCH_CASE_CONVERT(at::kHalf, __half, __VA_ARGS__); \ + AT_DISPATCH_CASE(at::kFloat, __VA_ARGS__)); + +void tile_reduce( + at::Tensor& in_tile, + at::Tensor& out_tile, + int64_t root, + std::string group_name, + std::string reduce_op) { + /* Perform a tile reduce operation on the input tensor, with the root rank + * receiving the reduced tensor. */ + TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now"); + TORCH_CHECK(in_tile.dim() == 2 && out_tile.dim() == 2, "Only 2D tensors are supported"); + TORCH_CHECK_EQ(in_tile.dtype(), out_tile.dtype()); + TORCH_CHECK_EQ(in_tile.sizes(), out_tile.sizes()); + TORCH_CHECK_EQ(in_tile.strides(), out_tile.strides()); + TORCH_CHECK_EQ(in_tile.device(), out_tile.device()); + + auto device = in_tile.device(); + c10::cuda::CUDAGuard guard(device); + auto hdl = c10d::symmetric_memory::rendezvous(in_tile, group_name); + c10d::symmetric_memory::rendezvous(out_tile, group_name); + + // Ideally 16 bytes per thread + int nblocks = at::ceil_div( + in_tile.numel() * in_tile.element_size(), + (int64_t)THREADS_PER_BLOCK * 16); + nblocks = std::min(nblocks, 24); + + // Need one team per block + auto& team_manager = TeamManager::get(device); + auto [teams, teams_dev] = team_manager.get_n_teams( + group_name, hdl->get_rank_to_global_rank(), nblocks); + TORCH_CHECK( + root < nvshmem_team_n_pes(teams[0]), + "root must be smaller than group size"); + auto stream = at::cuda::getCurrentCUDAStream(); + + // Prepare launch parameters + auto shape = nvshmemx::make_shape(in_tile.sizes()[0], in_tile.sizes()[1]); + auto stride = nvshmemx::make_stride(in_tile.strides()[0], in_tile.strides()[1]); + auto src_ptr = in_tile.const_data_ptr(); + auto dst_ptr = out_tile.mutable_data_ptr(); + void* args[] = { + &src_ptr, + &dst_ptr, + &shape, + &stride, + &root, + &teams_dev}; + + AT_DISPATCH_NVSHMEM_FLOATS(in_tile.scalar_type(), "tile_reduce", [&]() { + nvshmemx_collective_launch( + (const void*)tile_reduce_kernel, + dim3(nblocks), + dim3(THREADS_PER_BLOCK), + args, + 0, + stream); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + } // namespace c10d::nvshmem_extension @@ -876,4 +999,5 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev); m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d); m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset); + m.impl("tile_reduce", c10d::nvshmem_extension::tile_reduce); } diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh index bd8d778e862b..774246132409 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh @@ -58,4 +58,11 @@ void all_to_all_vdev_2d_offset( at::Tensor& out_splits_offsets, std::string group_name); +void tile_reduce( + at::Tensor& in_tile, + at::Tensor& out_tile, + int64_t root, + std::string group_name, + std::string reduce_op = "sum"); + } // namespace c10d::nvshmem_extension