Files
pytorch/test/distributed/tensor/test_tensor_ops.py
Tianyu Liu 435c18fb4a [DTensor] add op support for aten.unbind.int (#162560)
As titled.

It seems unbind returns views of the original tensor. E.g. see https://stackoverflow.com/questions/78910951/does-unbind-return-the-views-of-tensors-in-pytorch

So we error out when `shard_dim == unbind_dim`. This is similar to why we error out in view ops.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_view_ops.py#L544-L546

This PR also refactors some other tensor ops code, by creating two utils function `shift_shard_dims_after_insert`, `shift_shard_dims_after_remove`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162560
Approved by: https://github.com/zpcore
2025-09-11 00:58:23 +00:00

827 lines
34 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
import torch
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
init_device_mesh,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorConverter,
DTensorTestBase,
with_comms,
)
class DistTensorOpsTest(DTensorTestBase):
@with_comms
def test_aten_contiguous(self):
# this op not covered by dtensor_ops
mesh = self.build_device_mesh()
self._test_op(
mesh,
lambda x: torch.ops.aten.contiguous(x),
torch.randn(16, 32),
)
@with_comms
def test_detach(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
tensor_to_detach = torch.randn(12, 8, requires_grad=True)
mat = distribute_tensor(tensor_to_detach, device_mesh, shard_spec)
detached_mat = mat.detach()
self.assertFalse(detached_mat is mat)
@with_comms
def test_clone(self):
device_mesh = self.build_device_mesh()
specs = [[Replicate()], [Shard(0)]]
tensor_to_clone = torch.randn(12, 8, requires_grad=True)
for spec in specs:
mat = distribute_tensor(tensor_to_clone, device_mesh, spec)
cloned_mat = mat.clone()
self.assertFalse(cloned_mat is mat)
self.assertEqual(cloned_mat.to_local(), mat.to_local())
@with_comms
def test_copy_(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
# basic test
src_tensor = torch.randn((12, 12))
dst_tensor = torch.zeros(12, 12)
src_specs = [[Replicate()], [Shard(0)]]
dst_specs = [[Replicate()], [Shard(0)]]
for dst_spec, src_spec in zip(dst_specs, src_specs):
src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec)
dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec)
dst_dtensor.copy_(src_dtensor)
dst_tensor.copy_(src_tensor)
self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
# simple broadcasting
src_tensor = torch.randn((128,))
dst_tensor = torch.zeros(128, 128)
src_specs = [[Replicate()], [Shard(0)]]
dst_specs = [[Replicate()], [Shard(1)]]
for dst_spec, src_spec in zip(dst_specs, src_specs):
src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec)
dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec)
dst_dtensor.copy_(src_dtensor)
dst_tensor.copy_(src_tensor)
self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
# The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen
src_tensor = torch.randn((64, 1))
dst_tensor = torch.zeros(16, 32, 64, 128)
src_specs = [[Shard(1)], [Shard(1)], [Shard(1)], [Shard(1)]]
dst_specs = [[Replicate()], [Shard(0)], [Shard(1)], [Shard(2)]]
for dst_spec, src_spec in zip(dst_specs, src_specs):
src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec)
dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec)
dst_dtensor.copy_(src_dtensor)
dst_tensor.copy_(src_tensor)
self.assertEqual(dst_dtensor.full_tensor(), dst_tensor)
# as a pointwise op, need to keep Partial placements without redistribute
src_tensor = torch.randn((64, 1))
dst_tensor = torch.zeros(16, 32, 64, 128)
src_specs = [[Partial()]]
dst_specs = [[Partial()]]
for dst_spec, src_spec in zip(dst_specs, src_specs):
src_dtensor = DTensor.from_local(src_tensor, device_mesh, src_spec)
dst_dtensor = DTensor.from_local(dst_tensor, device_mesh, dst_spec)
dst_dtensor.copy_(src_dtensor)
dst_tensor.copy_(src_tensor)
self.assertEqual(dst_dtensor.placements, (Partial(),))
self.assertEqual(dst_dtensor._local_tensor, dst_tensor)
@with_comms
def test_contiguous(self):
device_mesh = self.build_device_mesh()
tensor = torch.rand(3, 5, 6, requires_grad=True)
sharding = [Shard(0)]
dist_tensor = DTensor.from_local(tensor, device_mesh, sharding)
self.assertTrue(dist_tensor.is_contiguous())
# shard on dim 0 should not change stride (30, 6, 1)
self.assertEqual(dist_tensor.stride(), tensor.stride())
new_dt = dist_tensor.transpose(0, 2)
self.assertFalse(new_dt.is_contiguous())
self.assertFalse(new_dt.to_local().is_contiguous())
# check stride
self.assertEqual(new_dt.stride(), (1, 6, 30))
new_dt = new_dt.contiguous()
self.assertTrue(new_dt.is_contiguous())
self.assertTrue(new_dt.to_local().is_contiguous())
# check stride
self.assertEqual(dist_tensor.stride(), tensor.stride())
# check backward
new_dt.to_local().sum().backward()
self.assertEqual(tensor.grad, torch.ones(3, 5, 6))
@with_comms
def test_inplace_op(self):
mesh = self.build_device_mesh()
input_tensor = torch.randn((12, 3), device=self.device_type)
dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)])
dt_to_mul = dt_to_add.clone()
expected_add_dt = dt_to_add.clone() + 3
add_res = dt_to_add.add_(3)
expected_mul_dt = dt_to_mul.clone() * 3
mul_res = dt_to_mul.mul_(3)
# inplace op should be the same instance before and after
self.assertTrue(add_res is dt_to_add)
self.assertEqual(add_res.to_local(), expected_add_dt.to_local())
self.assertTrue(mul_res is dt_to_mul)
self.assertEqual(mul_res.to_local(), expected_mul_dt.to_local())
# test inplace op self and other dtensor with other specs
# and make sure out spec not change
shard_spec = [Shard(0)]
partial_spec = [Partial()]
dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec)
partial_grad = DTensor.from_local(torch.randn(12, 3), mesh, partial_spec)
res = dt_to_inplace_add.add_(partial_grad)
self.assertTrue(res is dt_to_inplace_add)
self.assertTrue(res.placements == tuple(shard_spec))
@with_comms
def test_op_out_variant(self):
mesh = self.build_device_mesh()
input_tensor = torch.randn((12, 3), device=self.device_type)
sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)])
expected_dt = sharded_dt_input.clone() + 3
sharded_dt_out = sharded_dt_input.clone()
res = torch.add(sharded_dt_input, 3, out=sharded_dt_out)
# op out variant should be the same instance before and after
self.assertTrue(res is sharded_dt_out)
self.assertEqual(sharded_dt_out.to_local(), expected_dt.to_local())
# test op out variant with other spec and make sure out spec not change
replica_spec = [Replicate()]
replicate_out = distribute_tensor(input_tensor, mesh, replica_spec)
expected_dt = replicate_out.clone() + 3
res = torch.add(sharded_dt_input, 3, out=replicate_out)
self.assertTrue(res is replicate_out)
self.assertTrue(res.placements == tuple(replica_spec))
self.assertEqual(replicate_out.to_local(), expected_dt.to_local())
@with_comms
def test_empty_like(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
empty_like_dt = torch.empty_like(dist_tensor)
# empty is not deterministic, so we only check that the shard propagation worked
self.assertEqual((4, 8), empty_like_dt.to_local().shape)
@with_comms
def test_fill_inplace(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
full_like_dt = torch.fill_(dist_tensor, 42.0)
full_expected = torch.full((4, 8), 42.0)
self.assertEqual(full_expected, full_like_dt.to_local())
self.assertEqual(full_expected, dist_tensor.to_local())
@with_comms
def test_full_like(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
full_like_dt = torch.full_like(dist_tensor, 42.0)
full_expected = torch.full((4, 8), 42.0)
self.assertEqual(full_expected, full_like_dt.to_local())
@with_comms
def test_ones_like(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
ones_like_dt = torch.ones_like(dist_tensor)
ones_expected = torch.ones(4, 8)
self.assertEqual(ones_expected, ones_like_dt.to_local())
@with_comms
def test_ones_like_partial_sum(self):
device_mesh = self.build_device_mesh()
shard_spec = [Partial()]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
assert dist_tensor.shape == (4, 8)
ones_like_dt = torch.ones_like(dist_tensor)
ones_expected = torch.ones(dist_tensor.shape)
self.assertEqual(ones_expected, ones_like_dt.full_tensor())
@with_comms
def test_fill_inplace_partial_sum(self):
device_mesh = self.build_device_mesh()
shard_spec = [Partial()]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
assert dist_tensor.shape == (4, 8)
# inplace partial sum should keep partial
torch.fill_(dist_tensor, 8)
fill_expected = torch.full(
dist_tensor.shape, 8 * self.world_size, dtype=input_tensor.dtype
)
self.assertEqual(fill_expected, dist_tensor.full_tensor())
@with_comms
def test_zeros_like_partial_sum(self):
device_mesh = self.build_device_mesh()
shard_spec = [Partial()]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
assert dist_tensor.shape == (4, 8)
zeros_like_dt = torch.zeros_like(dist_tensor)
zeros_expected = torch.zeros(dist_tensor.shape)
self.assertEqual(zeros_expected, zeros_like_dt.full_tensor())
@with_comms
def test_zero_inplace(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
zeros_like_dt = torch.zero_(dist_tensor)
zeros_expected = torch.zeros(4, 8)
self.assertEqual(zeros_expected, zeros_like_dt.to_local())
self.assertEqual(zeros_expected, dist_tensor.to_local())
@with_comms
def test_zeros_like(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_tensor = torch.randn(4, 8, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
zeros_like_dt = torch.zeros_like(dist_tensor, dtype=torch.bfloat16)
zeros_expected = torch.zeros(4, 8, dtype=torch.bfloat16)
self.assertEqual(zeros_expected, zeros_like_dt.to_local())
# make sure there is no side effect on the input tensor dtype
self.assertEqual(dist_tensor.dtype, torch.float32)
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
@with_comms
@skip_if_lt_x_gpu(4)
def test_stack(self):
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 2)
)
partial_replicate_placement = [Partial(), Replicate()]
partial_placement = [Partial(), Partial()]
partial_replicate_dt = DTensor.from_local(
torch.randn(4, 8), mesh_2d, partial_replicate_placement
)
partial_dt = DTensor.from_local(torch.randn(4, 8), mesh_2d, partial_placement)
stack_dt = torch.stack([partial_replicate_dt, partial_dt])
self.assertEqual(stack_dt.placements, tuple(partial_placement))
self.assertEqual(stack_dt.shape, (2, 4, 8))
mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size))
# stack before/after shard dim
global_input = torch.randn(8, 8)
shard1_input = distribute_tensor(global_input, mesh_1d, [Shard(1)])
cloned_shard1_input = shard1_input.clone()
stack_shard1_dt = torch.stack([shard1_input, cloned_shard1_input])
self.assertEqual(stack_shard1_dt.placements, (Shard(2),))
self.assertEqual(stack_shard1_dt.shape, (2, 8, 8))
self.assertEqual(
stack_shard1_dt.full_tensor(), torch.stack([global_input, global_input])
)
stack_dim1_shard1_dt = torch.stack([shard1_input, cloned_shard1_input], dim=1)
self.assertEqual(stack_dim1_shard1_dt.placements, (Shard(2),))
self.assertEqual(stack_dim1_shard1_dt.shape, (8, 2, 8))
self.assertEqual(
stack_dim1_shard1_dt.full_tensor(),
torch.stack([global_input, global_input], dim=1),
)
@with_comms
def test_equal(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_tensor_1 = torch.ones(4, 4)
dist_tensor_1 = DTensor.from_local(input_tensor_1, device_mesh, shard_spec)
# tensors are equal
input_tensor_2 = torch.ones(4, 4)
dist_tensor_2 = DTensor.from_local(input_tensor_2, device_mesh, shard_spec)
eq_result = dist_tensor_1.equal(dist_tensor_2)
self.assertTrue(eq_result)
# tensors are different on some shards
if self.rank == 0:
input_tensor_2 = torch.ones(4, 4)
else:
input_tensor_2 = torch.randn(4, 4)
dist_tensor_2 = DTensor.from_local(input_tensor_2, device_mesh, shard_spec)
eq_result = dist_tensor_1.equal(dist_tensor_2)
# equal op all reduces each shard's local result
self.assertFalse(eq_result)
self.assertTrue(dist_tensor_1.is_same_size(dist_tensor_2))
# test if sharding are different
replica_spec = [Replicate()]
global_input = torch.ones(4 * self.world_size, 4)
dist_tensor_3 = DTensor.from_local(
global_input, device_mesh, replica_spec, run_check=False
)
self.assertTrue(dist_tensor_1.equal(dist_tensor_3))
self.assertTrue(dist_tensor_1.is_same_size(dist_tensor_3))
# test sharding difference with only some shards content difference
self.assertFalse(dist_tensor_2.equal(dist_tensor_3))
self.assertTrue(dist_tensor_1.is_same_size(dist_tensor_3))
self.assertFalse(input_tensor_2.is_same_size(dist_tensor_3))
def _test_op(self, mesh, op_call, *args, **kwargs):
out = op_call(*args, **kwargs)
dtc = DTensorConverter(mesh, args, kwargs)
for d_args, d_kwargs in dtc:
self.assertTrue(dtc.successful())
d_out = op_call(*d_args, **d_kwargs)
self.assertEqual(d_out.full_tensor(), out)
@with_comms
def test_new_full(self):
device_mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
global_tensor = torch.randn(12, 8)
placements = [[Shard(0)], [Replicate()]]
for placement in placements:
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
with comm_mode:
new_full_diff_dt = input_dt.new_full((4, 8), 42.0)
# new_full_diff_dt creates a replicated tensor, regardless of input_dt placement,
# which should not trigger any communication.
self.assertEqual(comm_mode.get_total_counts(), 0)
new_full_diff_expected = torch.full((4, 8), 42.0)
self.assertTrue(new_full_diff_dt.placements[0].is_replicate())
self.assertEqual(new_full_diff_expected, new_full_diff_dt.to_local())
with comm_mode:
new_full_same_dt = input_dt.new_full((12, 8), 42.0)
# new_full_same_dt creates a tensor with the same placement as input_dt,
# which should not trigger any communication.
self.assertEqual(comm_mode.get_total_counts(), 0)
new_full_same_expected = torch.full((12, 8), 42.0)
self.assertEqual(new_full_same_dt.placements, placement)
self.assertEqual(new_full_same_expected, new_full_same_dt.full_tensor())
@with_comms
def test_new_empty_strided(self):
device_mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
shard_dim = 1
placement = (Shard(shard_dim),)
# output shape same as input shape, evenly sharded input -> output same sharding as input
global_tensor = torch.randn(12, 8)
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
self.assertTrue(input_dt.shape[shard_dim] % self.world_size == 0)
with comm_mode:
new_empty_strided_dt = input_dt.new_empty_strided((12, 8), (8, 1))
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(new_empty_strided_dt.placements, placement)
self.assertEqual(
new_empty_strided_dt._local_tensor.size(), (12, 8 // self.world_size)
)
self.assertEqual(
new_empty_strided_dt._local_tensor.stride(), (8 // self.world_size, 1)
)
self.assertTrue(new_empty_strided_dt.contiguous() is new_empty_strided_dt)
# output shape same as input shape, unevenly sharded input -> output replicated
global_tensor = torch.randn(12, 7)
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
self.assertTrue(input_dt.shape[shard_dim] % self.world_size != 0)
with comm_mode:
new_empty_strided_dt = input_dt.new_empty_strided((12, 7), (7, 1))
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(new_empty_strided_dt.placements, (Replicate(),))
self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 7))
self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (7, 1))
# output shape different from input shape -> output replicated
global_tensor = torch.randn(12, 8)
input_dt = distribute_tensor(global_tensor, device_mesh, placement)
with comm_mode:
new_empty_strided_dt = input_dt.new_empty_strided((12, 4), (4, 1))
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(new_empty_strided_dt.placements, (Replicate(),))
self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 4))
self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (4, 1))
@with_comms
def test_scatter(self):
device_mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
# case 1 all replicate: input replicated, index/src replicated, output replicated
global_indexs = [
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[0, 1, 2], [0, 1, 4]]),
]
for scatter_dim in [0, 1]:
srcs = [torch.arange(1, 11).reshape((2, 5)), 4]
for global_src in srcs:
global_input = torch.zeros(3, 5, dtype=torch.int64)
global_index = global_indexs[scatter_dim]
input_dt = distribute_tensor(
global_input.clone(), device_mesh, [Replicate()]
)
index_dt = distribute_tensor(global_index, device_mesh, [Replicate()])
if isinstance(global_src, torch.Tensor):
src_dt = distribute_tensor(global_src, device_mesh, [Replicate()])
else:
src_dt = global_src
global_output = torch.scatter(
global_input, scatter_dim, global_index, global_src
)
with comm_mode:
output_dt = torch.scatter(input_dt, scatter_dim, index_dt, src_dt)
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(output_dt.placements, [Replicate()])
self.assertEqual(output_dt.to_local(), global_output)
@with_comms
def test_gather(self):
device_mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
# case 1 all replicate: input replicated, index replicated, output replicated
global_input = torch.randn(12, 8, 16)
global_index = torch.randint(8, (4, 4, 8))
input_dt = distribute_tensor(global_input, device_mesh, [Replicate()])
index_dt = distribute_tensor(global_index, device_mesh, [Replicate()])
for gather_dim in [0, 1, 2]:
global_output = torch.gather(global_input, gather_dim, global_index)
with comm_mode:
output_dt = torch.gather(input_dt, gather_dim, index_dt)
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(output_dt.placements, [Replicate()])
self.assertEqual(output_dt.to_local(), global_output)
# case 2 input sharding: input sharded, index replicated, output mask partial
# only works when index has size 1 on the gather dimension and
# input is sharded on the gather dimension
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
gather_dim = 1
global_input = torch.randn(12, 8, 16)
global_index = torch.randint(8, (4, 1, 8))
global_output = torch.gather(global_input, gather_dim, global_index)
input_dt = distribute_tensor(global_input, device_mesh, [Shard(gather_dim)])
index_dt = distribute_tensor(global_index, device_mesh, [Replicate()])
with comm_mode:
output_dt = torch.gather(input_dt, gather_dim, index_dt)
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertIsInstance(output_dt.placements[0], _MaskPartial)
self.assertEqual(output_dt.full_tensor(), global_output)
# case 3 index sharding: input replicated, index sharded, output sharded
# only works when the sharding dimension is the gather dimension
global_input = torch.randn(12, 8, 16)
global_index = torch.randint(8, (4, 4, 8))
for gather_dim in range(len(global_index.shape)):
input_dt = distribute_tensor(global_input, device_mesh, [Replicate()])
index_dt = distribute_tensor(global_index, device_mesh, [Shard(gather_dim)])
global_output = torch.gather(global_input, gather_dim, global_index)
with comm_mode:
output_dt = torch.gather(input_dt, gather_dim, index_dt)
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(output_dt.placements, [Shard(gather_dim)])
self.assertEqual(output_dt.full_tensor(), global_output)
@skipIfRocm
@with_comms
def test_index(self):
meshes = [
self.build_device_mesh(), # 1D mesh
# TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh
# DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh
]
for mesh in meshes:
self._test_op(
mesh,
lambda x, y: x[y],
torch.randn(16, 32, 16),
torch.randint(5, (4, 8)),
)
self._test_op(
mesh,
lambda x, y: x.index_select(1, y),
torch.randn(16, 32, 16),
torch.randint(5, (4,)),
)
self._test_op(
mesh,
lambda x, y: x.index_select(0, y),
torch.randn(16, 32, 16),
torch.randint(5, (4,)),
)
self._test_op(
mesh,
lambda x, y: x[y],
torch.randn(16, 32, 16),
torch.randint(5, (12,)),
)
self._test_op(
mesh,
lambda x, y: x[:, y],
torch.randn(16, 32, 16),
torch.randint(5, (4, 8)),
)
self._test_op(
mesh,
lambda x, y: x[..., y],
torch.randn(16, 32, 16),
torch.randint(5, (4, 12)),
)
self._test_op(
mesh,
lambda x, y: x[..., y],
torch.randn(16, 32, 16),
torch.randint(5, (4, 8, 16)),
)
self._test_op(
mesh,
lambda x, y, z: x[z, y],
torch.randn(16, 32, 16),
torch.randint(5, (12, 8, 12)),
torch.randint(2, (12, 8, 12)),
)
self._test_op(
mesh,
lambda x, y, z: x[z, :, y],
torch.randn(16, 32, 16),
torch.randint(5, (12, 8, 12)),
torch.randint(2, (12, 8, 12)),
)
self._test_op(
mesh,
lambda x, y, z: x[:, z, :, y],
torch.randn(16, 32, 16, 12),
torch.randint(5, (12, 8, 12)),
torch.randint(2, (12, 8, 12)),
)
# broadcast in inner dimensions
self._test_op(
mesh,
lambda x, y, z: x[:, z, :, y],
torch.randn(16, 32, 16, 12),
torch.randint(5, (12, 8, 12)),
torch.randint(2, (12, 1, 12)),
)
# implicit (left-padded) broadcast
self._test_op(
mesh,
lambda x, y, z: x[:, z, :, y],
torch.randn(16, 32, 16, 12),
torch.randint(5, (12, 8, 12)),
torch.randint(2, (8, 12)),
)
self._test_op(
mesh,
lambda x, y, z: x[z, y, :, :],
torch.randn(16, 32, 16, 12),
torch.randint(2, (8, 12)),
torch.randint(5, (12, 8, 12)),
)
self._test_op(
mesh,
lambda x, y, z: x[z, :, y, :],
torch.randn(16, 32, 16, 12),
torch.randint(2, (8, 12)),
torch.randint(5, (12, 8, 12)),
)
self._test_op(
mesh,
lambda x, y, z: x[z, :, :, y],
torch.randn(16, 32, 16, 12),
torch.randint(2, (8, 1)),
torch.randint(5, (12, 8, 12)),
)
@with_comms
def test_index_put_scalar(self):
device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
global_input = torch.randn(2, 4, 8, device=self.device_type)
global_index = [
torch.randint(global_input.shape[i], size=(), device=self.device_type)
for i in range(3)
]
global_value = torch.randn(size=(), device=self.device_type)
value_dt = distribute_tensor(
global_value, device_mesh, [Replicate(), Replicate()]
)
placement_choice_pool = [Shard(0), Shard(1), Replicate()]
for i in placement_choice_pool:
for j in placement_choice_pool:
input_dt = distribute_tensor(global_input, device_mesh, [i, j])
ref = torch.index_put(global_input, global_index, global_value)
output_dt = torch.index_put(input_dt, global_index, value_dt)
assert isinstance(output_dt, DTensor)
# for value is a scalar case, output placement must be replicate
self.assertEqual(output_dt.placements, (Replicate(), Replicate()))
self.assertEqual(output_dt.full_tensor(), ref)
@with_comms
def test_index_put_tensor(self):
device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
global_input = torch.randn(2, 4, 8, device=self.device_type)
global_index = [
torch.randint(global_input.shape[0], size=(), device=self.device_type)
]
global_value = torch.zeros([4, 8], device=self.device_type)
value_dt = distribute_tensor(global_value, device_mesh, [Shard(1), Replicate()])
input_dt = distribute_tensor(global_input, device_mesh, [Shard(0), Replicate()])
for accumulate in [True, False]:
ref = torch.index_put(global_input, global_index, global_value, accumulate)
output_dt = torch.index_put(input_dt, global_index, value_dt, accumulate)
assert isinstance(output_dt, DTensor)
# `input_dt` follows `value_dt`'s Shard(1) plus a offset value of
# global_value.ndim-global_input.ndim, which results in Shard(2)
self.assertEqual(output_dt.placements, (Shard(2), Replicate()))
self.assertEqual(output_dt.full_tensor(), ref)
@with_comms
def test_where_type_promotion(self):
mesh = self.build_device_mesh() # 1D mesh
specs = [[Shard(0)], [Replicate()]]
for spec in specs:
global_tensor = torch.randn(12, 8)
mat = distribute_tensor(global_tensor, mesh, spec)
res = torch.where(mat > 0, 1, 0)
ref = torch.where(global_tensor > 0, 1, 0)
self.assertEqual(res.full_tensor(), ref)
@with_comms
def test_dtensor_dtype_conversion(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
# by default we start from bf16 dtype
local_tenor = torch.randn(2, 8, dtype=torch.bfloat16)
bf16_sharded_dtensor = DTensor.from_local(local_tenor, device_mesh, shard_spec)
self.assertEqual(bf16_sharded_dtensor.dtype, torch.bfloat16)
self.assertEqual(bf16_sharded_dtensor.to_local().dtype, torch.bfloat16)
# convert to float dtype
fp32_sharded_dtensor = bf16_sharded_dtensor.float()
self.assertEqual(fp32_sharded_dtensor.dtype, torch.float32)
self.assertEqual(fp32_sharded_dtensor.to_local().dtype, torch.float32)
# convert back to bf16 dtype
bf16_sharded_dtensor1 = fp32_sharded_dtensor.type_as(bf16_sharded_dtensor)
self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16)
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
from torch.distributed.tensor.debug import _get_sharding_prop_cache_info
# by this point we only have cache misses
hits, misses, _, _ = _get_sharding_prop_cache_info()
self.assertEqual(hits, 0)
self.assertEqual(misses, 2)
# convert to fp32 again and see if there's cache hit
bf16_sharded_dtensor1.float()
hits, misses, _, _ = _get_sharding_prop_cache_info()
# by now we should have cache hit
self.assertEqual(hits, 1)
self.assertEqual(misses, 2)
@with_comms
def test_slice(self):
mesh = self.build_device_mesh() # 1D mesh
comm_mode = CommDebugMode()
shard_spec = [Shard(1)]
global_tensor = torch.randn(8, 16, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, mesh, shard_spec)
global_out = global_tensor[:, 8:]
with comm_mode:
sharded_out = sharded_dtensor[:, 8:]
self.assertEqual(comm_mode.get_total_counts(), 1)
global_out.backward(gradient=torch.ones_like(global_out))
with comm_mode:
sharded_out_grad = torch.distributed.tensor.ones(
sharded_out.shape, device_mesh=mesh, placements=shard_spec
)
sharded_out.backward(gradient=sharded_out_grad)
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(sharded_out.full_tensor(), global_out)
self.assertEqual(sharded_dtensor.grad.full_tensor(), global_tensor.grad)
@with_comms
def test_split_on_partial(self):
self.run_subtests(
{
"reduce_op": ["sum", "avg", "product", "min", "max"],
"split_size": [2, 3, 4],
"split_dim": [0, 1],
},
self._test_split_on_partial,
)
def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int):
torch.manual_seed(self.rank)
mesh = self.build_device_mesh()
partial_tensor = torch.randn(8, 8, device=self.device_type)
partial_dt = DTensor.from_local(
local_tensor=partial_tensor,
device_mesh=mesh,
placements=[Partial(reduce_op=reduce_op)],
)
self._test_op_on_dtensor(
torch.split,
partial_dt,
split_size,
dim=split_dim,
)
@with_comms
def test_unbind(self):
device_mesh = self.build_device_mesh()
shard_dims = [0, 1]
unbind_dims = [0, 1]
local_tensor = torch.randn(4, 8, requires_grad=True)
for shard_dim, unbind_dim in itertools.product(shard_dims, unbind_dims):
dist_tensor = distribute_tensor(
local_tensor, device_mesh, (Shard(shard_dim),)
)
if shard_dim == unbind_dim:
with self.assertRaisesRegex(
RuntimeError, "Sharding propagation failed"
):
dist_tensor.unbind(dim=unbind_dim)
else:
unbinded_dist_tensors = dist_tensor.unbind(dim=unbind_dim)
new_shard_dim = shard_dim if shard_dim < unbind_dim else shard_dim - 1
self.assertTrue(
all(
elem.placements[0].is_shard(dim=new_shard_dim)
for elem in unbinded_dist_tensors
)
)
for x, y in zip(
unbinded_dist_tensors, local_tensor.unbind(dim=unbind_dim)
):
self.assertEqual(x.full_tensor(), y)
if __name__ == "__main__":
run_tests()