mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
826 lines
34 KiB
Python
826 lines
34 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import itertools
|
|
|
|
import torch
|
|
from torch.distributed.tensor import (
|
|
DeviceMesh,
|
|
distribute_tensor,
|
|
DTensor,
|
|
init_device_mesh,
|
|
Partial,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorConverter,
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
class DistTensorOpsTest(DTensorTestBase):
|
|
@with_comms
|
|
def test_aten_contiguous(self):
|
|
# this op not covered by dtensor_ops
|
|
mesh = self.build_device_mesh()
|
|
self._test_op(
|
|
mesh,
|
|
lambda x: torch.ops.aten.contiguous(x),
|
|
torch.randn(16, 32),
|
|
)
|
|
|
|
@with_comms
|
|
def test_detach(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
tensor_to_detach = torch.randn(12, 8, requires_grad=True)
|
|
mat = distribute_tensor(tensor_to_detach, device_mesh, shard_spec)
|
|
detached_mat = mat.detach()
|
|
self.assertFalse(detached_mat is mat)
|
|
|
|
@with_comms
|
|
def test_clone(self):
|
|
device_mesh = self.build_device_mesh()
|
|
specs = [[Replicate()], [Shard(0)]]
|
|
tensor_to_clone = torch.randn(12, 8, requires_grad=True)
|
|
for spec in specs:
|
|
mat = distribute_tensor(tensor_to_clone, device_mesh, spec)
|
|
cloned_mat = mat.clone()
|
|
self.assertFalse(cloned_mat is mat)
|
|
self.assertEqual(cloned_mat.to_local(), mat.to_local())
|
|
|
|
@with_comms
|
|
def test_copy_(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
|
|
# basic test
|
|
src_tensor = torch.randn((12, 12))
|
|
dst_tensor = torch.zeros(12, 12)
|
|
src_specs = [[Replicate()], [Shard(0)]]
|
|
dst_specs = [[Replicate()], [Shard(0)]]
|
|
for dst_spec, src_spec in zip(dst_specs, src_specs):
|
|
src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec)
|
|
dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec)
|
|
dst_dtensor.copy_(src_dtensor)
|
|
dst_tensor.copy_(src_tensor)
|
|
self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
|
|
|
|
# simple broadcasting
|
|
src_tensor = torch.randn((128,))
|
|
dst_tensor = torch.zeros(128, 128)
|
|
src_specs = [[Replicate()], [Shard(0)]]
|
|
dst_specs = [[Replicate()], [Shard(1)]]
|
|
for dst_spec, src_spec in zip(dst_specs, src_specs):
|
|
src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec)
|
|
dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec)
|
|
dst_dtensor.copy_(src_dtensor)
|
|
dst_tensor.copy_(src_tensor)
|
|
self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
|
|
|
|
# The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen
|
|
src_tensor = torch.randn((64, 1))
|
|
dst_tensor = torch.zeros(16, 32, 64, 128)
|
|
src_specs = [[Shard(1)], [Shard(1)], [Shard(1)], [Shard(1)]]
|
|
dst_specs = [[Replicate()], [Shard(0)], [Shard(1)], [Shard(2)]]
|
|
for dst_spec, src_spec in zip(dst_specs, src_specs):
|
|
src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec)
|
|
dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec)
|
|
dst_dtensor.copy_(src_dtensor)
|
|
dst_tensor.copy_(src_tensor)
|
|
self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
|
|
|
|
# as a pointwise op, need to keep Partial placements without redistribute
|
|
src_tensor = torch.randn((64, 1))
|
|
dst_tensor = torch.zeros(16, 32, 64, 128)
|
|
src_specs = [[Partial()]]
|
|
dst_specs = [[Partial()]]
|
|
for dst_spec, src_spec in zip(dst_specs, src_specs):
|
|
src_dtensor = DTensor.from_local(src_tensor, device_mesh, src_spec)
|
|
dst_dtensor = DTensor.from_local(dst_tensor, device_mesh, dst_spec)
|
|
dst_dtensor.copy_(src_dtensor)
|
|
dst_tensor.copy_(src_tensor)
|
|
self.assertEqual(dst_dtensor.placements, (Partial(),))
|
|
self.assertEqual(dst_dtensor._local_tensor, dst_tensor)
|
|
|
|
@with_comms
|
|
def test_contiguous(self):
|
|
device_mesh = self.build_device_mesh()
|
|
tensor = torch.rand(3, 5, 6, requires_grad=True)
|
|
sharding = [Shard(0)]
|
|
dist_tensor = DTensor.from_local(tensor, device_mesh, sharding)
|
|
self.assertTrue(dist_tensor.is_contiguous())
|
|
# shard on dim 0 should not change stride (30, 6, 1)
|
|
self.assertEqual(dist_tensor.stride(), tensor.stride())
|
|
|
|
new_dt = dist_tensor.transpose(0, 2)
|
|
self.assertFalse(new_dt.is_contiguous())
|
|
self.assertFalse(new_dt.to_local().is_contiguous())
|
|
# check stride
|
|
self.assertEqual(new_dt.stride(), (1, 6, 30))
|
|
|
|
new_dt = new_dt.contiguous()
|
|
self.assertTrue(new_dt.is_contiguous())
|
|
self.assertTrue(new_dt.to_local().is_contiguous())
|
|
# check stride
|
|
self.assertEqual(dist_tensor.stride(), tensor.stride())
|
|
|
|
# check backward
|
|
new_dt.to_local().sum().backward()
|
|
self.assertEqual(tensor.grad, torch.ones(3, 5, 6))
|
|
|
|
@with_comms
|
|
def test_inplace_op(self):
|
|
mesh = self.build_device_mesh()
|
|
input_tensor = torch.randn((12, 3), device=self.device_type)
|
|
dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)])
|
|
dt_to_mul = dt_to_add.clone()
|
|
expected_add_dt = dt_to_add.clone() + 3
|
|
add_res = dt_to_add.add_(3)
|
|
expected_mul_dt = dt_to_mul.clone() * 3
|
|
mul_res = dt_to_mul.mul_(3)
|
|
# inplace op should be the same instance before and after
|
|
self.assertTrue(add_res is dt_to_add)
|
|
self.assertEqual(add_res.to_local(), expected_add_dt.to_local())
|
|
|
|
self.assertTrue(mul_res is dt_to_mul)
|
|
self.assertEqual(mul_res.to_local(), expected_mul_dt.to_local())
|
|
|
|
# test inplace op self and other dtensor with other specs
|
|
# and make sure out spec not change
|
|
shard_spec = [Shard(0)]
|
|
partial_spec = [Partial()]
|
|
dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec)
|
|
partial_grad = DTensor.from_local(torch.randn(12, 3), mesh, partial_spec)
|
|
res = dt_to_inplace_add.add_(partial_grad)
|
|
self.assertTrue(res is dt_to_inplace_add)
|
|
self.assertTrue(res.placements == tuple(shard_spec))
|
|
|
|
@with_comms
|
|
def test_op_out_variant(self):
|
|
mesh = self.build_device_mesh()
|
|
input_tensor = torch.randn((12, 3), device=self.device_type)
|
|
sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)])
|
|
expected_dt = sharded_dt_input.clone() + 3
|
|
sharded_dt_out = sharded_dt_input.clone()
|
|
res = torch.add(sharded_dt_input, 3, out=sharded_dt_out)
|
|
# op out variant should be the same instance before and after
|
|
self.assertTrue(res is sharded_dt_out)
|
|
self.assertEqual(sharded_dt_out.to_local(), expected_dt.to_local())
|
|
|
|
# test op out variant with other spec and make sure out spec not change
|
|
replica_spec = [Replicate()]
|
|
replicate_out = distribute_tensor(input_tensor, mesh, replica_spec)
|
|
expected_dt = replicate_out.clone() + 3
|
|
res = torch.add(sharded_dt_input, 3, out=replicate_out)
|
|
self.assertTrue(res is replicate_out)
|
|
self.assertTrue(res.placements == tuple(replica_spec))
|
|
self.assertEqual(replicate_out.to_local(), expected_dt.to_local())
|
|
|
|
@with_comms
|
|
def test_empty_like(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
empty_like_dt = torch.empty_like(dist_tensor)
|
|
# empty is not deterministic, so we only check that the shard propagation worked
|
|
self.assertEqual((4, 8), empty_like_dt.to_local().shape)
|
|
|
|
@with_comms
|
|
def test_fill_inplace(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
full_like_dt = torch.fill_(dist_tensor, 42.0)
|
|
full_expected = torch.full((4, 8), 42.0)
|
|
self.assertEqual(full_expected, full_like_dt.to_local())
|
|
self.assertEqual(full_expected, dist_tensor.to_local())
|
|
|
|
@with_comms
|
|
def test_full_like(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
full_like_dt = torch.full_like(dist_tensor, 42.0)
|
|
full_expected = torch.full((4, 8), 42.0)
|
|
self.assertEqual(full_expected, full_like_dt.to_local())
|
|
|
|
@with_comms
|
|
def test_ones_like(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
ones_like_dt = torch.ones_like(dist_tensor)
|
|
ones_expected = torch.ones(4, 8)
|
|
self.assertEqual(ones_expected, ones_like_dt.to_local())
|
|
|
|
@with_comms
|
|
def test_ones_like_partial_sum(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Partial()]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
assert dist_tensor.shape == (4, 8)
|
|
|
|
ones_like_dt = torch.ones_like(dist_tensor)
|
|
ones_expected = torch.ones(dist_tensor.shape)
|
|
self.assertEqual(ones_expected, ones_like_dt.full_tensor())
|
|
|
|
@with_comms
|
|
def test_fill_inplace_partial_sum(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Partial()]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
assert dist_tensor.shape == (4, 8)
|
|
|
|
# inplace partial sum should keep partial
|
|
torch.fill_(dist_tensor, 8)
|
|
fill_expected = torch.full(
|
|
dist_tensor.shape, 8 * self.world_size, dtype=input_tensor.dtype
|
|
)
|
|
self.assertEqual(fill_expected, dist_tensor.full_tensor())
|
|
|
|
@with_comms
|
|
def test_zeros_like_partial_sum(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Partial()]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
assert dist_tensor.shape == (4, 8)
|
|
|
|
zeros_like_dt = torch.zeros_like(dist_tensor)
|
|
zeros_expected = torch.zeros(dist_tensor.shape)
|
|
self.assertEqual(zeros_expected, zeros_like_dt.full_tensor())
|
|
|
|
@with_comms
|
|
def test_zero_inplace(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
zeros_like_dt = torch.zero_(dist_tensor)
|
|
zeros_expected = torch.zeros(4, 8)
|
|
self.assertEqual(zeros_expected, zeros_like_dt.to_local())
|
|
self.assertEqual(zeros_expected, dist_tensor.to_local())
|
|
|
|
@with_comms
|
|
def test_zeros_like(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
input_tensor = torch.randn(4, 8, requires_grad=True)
|
|
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
zeros_like_dt = torch.zeros_like(dist_tensor, dtype=torch.bfloat16)
|
|
zeros_expected = torch.zeros(4, 8, dtype=torch.bfloat16)
|
|
self.assertEqual(zeros_expected, zeros_like_dt.to_local())
|
|
# make sure there is no side effect on the input tensor dtype
|
|
self.assertEqual(dist_tensor.dtype, torch.float32)
|
|
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(4)
|
|
def test_stack(self):
|
|
mesh_2d = DeviceMesh(
|
|
self.device_type, torch.arange(self.world_size).reshape(2, 2)
|
|
)
|
|
partial_replicate_placement = [Partial(), Replicate()]
|
|
partial_placement = [Partial(), Partial()]
|
|
|
|
partial_replicate_dt = DTensor.from_local(
|
|
torch.randn(4, 8), mesh_2d, partial_replicate_placement
|
|
)
|
|
partial_dt = DTensor.from_local(torch.randn(4, 8), mesh_2d, partial_placement)
|
|
|
|
stack_dt = torch.stack([partial_replicate_dt, partial_dt])
|
|
self.assertEqual(stack_dt.placements, tuple(partial_placement))
|
|
self.assertEqual(stack_dt.shape, (2, 4, 8))
|
|
|
|
mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
# stack before/after shard dim
|
|
global_input = torch.randn(8, 8)
|
|
shard1_input = distribute_tensor(global_input, mesh_1d, [Shard(1)])
|
|
cloned_shard1_input = shard1_input.clone()
|
|
stack_shard1_dt = torch.stack([shard1_input, cloned_shard1_input])
|
|
self.assertEqual(stack_shard1_dt.placements, (Shard(2),))
|
|
self.assertEqual(stack_shard1_dt.shape, (2, 8, 8))
|
|
self.assertEqual(
|
|
stack_shard1_dt.full_tensor(), torch.stack([global_input, global_input])
|
|
)
|
|
|
|
stack_dim1_shard1_dt = torch.stack([shard1_input, cloned_shard1_input], dim=1)
|
|
self.assertEqual(stack_dim1_shard1_dt.placements, (Shard(2),))
|
|
self.assertEqual(stack_dim1_shard1_dt.shape, (8, 2, 8))
|
|
self.assertEqual(
|
|
stack_dim1_shard1_dt.full_tensor(),
|
|
torch.stack([global_input, global_input], dim=1),
|
|
)
|
|
|
|
@with_comms
|
|
def test_equal(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
input_tensor_1 = torch.ones(4, 4)
|
|
dist_tensor_1 = DTensor.from_local(input_tensor_1, device_mesh, shard_spec)
|
|
|
|
# tensors are equal
|
|
input_tensor_2 = torch.ones(4, 4)
|
|
dist_tensor_2 = DTensor.from_local(input_tensor_2, device_mesh, shard_spec)
|
|
|
|
eq_result = dist_tensor_1.equal(dist_tensor_2)
|
|
self.assertTrue(eq_result)
|
|
|
|
# tensors are different on some shards
|
|
if self.rank == 0:
|
|
input_tensor_2 = torch.ones(4, 4)
|
|
else:
|
|
input_tensor_2 = torch.randn(4, 4)
|
|
dist_tensor_2 = DTensor.from_local(input_tensor_2, device_mesh, shard_spec)
|
|
|
|
eq_result = dist_tensor_1.equal(dist_tensor_2)
|
|
# equal op all reduces each shard's local result
|
|
self.assertFalse(eq_result)
|
|
self.assertTrue(dist_tensor_1.is_same_size(dist_tensor_2))
|
|
|
|
# test if sharding are different
|
|
replica_spec = [Replicate()]
|
|
global_input = torch.ones(4 * self.world_size, 4)
|
|
dist_tensor_3 = DTensor.from_local(
|
|
global_input, device_mesh, replica_spec, run_check=False
|
|
)
|
|
|
|
self.assertTrue(dist_tensor_1.equal(dist_tensor_3))
|
|
self.assertTrue(dist_tensor_1.is_same_size(dist_tensor_3))
|
|
|
|
# test sharding difference with only some shards content difference
|
|
self.assertFalse(dist_tensor_2.equal(dist_tensor_3))
|
|
self.assertTrue(dist_tensor_1.is_same_size(dist_tensor_3))
|
|
self.assertFalse(input_tensor_2.is_same_size(dist_tensor_3))
|
|
|
|
def _test_op(self, mesh, op_call, *args, **kwargs):
|
|
out = op_call(*args, **kwargs)
|
|
dtc = DTensorConverter(mesh, args, kwargs)
|
|
for d_args, d_kwargs in dtc:
|
|
self.assertTrue(dtc.successful())
|
|
d_out = op_call(*d_args, **d_kwargs)
|
|
self.assertEqual(d_out.full_tensor(), out)
|
|
|
|
@with_comms
|
|
def test_new_full(self):
|
|
device_mesh = self.build_device_mesh()
|
|
comm_mode = CommDebugMode()
|
|
|
|
global_tensor = torch.randn(12, 8)
|
|
placements = [[Shard(0)], [Replicate()]]
|
|
for placement in placements:
|
|
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
|
|
with comm_mode:
|
|
new_full_diff_dt = input_dt.new_full((4, 8), 42.0)
|
|
# new_full_diff_dt creates a replicated tensor, regardless of input_dt placement,
|
|
# which should not trigger any communication.
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
new_full_diff_expected = torch.full((4, 8), 42.0)
|
|
self.assertTrue(new_full_diff_dt.placements[0].is_replicate())
|
|
self.assertEqual(new_full_diff_expected, new_full_diff_dt.to_local())
|
|
|
|
with comm_mode:
|
|
new_full_same_dt = input_dt.new_full((12, 8), 42.0)
|
|
# new_full_same_dt creates a tensor with the same placement as input_dt,
|
|
# which should not trigger any communication.
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
new_full_same_expected = torch.full((12, 8), 42.0)
|
|
self.assertEqual(new_full_same_dt.placements, placement)
|
|
self.assertEqual(new_full_same_expected, new_full_same_dt.full_tensor())
|
|
|
|
@with_comms
|
|
def test_new_empty_strided(self):
|
|
device_mesh = self.build_device_mesh()
|
|
comm_mode = CommDebugMode()
|
|
|
|
shard_dim = 1
|
|
placement = (Shard(shard_dim),)
|
|
|
|
# output shape same as input shape, evenly sharded input -> output same sharding as input
|
|
global_tensor = torch.randn(12, 8)
|
|
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
|
|
self.assertTrue(input_dt.shape[shard_dim] % self.world_size == 0)
|
|
with comm_mode:
|
|
new_empty_strided_dt = input_dt.new_empty_strided((12, 8), (8, 1))
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(new_empty_strided_dt.placements, placement)
|
|
self.assertEqual(
|
|
new_empty_strided_dt._local_tensor.size(), (12, 8 // self.world_size)
|
|
)
|
|
self.assertEqual(
|
|
new_empty_strided_dt._local_tensor.stride(), (8 // self.world_size, 1)
|
|
)
|
|
self.assertTrue(new_empty_strided_dt.contiguous() is new_empty_strided_dt)
|
|
|
|
# output shape same as input shape, unevenly sharded input -> output replicated
|
|
global_tensor = torch.randn(12, 7)
|
|
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
|
|
self.assertTrue(input_dt.shape[shard_dim] % self.world_size != 0)
|
|
with comm_mode:
|
|
new_empty_strided_dt = input_dt.new_empty_strided((12, 7), (7, 1))
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(new_empty_strided_dt.placements, (Replicate(),))
|
|
self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 7))
|
|
self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (7, 1))
|
|
|
|
# output shape different from input shape -> output replicated
|
|
global_tensor = torch.randn(12, 8)
|
|
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
|
|
with comm_mode:
|
|
new_empty_strided_dt = input_dt.new_empty_strided((12, 4), (4, 1))
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(new_empty_strided_dt.placements, (Replicate(),))
|
|
self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 4))
|
|
self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (4, 1))
|
|
|
|
@with_comms
|
|
def test_scatter(self):
|
|
device_mesh = self.build_device_mesh()
|
|
comm_mode = CommDebugMode()
|
|
|
|
# case 1 all replicate: input replicated, index/src replicated, output replicated
|
|
global_indexs = [
|
|
torch.tensor([[0, 1, 2, 0]]),
|
|
torch.tensor([[0, 1, 2], [0, 1, 4]]),
|
|
]
|
|
for scatter_dim in [0, 1]:
|
|
srcs = [torch.arange(1, 11).reshape((2, 5)), 4]
|
|
for global_src in srcs:
|
|
global_input = torch.zeros(3, 5, dtype=torch.int64)
|
|
global_index = global_indexs[scatter_dim]
|
|
|
|
input_dt = distribute_tensor(
|
|
global_input.clone(), device_mesh, [Replicate()]
|
|
)
|
|
index_dt = distribute_tensor(global_index, device_mesh, [Replicate()])
|
|
if isinstance(global_src, torch.Tensor):
|
|
src_dt = distribute_tensor(global_src, device_mesh, [Replicate()])
|
|
else:
|
|
src_dt = global_src
|
|
global_output = torch.scatter(
|
|
global_input, scatter_dim, global_index, global_src
|
|
)
|
|
with comm_mode:
|
|
output_dt = torch.scatter(input_dt, scatter_dim, index_dt, src_dt)
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(output_dt.placements, [Replicate()])
|
|
self.assertEqual(output_dt.to_local(), global_output)
|
|
|
|
@with_comms
|
|
def test_gather(self):
|
|
device_mesh = self.build_device_mesh()
|
|
comm_mode = CommDebugMode()
|
|
|
|
# case 1 all replicate: input replicated, index replicated, output replicated
|
|
global_input = torch.randn(12, 8, 16)
|
|
global_index = torch.randint(8, (4, 4, 8))
|
|
input_dt = distribute_tensor(global_input, device_mesh, [Replicate()])
|
|
index_dt = distribute_tensor(global_index, device_mesh, [Replicate()])
|
|
for gather_dim in [0, 1, 2]:
|
|
global_output = torch.gather(global_input, gather_dim, global_index)
|
|
with comm_mode:
|
|
output_dt = torch.gather(input_dt, gather_dim, index_dt)
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(output_dt.placements, [Replicate()])
|
|
self.assertEqual(output_dt.to_local(), global_output)
|
|
|
|
# case 2 input sharding: input sharded, index replicated, output mask partial
|
|
# only works when index has size 1 on the gather dimension and
|
|
# input is sharded on the gather dimension
|
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
|
|
|
gather_dim = 1
|
|
global_input = torch.randn(12, 8, 16)
|
|
global_index = torch.randint(8, (4, 1, 8))
|
|
global_output = torch.gather(global_input, gather_dim, global_index)
|
|
input_dt = distribute_tensor(global_input, device_mesh, [Shard(gather_dim)])
|
|
index_dt = distribute_tensor(global_index, device_mesh, [Replicate()])
|
|
with comm_mode:
|
|
output_dt = torch.gather(input_dt, gather_dim, index_dt)
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertIsInstance(output_dt.placements[0], _MaskPartial)
|
|
self.assertEqual(output_dt.full_tensor(), global_output)
|
|
|
|
# case 3 index sharding: input replicated, index sharded, output sharded
|
|
# only works when the sharding dimension is the gather dimension
|
|
global_input = torch.randn(12, 8, 16)
|
|
global_index = torch.randint(8, (4, 4, 8))
|
|
for gather_dim in range(len(global_index.shape)):
|
|
input_dt = distribute_tensor(global_input, device_mesh, [Replicate()])
|
|
index_dt = distribute_tensor(global_index, device_mesh, [Shard(gather_dim)])
|
|
global_output = torch.gather(global_input, gather_dim, global_index)
|
|
with comm_mode:
|
|
output_dt = torch.gather(input_dt, gather_dim, index_dt)
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(output_dt.placements, [Shard(gather_dim)])
|
|
self.assertEqual(output_dt.full_tensor(), global_output)
|
|
|
|
@with_comms
|
|
def test_index(self):
|
|
meshes = [
|
|
self.build_device_mesh(), # 1D mesh
|
|
# TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh
|
|
# DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh
|
|
]
|
|
for mesh in meshes:
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y: x[y],
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (4, 8)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y: x.index_select(1, y),
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (4,)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y: x.index_select(0, y),
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (4,)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y: x[y],
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (12,)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y: x[:, y],
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (4, 8)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y: x[..., y],
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (4, 12)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y: x[..., y],
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (4, 8, 16)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[z, y],
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (12, 8, 12)),
|
|
torch.randint(2, (12, 8, 12)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[z, :, y],
|
|
torch.randn(16, 32, 16),
|
|
torch.randint(5, (12, 8, 12)),
|
|
torch.randint(2, (12, 8, 12)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[:, z, :, y],
|
|
torch.randn(16, 32, 16, 12),
|
|
torch.randint(5, (12, 8, 12)),
|
|
torch.randint(2, (12, 8, 12)),
|
|
)
|
|
# broadcast in inner dimensions
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[:, z, :, y],
|
|
torch.randn(16, 32, 16, 12),
|
|
torch.randint(5, (12, 8, 12)),
|
|
torch.randint(2, (12, 1, 12)),
|
|
)
|
|
# implicit (left-padded) broadcast
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[:, z, :, y],
|
|
torch.randn(16, 32, 16, 12),
|
|
torch.randint(5, (12, 8, 12)),
|
|
torch.randint(2, (8, 12)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[z, y, :, :],
|
|
torch.randn(16, 32, 16, 12),
|
|
torch.randint(2, (8, 12)),
|
|
torch.randint(5, (12, 8, 12)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[z, :, y, :],
|
|
torch.randn(16, 32, 16, 12),
|
|
torch.randint(2, (8, 12)),
|
|
torch.randint(5, (12, 8, 12)),
|
|
)
|
|
self._test_op(
|
|
mesh,
|
|
lambda x, y, z: x[z, :, :, y],
|
|
torch.randn(16, 32, 16, 12),
|
|
torch.randint(2, (8, 1)),
|
|
torch.randint(5, (12, 8, 12)),
|
|
)
|
|
|
|
@with_comms
|
|
def test_index_put_scalar(self):
|
|
device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
|
global_input = torch.randn(2, 4, 8, device=self.device_type)
|
|
global_index = [
|
|
torch.randint(global_input.shape[i], size=(), device=self.device_type)
|
|
for i in range(3)
|
|
]
|
|
global_value = torch.randn(size=(), device=self.device_type)
|
|
value_dt = distribute_tensor(
|
|
global_value, device_mesh, [Replicate(), Replicate()]
|
|
)
|
|
placement_choice_pool = [Shard(0), Shard(1), Replicate()]
|
|
for i in placement_choice_pool:
|
|
for j in placement_choice_pool:
|
|
input_dt = distribute_tensor(global_input, device_mesh, [i, j])
|
|
ref = torch.index_put(global_input, global_index, global_value)
|
|
output_dt = torch.index_put(input_dt, global_index, value_dt)
|
|
assert isinstance(output_dt, DTensor)
|
|
# for value is a scalar case, output placement must be replicate
|
|
self.assertEqual(output_dt.placements, (Replicate(), Replicate()))
|
|
self.assertEqual(output_dt.full_tensor(), ref)
|
|
|
|
@with_comms
|
|
def test_index_put_tensor(self):
|
|
device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
|
global_input = torch.randn(2, 4, 8, device=self.device_type)
|
|
global_index = [
|
|
torch.randint(global_input.shape[0], size=(), device=self.device_type)
|
|
]
|
|
global_value = torch.zeros([4, 8], device=self.device_type)
|
|
value_dt = distribute_tensor(global_value, device_mesh, [Shard(1), Replicate()])
|
|
input_dt = distribute_tensor(global_input, device_mesh, [Shard(0), Replicate()])
|
|
for accumulate in [True, False]:
|
|
ref = torch.index_put(global_input, global_index, global_value, accumulate)
|
|
output_dt = torch.index_put(input_dt, global_index, value_dt, accumulate)
|
|
assert isinstance(output_dt, DTensor)
|
|
# `input_dt` follows `value_dt`'s Shard(1) plus a offset value of
|
|
# global_value.ndim-global_input.ndim, which results in Shard(2)
|
|
self.assertEqual(output_dt.placements, (Shard(2), Replicate()))
|
|
self.assertEqual(output_dt.full_tensor(), ref)
|
|
|
|
@with_comms
|
|
def test_where_type_promotion(self):
|
|
mesh = self.build_device_mesh() # 1D mesh
|
|
|
|
specs = [[Shard(0)], [Replicate()]]
|
|
for spec in specs:
|
|
global_tensor = torch.randn(12, 8)
|
|
mat = distribute_tensor(global_tensor, mesh, spec)
|
|
res = torch.where(mat > 0, 1, 0)
|
|
ref = torch.where(global_tensor > 0, 1, 0)
|
|
self.assertEqual(res.full_tensor(), ref)
|
|
|
|
@with_comms
|
|
def test_dtensor_dtype_conversion(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
# by default we start from bf16 dtype
|
|
local_tenor = torch.randn(2, 8, dtype=torch.bfloat16)
|
|
bf16_sharded_dtensor = DTensor.from_local(local_tenor, device_mesh, shard_spec)
|
|
self.assertEqual(bf16_sharded_dtensor.dtype, torch.bfloat16)
|
|
self.assertEqual(bf16_sharded_dtensor.to_local().dtype, torch.bfloat16)
|
|
|
|
# convert to float dtype
|
|
fp32_sharded_dtensor = bf16_sharded_dtensor.float()
|
|
self.assertEqual(fp32_sharded_dtensor.dtype, torch.float32)
|
|
self.assertEqual(fp32_sharded_dtensor.to_local().dtype, torch.float32)
|
|
|
|
# convert back to bf16 dtype
|
|
bf16_sharded_dtensor1 = fp32_sharded_dtensor.type_as(bf16_sharded_dtensor)
|
|
self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16)
|
|
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
|
|
|
|
from torch.distributed.tensor.debug import _get_sharding_prop_cache_info
|
|
|
|
# by this point we only have cache misses
|
|
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
|
self.assertEqual(hits, 0)
|
|
self.assertEqual(misses, 2)
|
|
|
|
# convert to fp32 again and see if there's cache hit
|
|
bf16_sharded_dtensor1.float()
|
|
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
|
# by now we should have cache hit
|
|
self.assertEqual(hits, 1)
|
|
self.assertEqual(misses, 2)
|
|
|
|
@with_comms
|
|
def test_slice(self):
|
|
mesh = self.build_device_mesh() # 1D mesh
|
|
comm_mode = CommDebugMode()
|
|
|
|
shard_spec = [Shard(1)]
|
|
global_tensor = torch.randn(8, 16, requires_grad=True)
|
|
sharded_dtensor = distribute_tensor(global_tensor, mesh, shard_spec)
|
|
|
|
global_out = global_tensor[:, 8:]
|
|
with comm_mode:
|
|
sharded_out = sharded_dtensor[:, 8:]
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
|
|
global_out.backward(gradient=torch.ones_like(global_out))
|
|
with comm_mode:
|
|
sharded_out_grad = torch.distributed.tensor.ones(
|
|
sharded_out.shape, device_mesh=mesh, placements=shard_spec
|
|
)
|
|
sharded_out.backward(gradient=sharded_out_grad)
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
|
|
self.assertEqual(sharded_out.full_tensor(), global_out)
|
|
self.assertEqual(sharded_dtensor.grad.full_tensor(), global_tensor.grad)
|
|
|
|
@with_comms
|
|
def test_split_on_partial(self):
|
|
self.run_subtests(
|
|
{
|
|
"reduce_op": ["sum", "avg", "product", "min", "max"],
|
|
"split_size": [2, 3, 4],
|
|
"split_dim": [0, 1],
|
|
},
|
|
self._test_split_on_partial,
|
|
)
|
|
|
|
def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int):
|
|
torch.manual_seed(self.rank)
|
|
mesh = self.build_device_mesh()
|
|
|
|
partial_tensor = torch.randn(8, 8, device=self.device_type)
|
|
partial_dt = DTensor.from_local(
|
|
local_tensor=partial_tensor,
|
|
device_mesh=mesh,
|
|
placements=[Partial(reduce_op=reduce_op)],
|
|
)
|
|
self._test_op_on_dtensor(
|
|
torch.split,
|
|
partial_dt,
|
|
split_size,
|
|
dim=split_dim,
|
|
)
|
|
|
|
@with_comms
|
|
def test_unbind(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_dims = [0, 1]
|
|
unbind_dims = [0, 1]
|
|
local_tensor = torch.randn(4, 8, requires_grad=True)
|
|
for shard_dim, unbind_dim in itertools.product(shard_dims, unbind_dims):
|
|
dist_tensor = distribute_tensor(
|
|
local_tensor, device_mesh, (Shard(shard_dim),)
|
|
)
|
|
|
|
if shard_dim == unbind_dim:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Sharding propagation failed"
|
|
):
|
|
dist_tensor.unbind(dim=unbind_dim)
|
|
else:
|
|
unbinded_dist_tensors = dist_tensor.unbind(dim=unbind_dim)
|
|
new_shard_dim = shard_dim if shard_dim < unbind_dim else shard_dim - 1
|
|
self.assertTrue(
|
|
all(
|
|
elem.placements[0].is_shard(dim=new_shard_dim)
|
|
for elem in unbinded_dist_tensors
|
|
)
|
|
)
|
|
for x, y in zip(
|
|
unbinded_dist_tensors, local_tensor.unbind(dim=unbind_dim)
|
|
):
|
|
self.assertEqual(x.full_tensor(), y)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|