Files
pytorch/test/distributed/test_nvshmem_triton.py
Wei Wang d7e275d4b4 [CI][CUDA] Add periodic b200 distributed job (#159323)
1. Run distributed job with B200 runner, periodically.
2. discovered generic distributed test issue that certain unit test hard-coded ranks, calling for require_exact_world_size(world_size) API instead of require_world_size(world_size).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159323
Approved by: https://github.com/eqy

Co-authored-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
2025-10-16 21:54:04 +00:00

1174 lines
36 KiB
Python

# Owner(s): ["oncall: distributed"]
# To run:
# python test/distributed/test_nvshmem_triton.py
import sys
import triton.language as tl
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch._inductor.runtime.triton_compat import triton
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
from torch.testing._internal.common_cuda import SM100OrLater
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import IS_H100, requires_triton
if not symm_mem.is_nvshmem_available():
print("NVSHMEM not available, skipping tests")
sys.exit(0)
def requires_h100():
return skip_but_pass_in_sandcastle_if(
not IS_H100,
"NVSHMEM requires H100. Skipping test on non-H100 GPU.",
)
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
# Shared Triton JIT kernels
@requires_nvshmem
@triton.jit
def my_put_kernel(
dest,
src,
nelems,
pe,
):
nvshmem.put(dest, src, nelems, pe)
@requires_nvshmem
@triton.jit
def my_get_kernel(
dest,
src,
nelems,
pe,
nbi: tl.constexpr, # use nonblocking interface if True
):
if nbi:
nvshmem.get_nbi(dest, src, nelems, pe)
nvshmem.quiet()
else:
nvshmem.get(dest, src, nelems, pe)
@requires_nvshmem
@triton.jit
def my_putmem_signal_block_kernel(
dst,
src,
size_bytes,
signal,
sig_val,
sig_op,
peer,
):
nvshmem.putmem_signal_block(dst, src, size_bytes, signal, sig_val, sig_op, peer)
@requires_nvshmem
@triton.jit
def my_signal_wait_until_kernel(signal, cmp_op, cmp_val):
nvshmem.signal_wait_until(signal, cmp_op, cmp_val)
@requires_nvshmem
@triton.jit
def my_signal_op_kernel(
sig_addr,
signal,
sig_op,
peer,
):
nvshmem.signal_op(sig_addr, signal, sig_op, peer)
@requires_nvshmem
@triton.jit
def my_wait_until_kernel(
ivar,
cmp_op,
cmp_val,
):
nvshmem.wait_until(ivar, cmp_op, cmp_val)
@requires_nvshmem
@triton.jit
def my_fence_kernel():
nvshmem.fence()
@requires_nvshmem
@triton.jit
def my_put_with_fence_kernel(
dst1,
src1,
dst2,
src2,
flag_dst,
flag_src,
nelems,
peer,
):
# First put
nvshmem.put(dst1, src1, nelems, peer)
# Ensure the first put is ordered before the next.
nvshmem.fence()
# Second put
nvshmem.put(dst2, src2, nelems, peer)
# Order the second put before flag update.
nvshmem.fence()
# Write the flag (single int64) to signal completion.
nvshmem.put(flag_dst, flag_src, 1, peer)
@requires_nvshmem
@triton.jit
def my_put_with_quiet_kernel(
dst,
src,
flag_dst,
flag_src,
nelems,
peer,
):
# Put data
nvshmem.put(dst, src, nelems, peer)
# Call quiet to ensure put is complete
nvshmem.quiet()
# Only after quiet, set the completion flag
# This ensures the data put is complete before flag is set
nvshmem.put(flag_dst, flag_src, 1, peer)
@requires_nvshmem
@triton.jit
def my_barrier_test_kernel(
dst,
src,
nelems,
):
# Testing barrier_all() requires coordinated operations across PEs within
# the same kernel execution. Unlike other kernels that just wrap NVSHMEM
# primitives, this one implements the full test logic to properly verify
# device-side barrier synchronization.
my_pe = nvshmem.my_pe()
n_pes = nvshmem.n_pes()
# Rank 0 broadcasts its value to all other ranks
if my_pe == 0:
# Write initial value
p_src = src.to(tl.pointer_type(tl.int32))
tl.store(p_src, 42)
# Put to all other ranks
i = 1
while i < n_pes:
nvshmem.put(dst, src, nelems, i)
i += 1
# Synchronize all PEs
nvshmem.barrier_all()
# Non-zero ranks increment the received value
if my_pe != 0:
p_dst = dst.to(tl.pointer_type(tl.int32))
received = tl.load(p_dst)
tl.store(p_dst, received + 1)
@requires_nvshmem
@triton.jit
def my_barrier_all_kernel():
nvshmem.barrier_all()
@requires_nvshmem
@triton.jit
def my_sync_test_kernel(
local_data,
remote_data,
nelems,
):
my_pe = nvshmem.my_pe()
n_pes = nvshmem.n_pes()
# Each PE writes a unique value to its local memory
p_local = local_data.to(tl.pointer_type(tl.int32))
unique_value = my_pe + 100 # PE 0 writes 100, PE 1 writes 101, etc.
tl.store(p_local, unique_value)
# sync_all() ensures local stores are visible to other PEs
# but doesn't guarantee completion of any remote operations
nvshmem.sync_all()
# Now each PE reads from the next PE's memory to verify visibility
# PE 0 reads from PE 1, PE 1 reads from PE 2, ..., PE n-1 reads from PE 0
next_pe = (my_pe + 1) % n_pes
nvshmem.get(remote_data, local_data, nelems, next_pe)
# The get should now see the value that the next PE wrote locally
# because sync_all() made those local stores visible
@requires_nvshmem
@triton.jit
def my_alltoall_kernel(
team_handle,
dst,
src,
nelems_per_pe,
):
nvshmem.alltoall(team_handle, dst, src, nelems_per_pe)
@requires_nvshmem
@triton.jit
def my_broadcast_kernel(
team_handle,
dst,
src,
nelems,
pe_root,
):
nvshmem.broadcast(team_handle, dst, src, nelems, pe_root)
@requires_nvshmem
@triton.jit
def my_reduce_kernel(
team_handle,
dest_tensor,
source_tensor,
nreduce,
operation: tl.constexpr,
):
nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation)
@skip_but_pass_in_sandcastle_if(
SM100OrLater,
"Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897",
)
@instantiate_parametrized_tests
class NVSHMEMTritonTest(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# Set NVSHMEM as SymmMem backend
symm_mem.set_backend("NVSHMEM")
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_put(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
# Configuration
nelems = 5 # number of elements to transfer
dtype = torch.int64
val = 42 + rank # Each rank has different data
# Create symmetric tensors
src = symm_mem.empty(nelems, dtype=dtype, device=self.device)
dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999)
# Fill source tensor with rank-specific pattern
for i in range(nelems):
src[i] = (
val * 10 + i
) # Rank 0: [420, 421, 422, 423, 424], Rank 1: [430, 431, ...]
# Rendezvous
symm_mem.rendezvous(src, group=group_name)
symm_mem.rendezvous(dst, group=group_name)
# Synchronize before operation
dist.barrier()
peer = 1 - rank
if rank == 0:
# Rank 0 puts its data to Rank 1
my_put_kernel[(1,)](
dst,
src,
nelems,
peer,
)
# Synchronize after operation
dist.barrier()
if rank == 1:
# Verify that rank 1 received rank 0's data
expected = [420 + i for i in range(nelems)]
torch.testing.assert_close(
dst, torch.tensor(expected, device=self.device, dtype=dtype)
)
@skipIfRocm
@requires_triton()
@requires_h100()
@parametrize("nbi", [False, True]) # Test both blocking and nonblocking interfaces
def test_triton_get(self, nbi: bool) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
# Configuration
numel = 8
dtype = torch.int8
val = 7
# Create symmetric tensors
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
val if rank == 0 else -1
)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp, group=group_name)
symm_mem.rendezvous(out, group=group_name)
dist.barrier()
peer = 1 - rank
if rank == 1:
# Rank 1 gets data from rank 0 using tensor-aware API
my_get_kernel[(1,)](
out,
inp,
numel,
peer,
nbi=nbi,
)
if rank == 1:
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_get_ring(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
world_size = dist.get_world_size()
# Configuration
numel = 8
dtype = torch.int8
# Each rank fills its input buffer with its own rank value
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp, group=group_name)
symm_mem.rendezvous(out, group=group_name)
dist.barrier()
# Ring topology: each rank gets data from the rank to its left
# rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc.
peer = (rank - 1) % world_size
# All ranks execute the get operation using tensor-aware API
my_get_kernel[(1,)](
out,
inp,
numel,
peer,
nbi=False,
)
expected_value = peer
torch.testing.assert_close(
out, expected_value * torch.ones(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_put_signal_set(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
# Data buffers
val = 11
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
# Use the signal pad attached to the output symmetric memory handle
# as the flag buffer for signaling completion.
flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
peer = 1 - rank
NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set
SIGNAL_VAL = 1 # Signal completion value
NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until
if rank == 0:
# Rank 0 puts into Rank 1
my_putmem_signal_block_kernel[(1, 1, 1)](
out,
inp,
size_bytes=msg_size_bytes,
signal=flag,
sig_val=SIGNAL_VAL,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
)
if rank == 1:
# Wait until signal flag is set by Rank 0
my_signal_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=SIGNAL_VAL,
)
# After wait completes, verify data and flag contents
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_put_signal_add(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
# Data buffers
val = 11
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
# Use the signal pad attached to the output symmetric memory handle
# as the flag buffer for signaling completion.
flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
peer = 1 - rank
NVSHMEM_SIGNAL_ADD = 5 # atomic add operation
SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD
NVSHMEM_CMP_EQ = 0
if rank == 0:
# Rank 0 puts into Rank 1
my_putmem_signal_block_kernel[(1, 1, 1)](
out,
inp,
size_bytes=msg_size_bytes,
signal=flag,
sig_val=SIGNAL_VAL,
sig_op=NVSHMEM_SIGNAL_ADD,
peer=peer,
)
if rank == 1:
my_signal_wait_until_kernel[(1, 1, 1)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=SIGNAL_VAL,
)
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_wait_until(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
peer = 1 - rank
NVSHMEM_CMP_EQ = 0 # equal comparison
FLAG_INITIAL_VALUE = 0
FLAG_FINAL_VALUE = 42
# Use a single int64 symmetric tensor as our synchronization flag.
flag = symm_mem.empty(1, dtype=torch.int32, device=self.device).fill_(
FLAG_INITIAL_VALUE
)
symm_mem.rendezvous(flag, group=group_name)
expected_flag = torch.tensor(
[FLAG_FINAL_VALUE], dtype=torch.int32, device=self.device
)
if rank == 0:
# Rank 0 (the waiter)
my_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=FLAG_FINAL_VALUE,
)
# Verification
torch.testing.assert_close(
flag,
expected_flag,
)
if rank == 1:
# Rank 1 (the signaler)
# Launch a kernel to put the value to Rank 0's flag tensor.
my_put_kernel[(1,)](
flag, # Destination symmetric tensor on the remote PE
expected_flag, # Source data tensor (local)
1, # Number of elements
peer, # The target PE (Rank 0)
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_signal_wait_until(self) -> None:
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
peer = 1 - rank
# NVSHMEM constants from documentation
NVSHMEM_CMP_EQ = 0 # equal comparison
NVSHMEM_SIGNAL_SET = 0 # atomic set operation
# Message configuration
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
val_to_put = 123 # arbitrary test value
COMPLETION_FLAG_VAL = 1
# Producer (rank 0) prepares the data to send
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val_to_put)
symm_mem.rendezvous(inp, group=group_name)
# Consumer (rank 1) prepares the destination buffer
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
out_hdl = symm_mem.rendezvous(out, group=group_name)
# Use the signal pad for synchronization, as in previous tests
flag_dtype = torch.int64
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0)
if rank == 0:
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
my_putmem_signal_block_kernel[(1, 1, 1)](
out,
inp,
size_bytes=msg_size_bytes,
signal=flag,
sig_val=COMPLETION_FLAG_VAL,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
)
elif rank == 1:
# Consumer (rank 1): Waits on the signal variable using `signal_wait_until`.
my_signal_wait_until_kernel[(1, 1, 1)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=COMPLETION_FLAG_VAL,
)
# After the wait returns, verify data and flag
torch.testing.assert_close(
out, val_to_put * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag,
torch.tensor(
[COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device
),
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_fence(self) -> None:
"""
Rank 0 performs two put operations into Rank 1's buffers with a fence
between them, followed by another fence and a flag update. Rank 1 waits
for the flag, then verifies that both destination buffers contain the
expected values. The flag is transferred after the final fence, so
its arrival implies that both preceding puts have been delivered in
order.
"""
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
peer = 1 - rank
# Message configuration
dtype = torch.int8
numel = 8
val1 = 10
val2 = 20
flag_val = 1
# Symmetric buffers
inp1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val1)
inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2)
out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp1, group=group_name)
symm_mem.rendezvous(inp2, group=group_name)
symm_mem.rendezvous(out1, group=group_name)
symm_mem.rendezvous(out2, group=group_name)
# Use regular symmetric memory tensor for flag
flag = symm_mem.empty(1, dtype=torch.int32, device=self.device).fill_(0)
symm_mem.rendezvous(flag, group=group_name)
flag_update_val = torch.tensor(
[flag_val], dtype=torch.int32, device=self.device
)
NVSHMEM_CMP_EQ = 0 # compare equal
if rank == 0:
my_put_with_fence_kernel[(1,)](
out1,
inp1,
out2,
inp2,
flag,
flag_update_val,
nelems=numel,
peer=peer,
)
elif rank == 1:
# Wait until flag is set by Rank 0
my_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=flag_val,
)
# Verify ordered data arrival.
torch.testing.assert_close(
out1, val1 * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
out2, val2 * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag, torch.tensor([flag_val], dtype=torch.int32, device=self.device)
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_quiet(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
peer = 1 - rank
dtype = torch.int8
numel = 8
val = 15
flag_val = 42
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
flag = symm_mem.empty(1, dtype=torch.int32, device=self.device).fill_(0)
flag_update_val = torch.tensor(
[flag_val], dtype=torch.int32, device=self.device
)
symm_mem.rendezvous(inp, group=group_name)
symm_mem.rendezvous(out, group=group_name)
symm_mem.rendezvous(flag, group=group_name)
NVSHMEM_CMP_EQ = 0
dist.barrier()
if rank == 1:
my_put_with_quiet_kernel[(1,)](
out,
inp,
flag,
flag_update_val,
nelems=numel,
peer=peer,
)
elif rank == 0:
my_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=flag_val,
)
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
dist.barrier()
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_barrier(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
numel = 1
dtype = torch.int32
src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0)
dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0)
symm_mem.rendezvous(src, group=group_name)
symm_mem.rendezvous(dst, group=group_name)
my_barrier_test_kernel[(1,)](
dst,
src,
nelems=numel,
launch_cooperative_grid=True,
num_ctas=1,
)
dist.barrier()
if rank == 0:
torch.testing.assert_close(
src, torch.tensor([42], device=self.device, dtype=dtype)
)
else:
torch.testing.assert_close(
dst, torch.tensor([43], device=self.device, dtype=dtype)
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_sync(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
numel = 1
dtype = torch.int32
# Create symmetric buffers
local_data = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0)
remote_data = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0)
symm_mem.rendezvous(local_data, group=group_name)
symm_mem.rendezvous(remote_data, group=group_name)
# Launch kernel with cooperative grid
my_sync_test_kernel[(1,)](
local_data,
remote_data,
nelems=numel,
launch_cooperative_grid=True,
num_ctas=1,
)
# Verify results
# Each PE should have written rank + 100 to its local_data
expected_local = rank + 100
torch.testing.assert_close(
local_data, torch.tensor([expected_local], device=self.device, dtype=dtype)
)
# Each PE should have read (next_rank + 100) into its remote_data
# PE 0 reads from PE 1, PE 1 reads from PE 2, ..., PE n-1 reads from PE 0
next_rank = (rank + 1) % self.world_size
expected_remote = next_rank + 100
torch.testing.assert_close(
remote_data,
torch.tensor([expected_remote], device=self.device, dtype=dtype),
)
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_alltoall(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
rank = self.rank
# Each PE will send 2 int64 elements to every other PE
nelems_per_pe = 2
dtype = torch.int64
# Source buffer: contains data for all PEs
# Layout: [data_for_pe0, data_for_pe1, ...]
src_size = nelems_per_pe * world_size
src = symm_mem.empty(src_size, dtype=dtype, device=self.device)
# Fill source with rank-specific data
# Formula: rank * 100 + destination_pe
for i in range(world_size):
value = rank * 100 + i
src[i * nelems_per_pe : (i + 1) * nelems_per_pe] = value
# Destination buffer
dst = symm_mem.empty(src_size, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(src, group=group_name)
symm_mem.rendezvous(dst, group=group_name)
# Synchronize before alltoall
dist.barrier()
team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0
# Launch the kernel using new tensor-aware API
my_alltoall_kernel[(1,)](
team_handle,
dst,
src,
nelems_per_pe,
launch_cooperative_grid=True,
)
# Synchronize after alltoall
dist.barrier()
# Verify results
for i in range(world_size):
# After alltoall, we should receive data from PE i that was intended for us
# PE i sends (i * 100 + rank) to us
expected = i * 100 + rank
actual = dst[i * nelems_per_pe : (i + 1) * nelems_per_pe]
torch.testing.assert_close(actual, torch.full_like(actual, expected))
@skipIfRocm
@requires_triton()
@requires_h100()
def test_triton_broadcast(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
# Configuration
nelems = 4 # number of elements
dtype = torch.int64
# Source buffer - only root will have meaningful data
pe_root = 0 # PE 0 will be the root
src = symm_mem.empty(nelems, dtype=dtype, device=self.device)
# Destination buffer
dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999)
if rank == pe_root:
# Root fills with specific pattern
for i in range(nelems):
src[i] = 100 + i
else:
# Non-root PEs have dummy data
src.fill_(-1)
symm_mem.rendezvous(src, group=group_name)
symm_mem.rendezvous(dst, group=group_name)
# Synchronize before broadcast
dist.barrier()
# Execute broadcast
team_handle = 0 # NVSHMEM_TEAM_WORLD
my_broadcast_kernel[(1,)](
team_handle,
dst,
src,
nelems,
pe_root,
launch_cooperative_grid=True,
)
# Synchronize after broadcast
dist.barrier()
# Verify results - all ranks should have the root's data
expected = [100 + i for i in range(nelems)]
torch.testing.assert_close(
dst, torch.tensor(expected, device=self.device, dtype=dtype)
)
@skipIfRocm
@requires_triton()
@requires_h100()
@parametrize(
"dtype",
[
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.float16,
torch.float32,
# torch.float64, # Tensor-likes are not close
torch.bfloat16,
],
)
def test_triton_sum_reduce(self, dtype) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
rank = self.rank
# Configuration
nreduce = 3 # number of separate reductions
# Source buffer - each rank contributes different values
src = symm_mem.empty(nreduce, dtype=dtype, device=self.device)
for i in range(nreduce):
src[i] = (rank + 1) * (i + 1) # Rank 0: [1,2,3], Rank 1: [2,4,6], etc.
# Destination buffer
dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(src, group=group_name)
symm_mem.rendezvous(dst, group=group_name)
# Calculate expected results
expected = []
for i in range(nreduce):
# Sum across all ranks: sum((rank+1)*(i+1) for rank in range(world_size))
total = sum((r + 1) * (i + 1) for r in range(world_size))
expected.append(total)
# Synchronize before reduction
dist.barrier()
# Execute sum reduction across all ranks
team_handle = 0 # NVSHMEM_TEAM_WORLD
my_reduce_kernel[(1,)](
team_handle,
dst,
src,
nreduce,
operation="sum",
launch_cooperative_grid=True,
)
# Synchronize after reduction
dist.barrier()
# Verify results
torch.testing.assert_close(
dst, torch.tensor(expected, device=self.device, dtype=dtype)
)
@skipIfRocm
@requires_triton()
@requires_h100()
@parametrize(
"dtype",
[
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
],
)
def test_triton_minmax_reduce(self, dtype) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
rank = self.rank
# Configuration
nreduce = 2 # number of values to reduce
# Source buffers for min and max
src_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device)
src_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device)
# Each rank contributes different values
# For min: rank 0: [10, 20], rank 1: [15, 5], etc.
# For max: same values
for i in range(nreduce):
if i == 0:
src_min[i] = 10 + rank * 5 # 10, 15, 20, ...
src_max[i] = 10 + rank * 5
else:
src_min[i] = 20 - rank * 15 # 20, 5, -10, ...
src_max[i] = 20 - rank * 15
# Destination buffers
dst_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1)
dst_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(src_min, group=group_name)
symm_mem.rendezvous(src_max, group=group_name)
symm_mem.rendezvous(dst_min, group=group_name)
symm_mem.rendezvous(dst_max, group=group_name)
# Calculate expected results
all_values = []
for i in range(nreduce):
values = []
for r in range(world_size):
if i == 0:
values.append(10 + r * 5)
else:
values.append(20 - r * 15)
all_values.append(values)
expected_min = [min(vals) for vals in all_values]
expected_max = [max(vals) for vals in all_values]
dist.barrier()
# Execute MIN reduction
team_handle = 0
my_reduce_kernel[(1,)](
team_handle,
dst_min,
src_min,
nreduce,
operation="min",
launch_cooperative_grid=True,
)
# Execute MAX reduction
my_reduce_kernel[(1,)](
team_handle,
dst_max,
src_max,
nreduce,
operation="max",
launch_cooperative_grid=True,
)
dist.barrier()
# Verify results
torch.testing.assert_close(
dst_min, torch.tensor(expected_min, device=self.device, dtype=dtype)
)
torch.testing.assert_close(
dst_max, torch.tensor(expected_max, device=self.device, dtype=dtype)
)
@skipIfRocm
@requires_triton()
@requires_h100()
@parametrize(
"dtype",
[
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
# torch.float64, # Tensor-likes are not close
torch.bfloat16,
],
)
def test_triton_prod_reduce(self, dtype) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
rank = self.rank
# Configuration
nreduce = 3 # number of separate reductions
# Source buffer - each rank contributes different values
# Use very small values to avoid overflow, especially for small integer types
src = symm_mem.empty(nreduce, dtype=dtype, device=self.device)
for i in range(nreduce):
# Use values that won't overflow even for int8: all values 1 or 2
if i == 0:
# For first element: rank 0,2,4... gets 1, rank 1,3,5... gets 2
src[i] = 1 if rank % 2 == 0 else 2
elif i == 1:
# For second element: all get 1 (no multiplication effect)
src[i] = 1
else:
# For third element: rank 0,1 get 1, rank 2,3 get 2, etc. (groups of 2)
src[i] = 1 if (rank // 2) % 2 == 0 else 2
# Destination buffer
dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(src, group=group_name)
symm_mem.rendezvous(dst, group=group_name)
# Calculate expected results
vals = torch.empty(nreduce, world_size, dtype=dtype)
vals[0, ::2] = 1
vals[0, 1::2] = 2
vals[1] = 1
vals2 = vals[2].view(-1, 2, 2)
vals2[:, 0] = 1
vals2[:, 1] = 2
expected = vals.prod(-1).tolist()
# Synchronize before reduction
dist.barrier()
# Execute product reduction across all ranks
team_handle = 0 # NVSHMEM_TEAM_WORLD
my_reduce_kernel[(1,)](
team_handle,
dst,
src,
nreduce,
operation="prod",
launch_cooperative_grid=True,
)
# Synchronize after reduction
dist.barrier()
# Verify results
torch.testing.assert_close(
dst, torch.tensor(expected, device=self.device, dtype=dtype)
)
if __name__ == "__main__":
run_tests()