[SymmMem] Tiled reduce (#162243)

Added op: `tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)`

For now supports only:
- NVSHMEM backed symmetric tensor;
- 2D tensor and tile;
- torch.float.

Testing on right-bottom quandrant:
```
rank 0:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0')
PASSED
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162243
Approved by: https://github.com/ngimel
This commit is contained in:
Ke Wen
2025-10-07 15:20:43 -07:00
committed by PyTorch MergeBot
parent 3040a5d294
commit d444384003
5 changed files with 373 additions and 0 deletions

View File

@ -701,5 +701,54 @@ class DispatchCombineInSubgroups(MultiProcContinuousTest):
dispatch_then_combine(self.device, align=8, group=subgroup)
@instantiate_parametrized_tests
@requires_nvshmem()
@requires_cuda_p2p_access()
class NVSHMEMTileCommTest(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
@parametrize("tile_size", [32, 128, 512])
@parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
def test_tile_reduce(self, tile_size: int, dtype: torch.dtype) -> None:
full_size = 1024
assert tile_size <= full_size
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
full_inp = symm_mem.empty(
full_size, full_size, dtype=dtype, device=self.device
).fill_(self.rank)
full_out = symm_mem.empty(
full_size, full_size, dtype=dtype, device=self.device
).fill_(0)
slice_ut = slice(tile_size, 2 * tile_size)
inp_tile = full_inp[slice_ut, slice_ut]
out_tile = full_out[slice_ut, slice_ut]
# Reduce the tile
root = 0
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
# Check data
expected = torch.zeros_like(full_out)
expected_tile = expected[slice_ut, slice_ut]
if self.rank == root:
expected_tile.fill_(self.world_size * (self.world_size - 1) / 2)
torch.testing.assert_close(full_out, expected)
if __name__ == "__main__":
run_tests()