mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
MOTIVATION To generalize Distributed test cases for non-CUDA devices CHANGES - test/distributed/checkpoint/test_fsspec.py - test/distributed/checkpoint/test_state_dict.py - test/distributed/test_multi_threaded_pg.py Replaced hard coded device names with torch.accelerator.current_accelerator - torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py support for hccl backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/156796 Approved by: https://github.com/guangyey, https://github.com/ezyang
343 lines
12 KiB
Python
343 lines
12 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import operator
|
|
import os
|
|
import sys
|
|
import threading
|
|
from functools import reduce
|
|
from unittest import skip, SkipTest
|
|
|
|
import torch
|
|
import torch.autograd
|
|
import torch.distributed as dist
|
|
from torch._C._distributed_c10d import ReduceOp
|
|
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiThreadedTestCase,
|
|
skip_if_lt_x_gpu,
|
|
spawn_threads_and_init_comms,
|
|
)
|
|
from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase
|
|
|
|
|
|
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
|
|
|
DEFAULT_WORLD_SIZE = 4
|
|
|
|
|
|
class TestCollectivesWithWrapper(TestCase):
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
def test_broadcast_object_list(self):
|
|
val = 99 if dist.get_rank() == 0 else None
|
|
object_list = [val] * dist.get_world_size()
|
|
|
|
dist.broadcast_object_list(object_list=object_list)
|
|
self.assertEqual(99, object_list[0])
|
|
|
|
def test_collective_error_on_rank_zero(self):
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
def _test_method(self):
|
|
input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather
|
|
output_tensors = [
|
|
torch.empty_like(input_tensor) for _ in range(dist.get_world_size())
|
|
]
|
|
dist.all_gather(output_tensors, input_tensor)
|
|
|
|
if dist.get_rank() == 0:
|
|
raise AssertionError("Mimic real test failure.") # fail on rank 0
|
|
|
|
dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
_test_method(self)
|
|
|
|
def test_collective_error_on_rank_non_zero(self):
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
def _test_method(self):
|
|
input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather
|
|
output_tensors = [
|
|
torch.empty_like(input_tensor) for _ in range(dist.get_world_size())
|
|
]
|
|
dist.all_gather(output_tensors, input_tensor)
|
|
|
|
if dist.get_rank() == 1:
|
|
raise AssertionError("Mimic real test failure.") # fail on rank 1
|
|
|
|
dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
_test_method(self)
|
|
|
|
def test_collective_error_on_rank_non_zero_all(self):
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
def _test_method(self):
|
|
input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather
|
|
output_tensors = [
|
|
torch.empty_like(input_tensor) for _ in range(dist.get_world_size())
|
|
]
|
|
dist.all_gather(output_tensors, input_tensor)
|
|
|
|
if dist.get_rank() > 0:
|
|
raise AssertionError(
|
|
"Mimic real test failure."
|
|
) # fail on all non-zero rank
|
|
|
|
dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
_test_method(self)
|
|
|
|
def test_skip(self):
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
@skip("check if skip exception can be captured correctly.")
|
|
def _test_method(self):
|
|
pass
|
|
|
|
if not IS_SANDCASTLE:
|
|
with self.assertRaises(SkipTest):
|
|
_test_method(self)
|
|
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
def test_all_to_all_single_tensor(self):
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
send = torch.full((world_size, 2), rank)
|
|
sizes = torch.ones(world_size, dtype=torch.int64)
|
|
|
|
out = torch.zeros(world_size, 2, dtype=send.dtype)
|
|
dist.all_to_all_single(out, send, sizes, sizes)
|
|
self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size))))
|
|
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
def test_all_to_all_single_list(self):
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
send = torch.full((world_size, 2), rank)
|
|
sizes = [1] * world_size
|
|
|
|
out = torch.zeros(world_size, 2, dtype=send.dtype)
|
|
dist.all_to_all_single(out, send, sizes, sizes)
|
|
self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size))))
|
|
|
|
@spawn_threads_and_init_comms(world_size=4)
|
|
def test_all_to_all_single_none(self):
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
send = torch.full((world_size, 2), rank)
|
|
|
|
out = torch.zeros(world_size, 2, dtype=send.dtype)
|
|
dist.all_to_all_single(out, send)
|
|
self.assertEqual(out.tolist(), list(zip(range(world_size), range(world_size))))
|
|
|
|
|
|
class TestCollectivesWithBaseClass(MultiThreadedTestCase):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
def setUp(self):
|
|
os.environ["TORCH_DIST_INIT_BARRIER"] = "1"
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
os.environ["TORCH_DIST_INIT_BARRIER"] = "0"
|
|
|
|
def test_allgather(self):
|
|
input_tensor = torch.ones(3, 3) * dist.get_rank()
|
|
output_tensors = [
|
|
torch.empty_like(input_tensor) for _ in range(self.world_size)
|
|
]
|
|
dist.all_gather(output_tensors, input_tensor)
|
|
for rank, out_tensor in enumerate(output_tensors):
|
|
self.assertEqual(out_tensor, torch.ones(3, 3) * rank)
|
|
|
|
def test_broadcast(self):
|
|
input_tensor = torch.ones(3, 3) * dist.get_rank()
|
|
for rank in range(self.world_size):
|
|
cloned_input = input_tensor.clone()
|
|
dist.broadcast(cloned_input, src=rank)
|
|
self.assertEqual(cloned_input, torch.ones(3, 3) * rank)
|
|
|
|
def test_scatter(self):
|
|
if dist.get_rank() == 0:
|
|
scatter_list = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
|
|
else:
|
|
scatter_list = None
|
|
output_tensor = torch.empty(3, 3)
|
|
|
|
dist.scatter(output_tensor, scatter_list)
|
|
self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank())
|
|
|
|
def test_reduce_scatter(self):
|
|
to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
|
|
output_tensor = torch.empty(3, 3)
|
|
|
|
dist.reduce_scatter(output_tensor, to_reduce_scatter)
|
|
expected_tensor = torch.ones(3, 3) * dist.get_rank() * self.world_size
|
|
self.assertEqual(output_tensor, expected_tensor)
|
|
|
|
output_tensor = torch.empty(3, 3)
|
|
dist.reduce_scatter(output_tensor, to_reduce_scatter, op=dist.ReduceOp.AVG)
|
|
expected_tensor = torch.ones(3, 3) * dist.get_rank()
|
|
self.assertEqual(output_tensor, expected_tensor)
|
|
|
|
def test_broadcast_object_list(self):
|
|
val = 99 if dist.get_rank() == 0 else None
|
|
object_list = [val] * dist.get_world_size()
|
|
print(f"{dist.get_rank()} -> {dist.get_world_size()}")
|
|
|
|
dist.broadcast_object_list(object_list=object_list)
|
|
self.assertEqual(99, object_list[0])
|
|
|
|
def test_all_reduce(self):
|
|
output = torch.ones(3, 3) * dist.get_rank()
|
|
dist.all_reduce(output)
|
|
res_num = ((0 + self.world_size - 1) * self.world_size) / 2
|
|
self.assertEqual(output, torch.ones(3, 3) * res_num)
|
|
|
|
def test_all_to_all(self):
|
|
rank = self.rank
|
|
world_size = self.world_size
|
|
input_tensor_list = [
|
|
torch.ones(3, 3) * x
|
|
for x in range(rank * world_size, (rank + 1) * world_size)
|
|
]
|
|
output_tensor_list = [torch.empty_like(tensor) for tensor in input_tensor_list]
|
|
dist.all_to_all(output_tensor_list, input_tensor_list)
|
|
expected_tensor_list = [
|
|
torch.ones(3, 3) * x
|
|
for x in range(rank, world_size * world_size, world_size)
|
|
]
|
|
self.assertEqual(expected_tensor_list, output_tensor_list)
|
|
|
|
def test_all_reduce_ops(self):
|
|
tensor = torch.tensor([dist.get_rank() + 1])
|
|
dist.all_reduce(tensor, op=ReduceOp.PRODUCT)
|
|
expected = reduce(operator.mul, range(1, self.world_size + 1))
|
|
self.assertEqual(expected, tensor.item())
|
|
|
|
tensor = torch.tensor([dist.get_rank() + 1])
|
|
dist.all_reduce(tensor, op=ReduceOp.MIN)
|
|
self.assertEqual(1, tensor.item())
|
|
|
|
tensor = torch.tensor([dist.get_rank() + 1])
|
|
dist.all_reduce(tensor, op=ReduceOp.MAX)
|
|
self.assertEqual(self.world_size, tensor.item())
|
|
|
|
tensor = torch.tensor([dist.get_rank() + 1])
|
|
dist.all_reduce(tensor, op=ReduceOp.BAND)
|
|
expected = reduce(operator.and_, range(1, self.world_size + 1))
|
|
self.assertEqual(expected, tensor.item())
|
|
|
|
tensor = torch.tensor([dist.get_rank() + 1])
|
|
dist.all_reduce(tensor, op=ReduceOp.BOR)
|
|
expected = reduce(operator.or_, range(1, self.world_size + 1))
|
|
self.assertEqual(expected, tensor.item())
|
|
|
|
tensor = torch.tensor([dist.get_rank() + 1])
|
|
dist.all_reduce(tensor, op=ReduceOp.BXOR)
|
|
expected = reduce(operator.xor, range(1, self.world_size + 1))
|
|
self.assertEqual(expected, tensor.item())
|
|
|
|
def test_assert_equal_on_rank(self):
|
|
# RNG is shared across threads. So instead of asserting on all threads
|
|
# we only assert on rank 0
|
|
self_tensor = torch.rand(3, 3)
|
|
rank_0_tensor = self_tensor.clone()
|
|
dist.broadcast(rank_0_tensor, src=0)
|
|
self.assertEqualOnRank(rank_0_tensor, self_tensor, rank=0)
|
|
self.assertNotEqualOnRank(rank_0_tensor, self_tensor, rank=1)
|
|
|
|
def test_subpg(self):
|
|
subpg0 = dist.new_group([0, 1])
|
|
subpg1 = dist.new_group([2, 3])
|
|
current_rank = dist.get_rank()
|
|
output = torch.ones(3, 3) * current_rank
|
|
|
|
# call all_reduce on subpg0 and subpg1 concurrently
|
|
if current_rank in [0, 1]:
|
|
dist.all_reduce(output, group=subpg0)
|
|
else:
|
|
dist.all_reduce(output, group=subpg1)
|
|
|
|
if current_rank in [0, 1]:
|
|
self.assertEqual(output, torch.ones(3, 3) * 1)
|
|
else:
|
|
self.assertEqual(output, torch.ones(3, 3) * 5)
|
|
|
|
def test_using_pg_from_another_thread(self):
|
|
def stuff_in_other_thread(pg):
|
|
x = torch.rand(4, requires_grad=True)
|
|
dist.all_reduce(x, group=pg)
|
|
|
|
t = threading.Thread(target=stuff_in_other_thread, args=(dist.group.WORLD,))
|
|
t.start()
|
|
t.join()
|
|
|
|
def test_gather(self):
|
|
if dist.get_rank() == 0:
|
|
gather_list = [torch.empty(3, 3) for _ in range(self.world_size)]
|
|
else:
|
|
gather_list = None
|
|
input_tensor = torch.ones(3, 3) * dist.get_rank()
|
|
|
|
dist.gather(input_tensor, gather_list)
|
|
if dist.get_rank() == 0:
|
|
for i in range(self.world_size):
|
|
self.assertEqual(gather_list[i], torch.ones(3, 3) * i)
|
|
|
|
def test_all_reduce_coalesced(self):
|
|
t0 = torch.ones(3, 3) * dist.get_rank()
|
|
t1 = torch.ones(3, 3) * dist.get_rank() * 2
|
|
dist.all_reduce_coalesced([t0, t1])
|
|
res_num = ((0 + self.world_size - 1) * self.world_size) / 2
|
|
self.assertEqual(t0, torch.ones(3, 3) * res_num)
|
|
self.assertEqual(t1, torch.ones(3, 3) * (res_num * 2))
|
|
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_bwd_sees_fwd_pg(self):
|
|
fwd_tid = threading.current_thread().ident
|
|
|
|
class MyFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, rank):
|
|
result = rank * 2
|
|
|
|
ctx.save_for_backward(result, rank)
|
|
assert int(rank.item()) == dist.get_rank()
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
result, rank = ctx.saved_tensors
|
|
bwd_tid = threading.current_thread().ident
|
|
|
|
self.assertEqual(
|
|
fwd_tid,
|
|
bwd_tid,
|
|
f"bwd not running in the same thread a fwd for rank {rank.item()}",
|
|
)
|
|
self.assertTrue(dist.is_initialized())
|
|
self.assertEqual(int(rank.item()), dist.get_rank())
|
|
dist.all_reduce(result)
|
|
self.assertEqual(int(result.item()), 12) # (0 + 1 + 2 + 3) * 2
|
|
|
|
return grad_output * result
|
|
|
|
x = torch.tensor(
|
|
[dist.get_rank()], dtype=torch.float, device=device_type, requires_grad=True
|
|
)
|
|
x = MyFunc.apply(x)
|
|
x.sum().backward()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|