Files
pytorch/test/distributed/test_nvshmem.py
Ke Wen d444384003 [SymmMem] Tiled reduce (#162243)
Added op: `tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)`

For now supports only:
- NVSHMEM backed symmetric tensor;
- 2D tensor and tile;
- torch.float.

Testing on right-bottom quandrant:
```
rank 0:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0')
PASSED
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162243
Approved by: https://github.com/ngimel
2025-10-08 02:03:04 +00:00

755 lines
28 KiB
Python

# Owner(s): ["oncall: distributed"]
# To run:
# python test/distributed/test_nvshmem.py
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.distributed.device_mesh import init_device_mesh
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda_p2p_access,
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
# Decorator
def requires_nvshmem():
return skip_but_pass_in_sandcastle_if(
not symm_mem.is_nvshmem_available(),
"test_nvshmem requires NVSHMEM, skipping tests",
)
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
@requires_nvshmem()
@requires_cuda_p2p_access()
class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_alloc(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
def foo():
inp = symm_mem.empty(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(inp, group=group_name)
foo()
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(out, group=group_name)
@skipIfRocm
def test_alloc_without_device_context(self) -> None:
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
self.assertEqual(out.device, self.device)
symm_mem.rendezvous(out, group=group_name)
@skipIfRocm
def test_mempool_tensor_factory(self) -> None:
"""
Test the effectiveness of MemPool on tensor factory ops.
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
src_rank = 0
allocator = symm_mem.get_mempool_allocator(self.device)
mempool = torch.cuda.MemPool(allocator)
with torch.cuda.use_mem_pool(mempool):
if self.rank == src_rank:
tensor = torch.arange(numel, dtype=dtype, device=self.device)
else:
tensor = torch.zeros(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(tensor, group=group_name)
torch.ops.symm_mem.nvshmem_broadcast(tensor, src_rank, group_name)
self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device))
@skipIfRocm
def test_mempool_compute_ops(self) -> None:
"""
Apply MemPool context to a compute op that creates input to collective.
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
dim = 1024
w = torch.ones(dim, dim, dtype=dtype, device=self.device)
x0 = torch.ones(1, dim, dtype=dtype, device=self.device)
allocator = symm_mem.get_mempool_allocator(self.device)
mempool = torch.cuda.MemPool(allocator)
with torch.cuda.use_mem_pool(mempool):
x = x0 + self.rank
y = torch.mm(x, w)
# y should be a symm tensor
torch.ops.symm_mem.nvshmem_broadcast(y, 0, group_name)
expected = torch.mm(x0, w)
self.assertEqual(y, expected)
@skipIfRocm
def test_handle_offset(self) -> None:
"""
Test if handle offset is correctly set.
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
allocator = symm_mem.get_mempool_allocator(self.device)
mempool = torch.cuda.MemPool(allocator)
with torch.cuda.use_mem_pool(mempool):
x0 = torch.empty(numel, dtype=dtype, device=self.device)
x1 = torch.empty_like(x0)
hdl0 = symm_mem.rendezvous(x0, group=group_name)
hdl1 = symm_mem.rendezvous(x1, group=group_name)
self.assertEqual(hdl0.offset, 0)
self.assertEqual(hdl1.offset, x0.untyped_storage().nbytes())
def test_get_remote_tensor(self) -> None:
"""
Get a remote tensor and use regular aten ops to write to it.
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
allocator = symm_mem.get_mempool_allocator(self.device)
mempool = torch.cuda.MemPool(allocator)
with torch.cuda.use_mem_pool(mempool):
# src data stores my rank
x = torch.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
y = torch.empty_like(x)
hdl_y = symm_mem.rendezvous(y, group=group_name)
peer = (self.rank + 1) % self.world_size # Shifting pattern
y_remote = hdl_y.get_remote_tensor(peer, y.size(), y.dtype)
y_remote.copy_(x)
dist.barrier()
# Expecting data from -1 rank
expected = torch.empty(numel, dtype=dtype, device=self.device).fill_(
(self.rank - 1) % self.world_size
)
self.assertEqual(y, expected)
@skipIfRocm
def test_nvshmem_put(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
hdl = symm_mem.rendezvous(tensor, group=group_name)
signal_pad = hdl.get_signal_pad(self.rank)
signal_val = 5
if self.rank == 0:
torch.ops.symm_mem.nvshmem_put_with_signal(
tensor, signal_pad, signal_val, 1
)
elif self.rank == 1:
torch.ops.symm_mem.nvshmem_wait_for_signal(signal_pad, signal_val, 0)
torch.testing.assert_close(
tensor, torch.zeros(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
def test_nvshmem_get(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
symm_mem.rendezvous(tensor, group=group_name)
if self.rank == 0:
torch.ops.symm_mem.nvshmem_get(tensor, 1)
# TODO: remove after we have wait_signal
dist.barrier()
torch.testing.assert_close(
tensor, torch.ones(numel, dtype=dtype, device=self.device)
)
else:
# handle.wait_signal(src_rank=0)
# TODO: remove after we have wait_signal
dist.barrier()
@instantiate_parametrized_tests
@requires_nvshmem()
@requires_cuda_p2p_access()
class NVSHMEMAll2AllTest(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_nvshmem_all_to_all(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel_per_peer = 10
numel = self.world_size * numel_per_peer
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp, group=group_name)
symm_mem.rendezvous(out, group=group_name)
torch.ops.symm_mem.nvshmem_all_to_all(inp, out, group_name)
expected = torch.cat(
[
torch.empty(numel_per_peer, dtype=dtype, device=self.device).fill_(i)
for i in range(self.world_size)
]
)
torch.testing.assert_close(out, expected)
@skipIfRocm
def test_all_to_all_vdev(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
# Number of elements for a peer is random between [0, k)
k = 10
inp_splits = torch.randint(k, (self.world_size,), device=self.device)
inp_numel = inp_splits.sum().item()
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
out_numel = out_splits.sum().item()
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
max_inp_numel = k * self.world_size
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
overflow_factor = self.world_size # worst case: one rank receives all data
max_out_numel = max_inp_numel * overflow_factor
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).copy_(
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
in_splits = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
out_splits_offsets = symm_mem.empty(
(2, self.world_size), dtype=torch.int64, device=self.device
)
# Row 0 is input splits
in_splits.copy_(inp_splits)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
torch.ops.symm_mem.all_to_all_vdev(
inp, out, in_splits, out_splits_offsets, group_name
)
# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_splits, inp_splits)
# Check output splits (row 1)
torch.testing.assert_close(out_splits_offsets[0], out_splits)
# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
# output offsets from `all_to_all_vdev` is exclusive scan
self.assertEqual(out_splits_offsets[1][0], 0)
torch.testing.assert_close(out_splits_offsets[1][1:], out_offsets[:-1])
# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
dist.all_to_all_single(
expected, inp[:inp_numel], out_splits.tolist(), inp_splits.tolist()
)
torch.testing.assert_close(out[:out_numel], expected)
@skipIfRocm
@parametrize("align", [1, 8, 16]) # `major_align` of output
def test_all_to_all_vdev_2d(self, align: int) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
# Number of experts per rank
ne = 8
nsplits = ne * self.world_size
# Number of elements for an expert is random between [0, k)
k = 10
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=self.device)
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
# We do a .t() here because there is a rank-major to expert-major shuffle
out_splits_t = out_splits.reshape(self.world_size, ne).t()
# Actual number of input elements
inp_numel = inp_splits.sum().item()
# Actual number of output elements
out_numel = out_splits.sum().item()
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
max_inp_numel = k * nsplits
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
overflow_factor = self.world_size # worst case: one rank receives all data
max_out_numel = max_inp_numel * overflow_factor
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).copy_(
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
in_splits = symm_mem.empty(
nsplits, dtype=torch.int64, device=self.device
).copy_(inp_splits)
# 2 rows: output splits, output offsets
# Initiallizing all values to -1 to check if they are updated
out_splits_offsets = symm_mem.empty(
(2, nsplits), dtype=torch.int64, device=self.device
).fill_(-1)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
torch.ops.symm_mem.all_to_all_vdev_2d(
inp, out, in_splits, out_splits_offsets, group_name, major_align=align
)
received_out_splits = out_splits_offsets[0]
received_out_offsets = out_splits_offsets[1]
# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_splits, inp_splits)
# Check output splits (row 1)
torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))
# Check output offsets (row 2)
out_split_list = out_splits_t.tolist()
for i in range(ne):
expert_sum = 0
for j in range(self.world_size):
expert_sum += out_split_list[i][j]
# Align up expert_sum
expert_sum_aligned = (expert_sum + align - 1) // align * align
# If 0, make it at least `align` (bc cutlass currently does not support empty bins)
expert_sum_aligned = max(expert_sum_aligned, align)
# last element absorbs the padding
out_split_list[i][-1] += expert_sum_aligned - expert_sum
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
# Make it exclusive scan because that's what `all_to_all_vdev_2d` returns
out_offsets = torch.cat(
[torch.zeros(1, device=self.device), out_offsets[:-1]]
).to(torch.int64)
torch.testing.assert_close(received_out_offsets, out_offsets)
# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1)
out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1)
dist.all_to_all_single(
expected,
inp[:inp_numel],
out_splits_rank.tolist(),
inp_splits_rank.tolist(),
)
# We still need to shuffle `expected`
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
result_list = []
for j in range(ne):
for i in range(self.world_size):
chunk_id = i * ne + j
offset = out_offsets[chunk_id]
chunk = expected[offset - out_splits[chunk_id] : offset]
result_list.append(chunk)
# Do a chunk-wise comparison
for c, chunk in enumerate(result_list):
start = received_out_offsets[c].item()
split = received_out_splits[c].item()
received_chunk = out[start : start + split]
torch.testing.assert_close(received_chunk, chunk)
@skipIfRocm
def test_all_to_all_vdev_2d_offset(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
# Number of experts per rank
ne = 8
nsplits = ne * self.world_size
# Number of elements for an expert is random between [0, k)
k = 10
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=self.device)
# Each split up align to k, as the offset, i.e. [0, k, 2k, 3k, ...]
inp_offsets = torch.arange(
0, k * nsplits, k, dtype=torch.int64, device=self.device
)
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
# Remember that we up-align each input split to k?
max_inp_numel = k * nsplits
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
overflow_factor = self.world_size # worst case: one rank receives all data
max_out_numel = max_inp_numel * overflow_factor
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).copy_(
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
# 2 rows: input splits, input offsets
in_splits_offsets = symm_mem.empty(
(2, nsplits), dtype=torch.int64, device=self.device
)
# 2 rows: output splits, output offsets
# Initiallizing all values to -1 to check if they are updated
out_splits_offsets = symm_mem.empty(
(2, nsplits), dtype=torch.int64, device=self.device
).fill_(-1)
# Row 0 is input splits
in_splits_offsets[0].copy_(inp_splits)
# Row 1 is input offsets
in_splits_offsets[1].copy_(inp_offsets)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
torch.ops.symm_mem.all_to_all_vdev_2d_offset(
inp, out, in_splits_offsets, out_splits_offsets, group_name
)
received_out_splits = out_splits_offsets[0]
received_out_offsets = out_splits_offsets[1]
# Check input splits and offsets -- should not change
torch.testing.assert_close(in_splits_offsets[0], inp_splits)
torch.testing.assert_close(in_splits_offsets[1], inp_offsets)
# Check output splits (row 1)
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
# First need to transpose the input splits
inp_splits_t = inp_splits.reshape(ne, self.world_size).t().contiguous()
dist.all_to_all_single(out_splits, inp_splits_t)
torch.testing.assert_close(received_out_splits, out_splits)
# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
# output offsets from `all_to_all_vdev_2d_offset` is exclusive scan
self.assertEqual(received_out_offsets[0], 0)
torch.testing.assert_close(received_out_offsets[1:], out_offsets[:-1])
# Check data
# Let's "squeeze" the padding out of the input data first
inp_chunks = [] # (ne, nranks)
for i in range(ne):
inp_chunks_e = [] # (nranks,)
for j in range(self.world_size):
chunk_id = i * self.world_size + j
offset = in_splits_offsets[1][chunk_id]
chunk = inp[offset : offset + inp_splits[chunk_id]]
inp_chunks_e.append(chunk)
inp_chunks.append(inp_chunks_e)
# Transpose the 2D input chunks
inp_chunks_t = list(zip(*inp_chunks))
# Now it is (nranks, ne), concatenate the e's
inp_chunks_t = [torch.cat(row) for row in inp_chunks_t]
# Create empty output tensors -- each tensor is data to be received from a peer
out_splits = out_splits.reshape(self.world_size, ne)
# Sum the split sizes of all experts, per peer
receive_size_per_peer = out_splits.sum(1)
out_chunks = [] # (nranks,)
for i in range(self.world_size):
out_chunks.append(
torch.empty(
receive_size_per_peer[i].item(), dtype=dtype, device=self.device
)
)
# All-to-all
dist.all_to_all(out_chunks, inp_chunks_t)
# Concatenate the output chunks received from all peers
out_expected = torch.cat(out_chunks)
# Actual number of output elements
out_numel = out_splits.sum().item()
self.assertEqual(out_expected.shape[0], out_numel)
# Check data
torch.testing.assert_close(out_expected, out[:out_numel])
# Help function used by multiple tests
def dispatch_then_combine(device, align: int, group) -> None:
"""
Shuffle the tokens, then combine them, and check if the combined data is
exactly the same as the original input data
"""
group_name = group.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
# Number of experts per rank
ne = 8
nsplits = ne * group.size()
# Number of elements for an expert is random between [0, k)
k = 10
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device)
# Actual number of input elements
inp_numel = inp_splits.sum().item()
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
max_inp_numel = k * nsplits
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
overflow_factor = group.size() # worst case: one rank receives all data
max_out_numel = max_inp_numel * overflow_factor
# Buffers for shuffle
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=device).copy_(
torch.randn(max_inp_numel, dtype=dtype, device=device)
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=device).fill_(-1)
in_splits = symm_mem.empty(nsplits, dtype=torch.int64, device=device).copy_(
inp_splits
)
# 2 rows: output splits, output offsets
# Initiallizing all values to -1 to check if they are updated
out_splits_offsets = symm_mem.empty(
(2, nsplits), dtype=torch.int64, device=device
).fill_(-1)
# Buffers for combine
combine_out = symm_mem.empty(max_out_numel, dtype=dtype, device=device).fill_(-1)
# 2 rows: output splits, output offsets
# Initiallizing all values to -1 to check if they are updated
combine_out_splits_offsets = symm_mem.empty(
(2, nsplits), dtype=torch.int64, device=device
).fill_(-1)
# Wait for all ranks to finish tensor allocation before accessing them
torch.cuda.synchronize(device)
dist.barrier(group=group)
# Shuffle the tokens
torch.ops.symm_mem.all_to_all_vdev_2d(
inp, out, in_splits, out_splits_offsets, group_name, major_align=align
)
# Combine the tokens
# `out_splits_offsets` from shuffle is exactly the `input_splits_offsets` for combine
# `out` data from shuffle is exactly the `input` data for combine
torch.ops.symm_mem.all_to_all_vdev_2d_offset(
out, combine_out, out_splits_offsets, combine_out_splits_offsets, group_name
)
# Assert the combined data is exactly the same as the original input data
torch.testing.assert_close(combine_out[:inp_numel], inp[:inp_numel])
# Assert the combined out splits are exactly the same as the original input splits
torch.testing.assert_close(combine_out_splits_offsets[0], inp_splits)
# Assert the combined out offsets are exactly the same as the original input offsets
inp_offsets = torch.cumsum(inp_splits, dim=0) # inclusive scan
# Make it exclusive scan because that's what `all_to_all_vdev_2d_offset` returns
inp_offsets = torch.cat([torch.zeros(1, device=device), inp_offsets[:-1]]).to(
torch.int64
)
torch.testing.assert_close(combine_out_splits_offsets[1], inp_offsets)
# Wait for all ranks to finish accessing tensors before freeing them
dist.barrier(group=group)
torch.cuda.synchronize(device)
@instantiate_parametrized_tests
@requires_nvshmem()
@requires_cuda_p2p_access()
class DispatchCombineTest(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
@parametrize("align", [1, 8, 16]) # `major_align` of output
def test_dispatch_combine(self, align: int) -> None:
"""
Test dispatch-and-combine over World group
"""
torch.manual_seed(42 + self.rank)
self._init_device()
dispatch_then_combine(self.device, align, dist.group.WORLD)
@instantiate_parametrized_tests
@requires_nvshmem()
@requires_cuda_p2p_access()
class DispatchCombineInSubgroups(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
# TODO: FIXIT. Currently, `MultiProcContinuousTest` treats the skip code as a
# failure
@skip_if_lt_x_gpu(4)
def test_dispatch_combine_subgroup(self) -> None:
"""
Test dispatch-and-combine over concurrent subgroups
"""
torch.manual_seed(42 + self.rank)
self._init_device()
symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name)
# Test on two concurrent subgroups
ngroups = 2
subgroup_size = self.world_size // ngroups
dm = init_device_mesh(
device_type, (ngroups, subgroup_size), mesh_dim_names=("dp", "ep")
)
subgroup = dm.get_group("ep")
dispatch_then_combine(self.device, align=8, group=subgroup)
@instantiate_parametrized_tests
@requires_nvshmem()
@requires_cuda_p2p_access()
class NVSHMEMTileCommTest(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
@parametrize("tile_size", [32, 128, 512])
@parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
def test_tile_reduce(self, tile_size: int, dtype: torch.dtype) -> None:
full_size = 1024
assert tile_size <= full_size
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
full_inp = symm_mem.empty(
full_size, full_size, dtype=dtype, device=self.device
).fill_(self.rank)
full_out = symm_mem.empty(
full_size, full_size, dtype=dtype, device=self.device
).fill_(0)
slice_ut = slice(tile_size, 2 * tile_size)
inp_tile = full_inp[slice_ut, slice_ut]
out_tile = full_out[slice_ut, slice_ut]
# Reduce the tile
root = 0
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
# Check data
expected = torch.zeros_like(full_out)
expected_tile = expected[slice_ut, slice_ut]
if self.rank == root:
expected_tile.fill_(self.world_size * (self.world_size - 1) / 2)
torch.testing.assert_close(full_out, expected)
if __name__ == "__main__":
run_tests()