mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NVSHMEM][Triton] Fix NVSHMEM triton test for wacky world sizes (#165704)
Currently assumes divisible by 4? world size Not as slick as the old setup code but more general Pull Request resolved: https://github.com/pytorch/pytorch/pull/165704 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
382b0150de
commit
a16fd6b488
@ -1141,9 +1141,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
|
||||
vals[0, ::2] = 1
|
||||
vals[0, 1::2] = 2
|
||||
vals[1] = 1
|
||||
vals2 = vals[2].view(-1, 2, 2)
|
||||
vals2[:, 0] = 1
|
||||
vals2[:, 1] = 2
|
||||
for rank in range(world_size):
|
||||
vals[2, rank] = 1 if (rank // 2) % 2 == 0 else 2
|
||||
expected = vals.prod(-1).tolist()
|
||||
|
||||
# Synchronize before reduction
|
||||
|
Reference in New Issue
Block a user