mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[SymmMem] Tiled reduce (#162243)
Added op: `tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)` For now supports only: - NVSHMEM backed symmetric tensor; - 2D tensor and tile; - torch.float. Testing on right-bottom quandrant: ``` rank 0: tensor([[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0') PASSED ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162243 Approved by: https://github.com/ngimel
This commit is contained in:
191
benchmarks/distributed/bench_nvshmem_tile_reduce.py
Normal file
191
benchmarks/distributed/bench_nvshmem_tile_reduce.py
Normal file
@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Benchmark for NVSHMEM tile reduce operations.
|
||||
|
||||
Usage:
|
||||
python benchmarks/distributed/bench_nvshmem_tile_reduce.py
|
||||
|
||||
This benchmark measures the performance of tile reduce operations across different
|
||||
matrix sizes and tile configurations.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
requires_cuda_p2p_access,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skipIfRocm,
|
||||
)
|
||||
|
||||
|
||||
# Decorator
|
||||
def requires_nvshmem():
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not symm_mem.is_nvshmem_available(),
|
||||
"bench_nvshmem_tile_reduce requires NVSHMEM, skipping benchmark",
|
||||
)
|
||||
|
||||
|
||||
# So that benchmarks are written in device-agnostic way
|
||||
device_type = "cuda"
|
||||
device_module = torch.get_device_module(device_type)
|
||||
|
||||
|
||||
@requires_nvshmem()
|
||||
@requires_cuda_p2p_access()
|
||||
class NVSHMEMTileReduceBenchmark(MultiProcContinuousTest):
|
||||
def _init_device(self) -> None:
|
||||
# TODO: relieve this (seems to hang if without)
|
||||
device_module.set_device(self.device)
|
||||
# Set NVSHMEM as SymmMem backend
|
||||
symm_mem.set_backend("NVSHMEM")
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
def _benchmark_tile_reduce_single(
|
||||
self,
|
||||
full_size: int,
|
||||
tile_size: int,
|
||||
warmup_iters: int = 5,
|
||||
bench_iters: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Benchmark a single configuration of tile reduce.
|
||||
|
||||
Args:
|
||||
full_size: Size of the full matrix (full_size x full_size)
|
||||
warmup_iters: Number of warmup iterations
|
||||
bench_iters: Number of benchmark iterations
|
||||
|
||||
Returns:
|
||||
Dictionary with benchmark results
|
||||
"""
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
dtype = torch.float
|
||||
|
||||
# Allocate full matrices
|
||||
full_inp = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(self.rank)
|
||||
full_out = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(0)
|
||||
|
||||
slice_ut = slice(0, tile_size)
|
||||
inp_tile = full_inp[slice_ut, slice_ut]
|
||||
out_tile = full_out[slice_ut, slice_ut]
|
||||
|
||||
root = 0
|
||||
|
||||
# Warmup iterations
|
||||
for _ in range(warmup_iters):
|
||||
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
# Benchmark iterations
|
||||
times = []
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize(self.device)
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(bench_iters):
|
||||
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
|
||||
|
||||
torch.cuda.synchronize(self.device)
|
||||
end_time = time.perf_counter()
|
||||
times.append((end_time - start_time) / bench_iters)
|
||||
|
||||
# Calculate statistics
|
||||
times = torch.tensor(times, dtype=torch.float64)
|
||||
tile_elements = tile_size * tile_size
|
||||
tile_bytes = (
|
||||
tile_elements * dtype.itemsize
|
||||
if hasattr(dtype, "itemsize")
|
||||
else tile_elements * 4
|
||||
)
|
||||
|
||||
results = {
|
||||
"full_size": full_size,
|
||||
"tile_size": tile_size,
|
||||
"tile_elements": tile_elements,
|
||||
"tile_bytes": tile_bytes,
|
||||
"world_size": self.world_size,
|
||||
"mean_time_ms": times.mean().item() * 1000,
|
||||
"std_time_ms": times.std().item() * 1000,
|
||||
"min_time_ms": times.min().item() * 1000,
|
||||
"max_time_ms": times.max().item() * 1000,
|
||||
"throughput_gb_s": tile_bytes / (times.mean().item() * 1e9),
|
||||
"elements_per_sec": tile_elements / times.mean().item(),
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
@skipIfRocm
|
||||
def test_benchmark_tile_reduce_various_sizes(self) -> None:
|
||||
"""
|
||||
Benchmark tile reduce across various matrix sizes.
|
||||
"""
|
||||
# Test various matrix sizes
|
||||
tile_sizes = [512, 1024, 2048, 4096, 8192, 16384]
|
||||
full_size = tile_sizes[-1]
|
||||
warmup_iters = 5
|
||||
bench_iters = 20
|
||||
|
||||
results = []
|
||||
|
||||
for tile_size in tile_sizes:
|
||||
try:
|
||||
result = self._benchmark_tile_reduce_single(
|
||||
full_size, tile_size, warmup_iters, bench_iters
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if self.rank == 0:
|
||||
print(
|
||||
f"Matrix Size: {full_size}x{full_size}, Tile Size: {tile_size}x{tile_size}"
|
||||
)
|
||||
print(
|
||||
f" Mean Time: {result['mean_time_ms']:.3f} ± {result['std_time_ms']:.3f} ms"
|
||||
)
|
||||
print(f" Throughput: {result['throughput_gb_s']:.2f} GB/s")
|
||||
print(f" Bytes: {result['tile_bytes']:.0f}")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
if self.rank == 0:
|
||||
print(f"Failed to benchmark matrix size {full_size}: {e}")
|
||||
|
||||
# Print summary
|
||||
if self.rank == 0 and results:
|
||||
print("=== BENCHMARK SUMMARY ===")
|
||||
print(
|
||||
f"{'Matrix Size':<12} {'Tile Size':<10} {'Time (ms)':<12} {'Throughput (GB/s)':<18} {'Bytes':<15}"
|
||||
)
|
||||
print("-" * 70)
|
||||
|
||||
for result in results:
|
||||
print(
|
||||
f"{result['full_size']}x{result['full_size']:<7} "
|
||||
f"{result['tile_size']}x{result['tile_size']:<5} "
|
||||
f"{result['mean_time_ms']:<12.3f} "
|
||||
f"{result['throughput_gb_s']:<18.2f} "
|
||||
f"{result['tile_bytes']:<15.0f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# For standalone usage, you'd need to set up distributed environment
|
||||
# For now, this is meant to be run via the PyTorch test framework
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
run_tests()
|
@ -701,5 +701,54 @@ class DispatchCombineInSubgroups(MultiProcContinuousTest):
|
||||
dispatch_then_combine(self.device, align=8, group=subgroup)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_nvshmem()
|
||||
@requires_cuda_p2p_access()
|
||||
class NVSHMEMTileCommTest(MultiProcContinuousTest):
|
||||
def _init_device(self) -> None:
|
||||
# TODO: relieve this (seems to hang if without)
|
||||
device_module.set_device(self.device)
|
||||
# Set NVSHMEM as SymmMem backend
|
||||
symm_mem.set_backend("NVSHMEM")
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
@parametrize("tile_size", [32, 128, 512])
|
||||
@parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
||||
def test_tile_reduce(self, tile_size: int, dtype: torch.dtype) -> None:
|
||||
full_size = 1024
|
||||
assert tile_size <= full_size
|
||||
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
full_inp = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(self.rank)
|
||||
full_out = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(0)
|
||||
|
||||
slice_ut = slice(tile_size, 2 * tile_size)
|
||||
inp_tile = full_inp[slice_ut, slice_ut]
|
||||
out_tile = full_out[slice_ut, slice_ut]
|
||||
|
||||
# Reduce the tile
|
||||
root = 0
|
||||
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
|
||||
|
||||
# Check data
|
||||
expected = torch.zeros_like(full_out)
|
||||
expected_tile = expected[slice_ut, slice_ut]
|
||||
if self.rank == root:
|
||||
expected_tile.fill_(self.world_size * (self.world_size - 1) / 2)
|
||||
|
||||
torch.testing.assert_close(full_out, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -510,6 +510,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
|
||||
m.def(
|
||||
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
|
||||
m.def(
|
||||
"tile_reduce(Tensor in_tile, Tensor(a!) out_tile, int root, str group_name, str reduce_op='sum') -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||
|
||||
#include <ATen/ceil_div.h>
|
||||
// Use torch's cub wrapper instead of CUDA's <cub/cub.cuh>, see #55292
|
||||
#include <ATen/cuda/cub.cuh>
|
||||
|
||||
@ -863,6 +864,128 @@ void all_to_all_vdev_2d_offset(
|
||||
0,
|
||||
stream);
|
||||
}
|
||||
|
||||
/* Tiled Communication */
|
||||
|
||||
using Shape2D = nvshmemx::shape<int64_t, int64_t>;
|
||||
using Stride2D = nvshmemx::stride<int64_t, int64_t>;
|
||||
|
||||
template <typename T>
|
||||
__global__ void tile_reduce_kernel(
|
||||
T* src_ptr, T* dst_ptr, Shape2D shape, Stride2D strides, int64_t root, nvshmem_team_t* teams) {
|
||||
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||
#else
|
||||
int bid = blockIdx.x;
|
||||
auto team = teams[bid];
|
||||
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID && " invalid team\n");
|
||||
|
||||
// Global tile shape
|
||||
auto [rows, cols] = shape;
|
||||
auto [stride0, stride1] = strides;
|
||||
|
||||
// Divide rows among CUDA blocks
|
||||
auto rows_per_block = at::ceil_div(rows, (int64_t)gridDim.x);
|
||||
auto block_start_row = rows_per_block * bid;
|
||||
auto block_shape = nvshmemx::make_shape(std::min(rows_per_block, rows - block_start_row), cols);
|
||||
auto block_layout = nvshmemx::make_layout(block_shape, strides);
|
||||
|
||||
// Start pointer of each block's sub-tile
|
||||
auto block_src_ptr = src_ptr + stride0 * block_start_row;
|
||||
auto block_dst_ptr = dst_ptr + stride0 * block_start_row;
|
||||
auto block_src_tensor = nvshmemx::Tensor(block_src_ptr, block_layout);
|
||||
auto block_dst_tensor = nvshmemx::Tensor(block_dst_ptr, block_layout);
|
||||
|
||||
// Making these empty to avoid nvshmemx::tile_sum_reduce_block() from doing
|
||||
// additional range checks
|
||||
auto start_coord = nvshmemx::make_shape();
|
||||
auto boundary = nvshmemx::make_shape();
|
||||
|
||||
// Use one-shot pull to reduce the tile
|
||||
uint64_t flag = 0;
|
||||
constexpr auto algo = nvshmemx::tile_coll_algo_t::NVLS_ONE_SHOT_PULL_NBI;
|
||||
nvshmemx::tile_sum_reduce_block<decltype(block_src_tensor), decltype(block_dst_tensor), decltype(boundary), algo>(
|
||||
team, block_src_tensor, block_dst_tensor, start_coord, boundary, root, flag /* unused */);
|
||||
|
||||
// Wait for the operation to complete
|
||||
nvshmemx::tile_collective_wait<algo>(team, flag /* unused */);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_CASE_CONVERT(enum_type, scalar_type, ...) \
|
||||
case enum_type: { \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||
using scalar_t = scalar_type; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_NVSHMEM_FLOATS(scalar_type, name, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
scalar_type, name, \
|
||||
AT_DISPATCH_CASE_CONVERT(at::kBFloat16, __nv_bfloat16, __VA_ARGS__); \
|
||||
AT_DISPATCH_CASE_CONVERT(at::kHalf, __half, __VA_ARGS__); \
|
||||
AT_DISPATCH_CASE(at::kFloat, __VA_ARGS__));
|
||||
|
||||
void tile_reduce(
|
||||
at::Tensor& in_tile,
|
||||
at::Tensor& out_tile,
|
||||
int64_t root,
|
||||
std::string group_name,
|
||||
std::string reduce_op) {
|
||||
/* Perform a tile reduce operation on the input tensor, with the root rank
|
||||
* receiving the reduced tensor. */
|
||||
TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now");
|
||||
TORCH_CHECK(in_tile.dim() == 2 && out_tile.dim() == 2, "Only 2D tensors are supported");
|
||||
TORCH_CHECK_EQ(in_tile.dtype(), out_tile.dtype());
|
||||
TORCH_CHECK_EQ(in_tile.sizes(), out_tile.sizes());
|
||||
TORCH_CHECK_EQ(in_tile.strides(), out_tile.strides());
|
||||
TORCH_CHECK_EQ(in_tile.device(), out_tile.device());
|
||||
|
||||
auto device = in_tile.device();
|
||||
c10::cuda::CUDAGuard guard(device);
|
||||
auto hdl = c10d::symmetric_memory::rendezvous(in_tile, group_name);
|
||||
c10d::symmetric_memory::rendezvous(out_tile, group_name);
|
||||
|
||||
// Ideally 16 bytes per thread
|
||||
int nblocks = at::ceil_div(
|
||||
in_tile.numel() * in_tile.element_size(),
|
||||
(int64_t)THREADS_PER_BLOCK * 16);
|
||||
nblocks = std::min(nblocks, 24);
|
||||
|
||||
// Need one team per block
|
||||
auto& team_manager = TeamManager::get(device);
|
||||
auto [teams, teams_dev] = team_manager.get_n_teams(
|
||||
group_name, hdl->get_rank_to_global_rank(), nblocks);
|
||||
TORCH_CHECK(
|
||||
root < nvshmem_team_n_pes(teams[0]),
|
||||
"root must be smaller than group size");
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Prepare launch parameters
|
||||
auto shape = nvshmemx::make_shape(in_tile.sizes()[0], in_tile.sizes()[1]);
|
||||
auto stride = nvshmemx::make_stride(in_tile.strides()[0], in_tile.strides()[1]);
|
||||
auto src_ptr = in_tile.const_data_ptr();
|
||||
auto dst_ptr = out_tile.mutable_data_ptr();
|
||||
void* args[] = {
|
||||
&src_ptr,
|
||||
&dst_ptr,
|
||||
&shape,
|
||||
&stride,
|
||||
&root,
|
||||
&teams_dev};
|
||||
|
||||
AT_DISPATCH_NVSHMEM_FLOATS(in_tile.scalar_type(), "tile_reduce", [&]() {
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)tile_reduce_kernel<scalar_t>,
|
||||
dim3(nblocks),
|
||||
dim3(THREADS_PER_BLOCK),
|
||||
args,
|
||||
0,
|
||||
stream);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
||||
|
||||
@ -876,4 +999,5 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||
m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev);
|
||||
m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d);
|
||||
m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset);
|
||||
m.impl("tile_reduce", c10d::nvshmem_extension::tile_reduce);
|
||||
}
|
||||
|
@ -58,4 +58,11 @@ void all_to_all_vdev_2d_offset(
|
||||
at::Tensor& out_splits_offsets,
|
||||
std::string group_name);
|
||||
|
||||
void tile_reduce(
|
||||
at::Tensor& in_tile,
|
||||
at::Tensor& out_tile,
|
||||
int64_t root,
|
||||
std::string group_name,
|
||||
std::string reduce_op = "sum");
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
Reference in New Issue
Block a user