mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
reland of https://github.com/pytorch/pytorch/pull/133113 I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :( ---- Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203 Approved by: https://github.com/tianyu-l
577 lines
19 KiB
Python
577 lines
19 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import itertools
|
|
from typing import cast, List
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import rand, randn, Tensor
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
distribute_tensor,
|
|
init_device_mesh,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed._tensor.placement_types import Placement
|
|
from torch.distributed.tensor._ops._view_ops import (
|
|
Broadcast,
|
|
dim_maps,
|
|
Flatten,
|
|
InputDim,
|
|
Repeat,
|
|
Singleton,
|
|
Split,
|
|
view_groups,
|
|
)
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
class TestViewOps(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 6
|
|
|
|
def test_view_groups(self):
|
|
self.assertEqual(
|
|
view_groups([2, 3], [3, 2]),
|
|
(
|
|
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
|
|
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 4, 5], [12, 5]),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4, 5, 7], [12, 70]),
|
|
(
|
|
Split(
|
|
Flatten(
|
|
(
|
|
InputDim(0),
|
|
InputDim(1),
|
|
InputDim(2),
|
|
InputDim(3),
|
|
InputDim(4),
|
|
)
|
|
),
|
|
(12, 70),
|
|
0,
|
|
),
|
|
Split(
|
|
Flatten(
|
|
(
|
|
InputDim(0),
|
|
InputDim(1),
|
|
InputDim(2),
|
|
InputDim(3),
|
|
InputDim(4),
|
|
)
|
|
),
|
|
(12, 70),
|
|
1,
|
|
),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]),
|
|
(
|
|
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0),
|
|
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1),
|
|
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0),
|
|
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 4, 8, 3], [12, 4, 2, 3]),
|
|
(
|
|
Flatten((InputDim(0), InputDim(1))),
|
|
Split(InputDim(2), (4, 2), 0),
|
|
Split(InputDim(2), (4, 2), 1),
|
|
InputDim(3),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]),
|
|
(
|
|
Singleton(),
|
|
InputDim(0),
|
|
Split(InputDim(1), (2, 4, 3), 0),
|
|
Split(InputDim(1), (2, 4, 3), 1),
|
|
Singleton(),
|
|
Split(InputDim(1), (2, 4, 3), 2),
|
|
Singleton(),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]),
|
|
(
|
|
Flatten((InputDim(2), InputDim(3))),
|
|
InputDim(4),
|
|
InputDim(5),
|
|
Singleton(),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]),
|
|
(
|
|
Split(InputDim(2), (3, 4), 0),
|
|
Split(InputDim(2), (3, 4), 1),
|
|
InputDim(3),
|
|
Flatten((InputDim(6), InputDim(7))),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4], [2, -1, 4]),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
|
|
dim_map = dim_maps[op]
|
|
rules = dim_map(*args, **kwargs)
|
|
outputs = op(*args, **kwargs)
|
|
flat_args = pytree.arg_tree_leaves(*args)
|
|
in_shape = flat_args[0].shape
|
|
|
|
no_shard_dims = set()
|
|
for rule in rules:
|
|
if isinstance(rule, Repeat):
|
|
if isinstance(rule.input_dim, InputDim):
|
|
no_shard_dims.add(rule.input_dim.input_dim)
|
|
elif isinstance(rule, Flatten):
|
|
for dim in rule.input_dims[1:]:
|
|
if isinstance(dim, InputDim):
|
|
no_shard_dims.add(dim.input_dim)
|
|
elif isinstance(rule, Split):
|
|
if isinstance(rule.input_dim, Flatten):
|
|
for dim in rule.input_dim.input_dims[1:]:
|
|
if isinstance(dim, InputDim):
|
|
no_shard_dims.add(dim.input_dim)
|
|
|
|
if op == torch.unbind:
|
|
no_shard_dims.add(kwargs.get("dim", 0))
|
|
|
|
sharding_choices = cast(List[Placement], [Replicate()]) + [
|
|
Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
|
|
]
|
|
|
|
all_sharding_choices = itertools.product(
|
|
*(device_mesh.ndim * [sharding_choices])
|
|
)
|
|
|
|
for in_shard in all_sharding_choices:
|
|
in_dt = distribute_tensor(args[0], device_mesh, in_shard)
|
|
|
|
comm_mode = CommDebugMode()
|
|
with comm_mode:
|
|
out_dt = op(in_dt, *args[1:], **kwargs)
|
|
|
|
self.assertEqual(
|
|
comm_mode.get_total_counts(), 0, "Expected no redistribution."
|
|
)
|
|
|
|
full_out = out_dt.full_tensor()
|
|
|
|
if dist.get_rank() == 0:
|
|
self.assertEqual(outputs, full_out)
|
|
|
|
def dimmap_test(self, op, args, expected_rule_output):
|
|
rules = dim_maps[op](*args)
|
|
self.assertEqual(rules, expected_rule_output)
|
|
self.call_dt_test(op, args, {}, self.device_mesh)
|
|
|
|
@with_comms
|
|
def test_view_ops(self):
|
|
self.device_mesh = DeviceMesh(
|
|
self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
|
|
)
|
|
self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),))
|
|
self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),))
|
|
self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1)))
|
|
|
|
self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton()))
|
|
self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0)))
|
|
self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1)))
|
|
self.dimmap_test(
|
|
torch.atleast_2d,
|
|
(randn(24, 36, 48),),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(()),),
|
|
(Singleton(), Singleton(), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24),),
|
|
(Singleton(), InputDim(0), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36),),
|
|
(InputDim(0), InputDim(1), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36, 42),),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36, 42, 24),),
|
|
(InputDim(0), InputDim(1), InputDim(2), InputDim(3)),
|
|
)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4))
|
|
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (1, 24, 36)),
|
|
(Singleton(), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (42, 24, 36)),
|
|
(Broadcast(Singleton(), 42), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 1, 36), (12, 24, 24, 36)),
|
|
(
|
|
Broadcast(Singleton(), 12),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 24),
|
|
InputDim(2),
|
|
),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (-1, 36)),
|
|
(InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 1, 36), (-1, 1, 36)),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(randn(36, 1, 24), (12, 36, 42, 24)),
|
|
(
|
|
Broadcast(Singleton(), 12),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.expand,
|
|
(randn(24, 1, 36, 1), 36, 24, 42, -1, 24),
|
|
(
|
|
Broadcast(Singleton(), 36),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
Broadcast(InputDim(3), 24),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.expand,
|
|
(randn(24, 1, 36, 1), (36, 24, 42, -1, 24)),
|
|
(
|
|
Broadcast(Singleton(), 36),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
Broadcast(InputDim(3), 24),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.flatten,
|
|
(randn(24, 36),),
|
|
(Flatten((InputDim(0), InputDim(1))),),
|
|
)
|
|
self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),))
|
|
self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),))
|
|
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(12, 24, 48, 96), 1, 2),
|
|
(InputDim(0), InputDim(2), InputDim(1), InputDim(3)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(6, 12, 24), 1, 0),
|
|
(InputDim(1), InputDim(0), InputDim(2)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 12, 6), (1, 2), (0, 1)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 6, 12), (0, 2, 1), (2, 1, 0)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 12), (1, 0), (0, 1)),
|
|
(InputDim(1), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(36, 24, 12), (1, 2), (0, 1)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(36, 24, 12), (1, 2), (-3, -2)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.permute,
|
|
(randn(24, 36, 42), (2, 0, 1)),
|
|
(InputDim(2), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.permute,
|
|
(randn(24, 36, 42), (-1, -3, -2)),
|
|
(InputDim(2), InputDim(0), InputDim(1)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.ravel,
|
|
(randn(24, 36),),
|
|
(Flatten((InputDim(0), InputDim(1))),),
|
|
)
|
|
self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),))
|
|
self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),))
|
|
|
|
self.dimmap_test(
|
|
Tensor.repeat,
|
|
(randn(24, 36), 1, 2, 1, 1, 2),
|
|
(
|
|
Singleton(),
|
|
Broadcast(Singleton(), 2),
|
|
Singleton(),
|
|
InputDim(0),
|
|
Repeat(InputDim(1), 2),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.reshape,
|
|
(randn(6, 12, 24), (72, 24)),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.tile,
|
|
(randn(24, 36), (1, 2, 1, 1, 2)),
|
|
(
|
|
Singleton(),
|
|
Broadcast(Singleton(), 2),
|
|
Singleton(),
|
|
InputDim(0),
|
|
Repeat(InputDim(1), 2),
|
|
),
|
|
)
|
|
self.dimmap_test(
|
|
torch.tile,
|
|
(randn(42, 24, 36), (1, 3)),
|
|
(InputDim(0), InputDim(1), Repeat(InputDim(2), 3)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.transpose,
|
|
(randn(24, 60, 42, 60), 2, 0),
|
|
(InputDim(2), InputDim(1), InputDim(0), InputDim(3)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.transpose,
|
|
(randn(24, 60, 42, 60), -1, 0),
|
|
(InputDim(3), InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.unsqueeze,
|
|
(randn(42, 24, 36), 1),
|
|
(InputDim(0), Singleton(), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(6, 12, 24), 72, 24),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),))
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(1, 1, 42, 24), -1),
|
|
(Flatten((InputDim(2), InputDim(3))),),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(1, 1, 42, 1, 24, 1), -1),
|
|
(Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(48, 35, 26), (24, 4, 35, 13)),
|
|
(
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=0,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=1,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=2,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=3,
|
|
),
|
|
),
|
|
)
|
|
|
|
# TODO: Currently functional collectives on complex numbers are not fully supported,
|
|
# so we are having a standalone test for view_as_complex and view_as_real combined.
|
|
# Once complex numbers are supported, we can add the following to the dim_map test.
|
|
#
|
|
# self.dimmap_test(
|
|
# torch.view_as_complex,
|
|
# (randn(24, 13, 2),),
|
|
# (
|
|
# InputDim(0),
|
|
# Flatten((InputDim(1), InputDim(2))),
|
|
# ),
|
|
# )
|
|
# self.dimmap_test(
|
|
# torch.view_as_real,
|
|
# (torch.randn(24, 13, dtype=torch.cfloat),),
|
|
# (
|
|
# InputDim(0),
|
|
# Split(InputDim(1), (13, 2), 0),
|
|
# Split(InputDim(1), (13, 2), 1),
|
|
# ),
|
|
# )
|
|
@with_comms
|
|
def test_complex_view_ops(self):
|
|
self.device_mesh = DeviceMesh(
|
|
self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
|
|
)
|
|
inp = randn(24, 13, 2)
|
|
intermediate = torch.view_as_complex(inp)
|
|
out = torch.view_as_real(intermediate)
|
|
|
|
# test dim_map correctness
|
|
expected_view_as_complex_rule = (
|
|
InputDim(0),
|
|
Flatten((InputDim(1), InputDim(2))),
|
|
)
|
|
view_as_complex_rule = dim_maps[torch.view_as_complex](inp)
|
|
self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule)
|
|
expected_view_as_real_rule = (
|
|
InputDim(0),
|
|
Split(InputDim(1), (13, 2), 0),
|
|
Split(InputDim(1), (13, 2), 1),
|
|
)
|
|
view_as_real_rule = dim_maps[torch.view_as_real](intermediate)
|
|
self.assertEqual(view_as_real_rule, expected_view_as_real_rule)
|
|
|
|
# test sharded computation correctness
|
|
# NOTE: For the input to torch.view_as_complex, sharding
|
|
# on the last two dimensions is not supported.
|
|
sharding_choices: List[Placement] = [Replicate(), Shard(0)]
|
|
all_sharding_choices = itertools.product(
|
|
*(self.device_mesh.ndim * [sharding_choices])
|
|
)
|
|
|
|
for inp_shard in all_sharding_choices:
|
|
inp_dt = distribute_tensor(inp, self.device_mesh, inp_shard)
|
|
|
|
comm_mode = CommDebugMode()
|
|
with comm_mode:
|
|
intermediate_dt = torch.view_as_complex(inp_dt)
|
|
out_dt = torch.view_as_real(intermediate_dt)
|
|
|
|
self.assertEqual(
|
|
comm_mode.get_total_counts(), 0, "Expected no redistribution."
|
|
)
|
|
self.assertEqual(out, out_dt.full_tensor())
|
|
|
|
@with_comms
|
|
def test_dtensor_view_op_uneven(self):
|
|
"""
|
|
Test two uneven cases for view op:
|
|
1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard.
|
|
2) the sharded tensor dim is uneven such that some ranks have full shards,
|
|
smaller non-empty shards, and empty shards.
|
|
"""
|
|
dim0_sizes = [1, self.world_size + 1]
|
|
for dim0_size in dim0_sizes:
|
|
p = torch.randn(dim0_size, 2, 2, 2)
|
|
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
|
dtensor = distribute_tensor(p, mesh, [Shard(0)])
|
|
|
|
with CommDebugMode() as comm_mode:
|
|
view = dtensor.view(dim0_size, 2, 4)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
# when no communication happens, the data pointer should be the same.
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
|
|
view = dtensor.view(dim0_size, 4, 2)
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
|
|
view = dtensor.view(dim0_size, 8)
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
|
|
view = dtensor.view(dtensor.shape)
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|