mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI/Build] Replace mean with torch.all in test_pynccl.py (#10876)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
381ac93bb5
commit
d2bd88b122
@ -62,8 +62,7 @@ def worker_fn():
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == pynccl_comm.world_size
|
||||
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -88,13 +87,11 @@ def multiple_allreduce_worker_fn():
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
assert torch.all(tensor == 4).cpu().item()
|
||||
else:
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
@ -116,13 +113,11 @@ def multiple_allreduce_with_vllm_worker_fn():
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
assert torch.all(tensor == 4).cpu().item()
|
||||
else:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
@ -149,7 +144,7 @@ def worker_fn_with_cudagraph():
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
@ -249,8 +244,7 @@ def send_recv_worker_fn():
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 1
|
||||
assert torch.all(tensor == 1).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -289,11 +283,10 @@ def multiple_send_recv_worker_fn():
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
if torch.distributed.get_rank() in [0, 2]:
|
||||
assert result == 1
|
||||
assert torch.all(tensor == 1).cpu().item()
|
||||
else:
|
||||
assert result == 2
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
|
Reference in New Issue
Block a user