[NVSHMEM][Triton] Fix NVSHMEM triton test for wacky world sizes (#165704)

Currently assumes divisible by 4? world size

Not as slick as the old setup code but more general

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165704
Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
Eddie Yan
2025-10-17 19:33:26 +00:00
committed by PyTorch MergeBot
parent 382b0150de
commit a16fd6b488

View File

@ -1141,9 +1141,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
vals[0, ::2] = 1
vals[0, 1::2] = 2
vals[1] = 1
vals2 = vals[2].view(-1, 2, 2)
vals2[:, 0] = 1
vals2[:, 1] = 2
for rank in range(world_size):
vals[2, rank] = 1 if (rank // 2) % 2 == 0 else 2
expected = vals.prod(-1).tolist()
# Synchronize before reduction