mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SymmMem] Tiled reduce (#162243)
Added op: `tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)` For now supports only: - NVSHMEM backed symmetric tensor; - 2D tensor and tile; - torch.float. Testing on right-bottom quandrant: ``` rank 0: tensor([[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0') PASSED ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162243 Approved by: https://github.com/ngimel
This commit is contained in:
@ -701,5 +701,54 @@ class DispatchCombineInSubgroups(MultiProcContinuousTest):
|
||||
dispatch_then_combine(self.device, align=8, group=subgroup)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_nvshmem()
|
||||
@requires_cuda_p2p_access()
|
||||
class NVSHMEMTileCommTest(MultiProcContinuousTest):
|
||||
def _init_device(self) -> None:
|
||||
# TODO: relieve this (seems to hang if without)
|
||||
device_module.set_device(self.device)
|
||||
# Set NVSHMEM as SymmMem backend
|
||||
symm_mem.set_backend("NVSHMEM")
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
@parametrize("tile_size", [32, 128, 512])
|
||||
@parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
||||
def test_tile_reduce(self, tile_size: int, dtype: torch.dtype) -> None:
|
||||
full_size = 1024
|
||||
assert tile_size <= full_size
|
||||
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
full_inp = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(self.rank)
|
||||
full_out = symm_mem.empty(
|
||||
full_size, full_size, dtype=dtype, device=self.device
|
||||
).fill_(0)
|
||||
|
||||
slice_ut = slice(tile_size, 2 * tile_size)
|
||||
inp_tile = full_inp[slice_ut, slice_ut]
|
||||
out_tile = full_out[slice_ut, slice_ut]
|
||||
|
||||
# Reduce the tile
|
||||
root = 0
|
||||
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
|
||||
|
||||
# Check data
|
||||
expected = torch.zeros_like(full_out)
|
||||
expected_tile = expected[slice_ut, slice_ut]
|
||||
if self.rank == root:
|
||||
expected_tile.fill_(self.world_size * (self.world_size - 1) / 2)
|
||||
|
||||
torch.testing.assert_close(full_out, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user