mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 10:04:58 +08:00
Update on "[3/N] [DTensor device order] Make some placement type class method static"
Some methods in `Placement` class can be exposed as static. Those method should be useful w/o initializing the object. E.g., when we `distribute_tensor` from normal tensor, we may want: ``` local_tensor = Shard.shard_tensor(tensor_dim, local_tensor, device_mesh, mesh_dim,) ``` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
This commit is contained in:
@ -20,7 +20,11 @@ from torch.distributed.tensor import (
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._api import _shard_tensor
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._dtensor_spec import (
|
||||
DTensorSpec,
|
||||
ShardOrderEntry,
|
||||
TensorMeta,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.distributed.tensor.parallel import (
|
||||
@ -1065,35 +1069,23 @@ class TestDTensorSpec(DTensorTestBase):
|
||||
def test_dtensor_spec_print(self):
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str((Shard(2), Shard(1), Shard(0)), None),
|
||||
"""S(0)[2]S(1)[1]S(2)[0]""",
|
||||
"""S(2)S(1)S(0)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Shard(2), Shard(1), Shard(0)), ((0, 2), (1, 1), (2, 0))
|
||||
),
|
||||
"""S(0)[2]S(1)[1]S(2)[0]""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Shard(1), Shard(1), Shard(1)), ((1, 2, 0, 1),)
|
||||
),
|
||||
"""S(1)[2, 0, 1]""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Shard(2), Shard(1), Shard(0)), None, False
|
||||
(Shard(2), Shard(1), Shard(0)),
|
||||
(
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(1,)),
|
||||
ShardOrderEntry(tensor_dim=2, mesh_dims=(0,)),
|
||||
),
|
||||
),
|
||||
"""S(2)S(1)S(0)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Shard(2), Shard(1), Shard(0)), ((0, 2), (1, 1), (2, 0)), False
|
||||
),
|
||||
"""S(2)S(1)S(0)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Shard(1), Shard(1), Shard(1)), ((1, 2, 0, 1),), False
|
||||
(Shard(1), Shard(1), Shard(1)),
|
||||
(ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0, 1)),),
|
||||
),
|
||||
"""S(1)[1]S(1)[2]S(1)[0]""",
|
||||
)
|
||||
@ -1101,23 +1093,11 @@ class TestDTensorSpec(DTensorTestBase):
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Replicate(), Replicate(), Replicate()), None
|
||||
),
|
||||
"""R*""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Replicate(), Replicate(), Shard(1)), None
|
||||
),
|
||||
"""S(1)[2]""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Replicate(), Replicate(), Replicate()), None, False
|
||||
),
|
||||
"""RRR""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
DTensorSpec.format_shard_order_str(
|
||||
(Replicate(), Replicate(), Shard(1)), None, False
|
||||
(Replicate(), Replicate(), Shard(1)), None
|
||||
),
|
||||
"""RRS(1)""",
|
||||
)
|
||||
@ -1130,28 +1110,46 @@ class TestDTensorSpec(DTensorTestBase):
|
||||
tensor_global = DTensor.from_local(
|
||||
tensor_local, mesh, [Shard(1), Shard(1), Shard(0)]
|
||||
)
|
||||
tensor_global._spec.shard_order = ((0, 2), (1, 1, 0))
|
||||
tensor_global._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, r"shard_order .* has empty mesh dim"
|
||||
):
|
||||
tensor_global._spec.shard_order = ((1,), (0, 2))
|
||||
tensor_global._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=()),
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "tensor dim should be sorted in shard_order"
|
||||
):
|
||||
tensor_global._spec.shard_order = ((1, 1, 0), (0, 2))
|
||||
tensor_global._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
r"placement\[\d+\] doesn't have a matching shard in shard_order",
|
||||
):
|
||||
tensor_global._spec.shard_order = ((0, 1), (1, 1, 0))
|
||||
tensor_global._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, r"shard_order .* has invalid mesh dim \[\d+\]"
|
||||
AssertionError, r"shard_order .* has invalid mesh dim \([\d,]+\)"
|
||||
):
|
||||
tensor_global._spec.shard_order = ((0, 3), (1, 1, 0))
|
||||
tensor_global._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(3,)),
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, r"shard_order .* has invalid tensor dim -?\d+"
|
||||
):
|
||||
tensor_global._spec.shard_order = ((0, 2), (-1, 1, 0))
|
||||
tensor_global._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
|
||||
ShardOrderEntry(tensor_dim=-1, mesh_dims=(1, 0)),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_spec_update(self):
|
||||
@ -1168,7 +1166,10 @@ class TestDTensorSpec(DTensorTestBase):
|
||||
self.assertEqual(hash(tensor_global_1._spec), hash(tensor_global_2._spec))
|
||||
self.assertEqual(tensor_global_1._spec, tensor_global_2._spec)
|
||||
# not using the default shard_order
|
||||
tensor_global_1._spec.shard_order = ((0, 2), (1, 1, 0))
|
||||
tensor_global_1._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
|
||||
)
|
||||
# hash should be recomputed in DTensorSpec.__setattr__()
|
||||
self.assertNotEqual(hash(tensor_global_1._spec), hash(tensor_global_2._spec))
|
||||
self.assertNotEqual(tensor_global_1._spec, tensor_global_2._spec)
|
||||
@ -1182,7 +1183,13 @@ class TestDTensorSpec(DTensorTestBase):
|
||||
tensor_global = DTensor.from_local(
|
||||
tensor_local, mesh, [Shard(1), Shard(1), Shard(0)]
|
||||
)
|
||||
self.assertEqual(tensor_global._spec.shard_order, ((0, 2), (1, 0, 1)))
|
||||
self.assertEqual(
|
||||
tensor_global._spec.shard_order,
|
||||
(
|
||||
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 1)),
|
||||
),
|
||||
)
|
||||
|
||||
tensor_global = DTensor.from_local(
|
||||
tensor_local, mesh, [Replicate(), Replicate(), Replicate()]
|
||||
@ -1213,12 +1220,21 @@ class TestDTensorSpec(DTensorTestBase):
|
||||
tensor_local, mesh, [Shard(1), Shard(2), Shard(1)]
|
||||
)
|
||||
# DTensorSpec automatically builds the default left-to-right order
|
||||
self.assertEqual(tensor_global._spec.shard_order, ((1, 0, 2), (2, 1)))
|
||||
self.assertEqual(
|
||||
tensor_global._spec.shard_order,
|
||||
(
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2)),
|
||||
ShardOrderEntry(tensor_dim=2, mesh_dims=(1,)),
|
||||
),
|
||||
)
|
||||
self.assertTrue(
|
||||
DTensorSpec.is_default_device_order(tensor_global._spec.shard_order)
|
||||
)
|
||||
# manually set the shard_order by exchange mesh dim 0 and 2
|
||||
tensor_global._spec.shard_order = ((1, 2, 0), (2, 1))
|
||||
tensor_global._spec.shard_order = (
|
||||
ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0)),
|
||||
ShardOrderEntry(tensor_dim=2, mesh_dims=(1,)),
|
||||
)
|
||||
self.assertFalse(
|
||||
DTensorSpec.is_default_device_order(tensor_global._spec.shard_order)
|
||||
)
|
||||
|
||||
@ -203,19 +203,19 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found"):
|
||||
mesh_2d.get_mesh_dim_by_name("")
|
||||
mesh_2d._get_mesh_dim_by_name("")
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
self.assertEqual(mesh_2d.get_mesh_dim_by_name("dp"), 0)
|
||||
self.assertEqual(mesh_2d.get_mesh_dim_by_name("tp"), 1)
|
||||
self.assertEqual(mesh_2d._get_mesh_dim_by_name("dp"), 0)
|
||||
self.assertEqual(mesh_2d._get_mesh_dim_by_name("tp"), 1)
|
||||
tp_mesh = mesh_2d["tp"]
|
||||
self.assertEqual(tp_mesh.get_mesh_dim_by_name("tp"), 0)
|
||||
self.assertEqual(tp_mesh._get_mesh_dim_by_name("tp"), 0)
|
||||
non_exist_mesh_name = "dp"
|
||||
with self.assertRaisesRegex(
|
||||
KeyError, f"Mesh dimension '{non_exist_mesh_name}' does not exist."
|
||||
):
|
||||
tp_mesh.get_mesh_dim_by_name(non_exist_mesh_name)
|
||||
tp_mesh._get_mesh_dim_by_name(non_exist_mesh_name)
|
||||
|
||||
@with_comms
|
||||
def test_get_local_rank_raises_exception(self):
|
||||
@ -1081,8 +1081,8 @@ class TestMeshEnv(DTensorTestBase):
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
|
||||
self.assertEqual(mesh_2d.get_mesh_dim_by_name("DP"), 0)
|
||||
self.assertEqual(mesh_2d.get_mesh_dim_by_name("TP"), 1)
|
||||
self.assertEqual(mesh_2d._get_mesh_dim_by_name("DP"), 0)
|
||||
self.assertEqual(mesh_2d._get_mesh_dim_by_name("TP"), 1)
|
||||
|
||||
@with_comms
|
||||
def test_get_all_submeshes(self):
|
||||
|
||||
@ -210,7 +210,7 @@ else:
|
||||
"The submesh can only be a 1D mesh."
|
||||
)
|
||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||
return root_mesh.get_mesh_dim_by_name(child_mesh_dim_name)
|
||||
return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@ -322,7 +322,7 @@ else:
|
||||
"""
|
||||
Return all the submeshes of a given mesh dimension of the device mesh.
|
||||
"""
|
||||
mesh_dim = device_mesh.get_mesh_dim_by_name(mesh_dim_name)
|
||||
mesh_dim = device_mesh._get_mesh_dim_by_name(mesh_dim_name)
|
||||
layout = device_mesh._layout[mesh_dim]
|
||||
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||
device_mesh.mesh,
|
||||
@ -811,7 +811,7 @@ else:
|
||||
return not_none(_resolve_process_group(dim_group_name))
|
||||
else:
|
||||
mesh_dim = (
|
||||
self.get_mesh_dim_by_name(mesh_dim)
|
||||
self._get_mesh_dim_by_name(mesh_dim)
|
||||
if isinstance(mesh_dim, str)
|
||||
else mesh_dim
|
||||
)
|
||||
@ -1075,7 +1075,7 @@ else:
|
||||
self, mesh_dim_name, backend_override_tuple
|
||||
)
|
||||
|
||||
def get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
|
||||
def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
|
||||
if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0:
|
||||
raise KeyError(
|
||||
"No `mesh_dim_names` found.",
|
||||
|
||||
@ -13,34 +13,40 @@ from torch.distributed.tensor.placement_types import (
|
||||
)
|
||||
|
||||
|
||||
MeshDimTuple = tuple[int, ...]
|
||||
TensorDimTuple = tuple[MeshDimTuple, ...]
|
||||
class ShardOrderEntry(NamedTuple):
|
||||
"""
|
||||
Represents how a single tensor dimension is sharded across mesh dimensions.
|
||||
|
||||
Attributes:
|
||||
tensor_dim: The tensor dimension being sharded (e.g., 0, 1, 2 for a 3D tensor).
|
||||
mesh_dims: Tuple of mesh dimensions across which this tensor dimension is sharded,
|
||||
in execution order. The first mesh dim is applied first, second is applied
|
||||
second, etc.
|
||||
|
||||
Examples:
|
||||
>>> # Tensor dim 1 sharded across mesh dim 2, then mesh dim 0
|
||||
>>> ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0))
|
||||
|
||||
>>> # Tensor dim 0 sharded only on mesh dim 1
|
||||
>>> ShardOrderEntry(tensor_dim=0, mesh_dims=(1,))
|
||||
"""
|
||||
|
||||
tensor_dim: int
|
||||
mesh_dims: tuple[int, ...]
|
||||
|
||||
|
||||
# Controls the print format for tensor distribution visualization.
|
||||
# When True: Shows tensor-centric format mapping tensor dimensions to mesh dimensions
|
||||
# When False: Shows standard DTensor mesh-centric format mapping mesh dimensions to tensor dimensions
|
||||
# Type alias for the complete shard order specification
|
||||
# A tuple of ShardOrderEntry, one per sharded tensor dimension
|
||||
#
|
||||
# Example with a 3D tensor on a 2x2x2x2 mesh (16 devices) with
|
||||
# ``placements``: [Partial(), Shard(1), Shard(1), Replicate()],
|
||||
# ``shard_order``: {1: [2, 1]} (tensor dim 1 shard on mesh dimension 2 first then 1)
|
||||
# - mesh_dim_0: Partial reduction (sum)
|
||||
# - mesh_dim_1: Shard tensor dimension 1 (executed second)
|
||||
# - mesh_dim_2: Shard tensor dimension 1 (executed first)
|
||||
# - mesh_dim_3: Replicate
|
||||
#
|
||||
# When True (tensor-centric): "S(1)[1, 2]P(sum)[0]"
|
||||
# - S(1)[1, 2]: tensor dimension 1 sharded on mesh dimension 1 and then 2
|
||||
# - P(sum)[0]: partial reduction on mesh dimension 0
|
||||
# - Clearly shows which tensor dims map to which mesh dims
|
||||
#
|
||||
# When False (mesh-centric): "P(sum)S(1)[1]S(1)[0]R"
|
||||
# - P(sum): mesh dimension 0 has partial reduction
|
||||
# - S(1): mesh dimension 1 shards tensor dimension 1 (with order 1)
|
||||
# - S(1): mesh dimension 2 shards tensor dimension 1 (with order 0)
|
||||
# - R: Replicated on mesh dimension 3
|
||||
# - Standard DTensor placement with concise format following mesh dimension order
|
||||
tensor_centric_format = True
|
||||
# Example:
|
||||
# shard_order = (
|
||||
# ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
|
||||
# ShardOrderEntry(tensor_dim=2, mesh_dims=(0, 3)),
|
||||
# )
|
||||
# This means:
|
||||
# - Tensor dimension 0 is sharded on mesh dimension 1
|
||||
# - Tensor dimension 2 is sharded on mesh dimension 0 first, then mesh dimension 3
|
||||
ShardOrder = tuple[ShardOrderEntry, ...]
|
||||
|
||||
|
||||
class TensorMeta(NamedTuple):
|
||||
@ -64,16 +70,17 @@ class DTensorSpec:
|
||||
# When a tensor dimension is sharded across multiple mesh axes,
|
||||
# `shard_order` specifies the sequence in which these shardings are applied.
|
||||
# This order determines how tensor shards are mapped and distributed across
|
||||
# devices. `shard_order` is a tuple of tuples of integers: in each inner
|
||||
# tuple, the first element is the tensor dimension being sharded, and the
|
||||
# remaining elements are the device mesh dimensions (in order) over which
|
||||
# that tensor dimension is sharded.
|
||||
# devices.
|
||||
#
|
||||
# Example:
|
||||
# For a tensor of shape [8, 16] and a 3D device mesh, if dim 0 is sharded over
|
||||
# mesh dim 1, and dim 1 is sharded over mesh dim 0 and then mesh dim 2,
|
||||
# the shard_order would be: shard_order = ((0, 1), (1, 0, 2))
|
||||
shard_order: TensorDimTuple = None # type: ignore[assignment]
|
||||
# the shard_order would be:
|
||||
# shard_order = (
|
||||
# ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
|
||||
# ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2)),
|
||||
# )
|
||||
shard_order: ShardOrder = None # type: ignore[assignment]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not isinstance(self.placements, tuple):
|
||||
@ -87,7 +94,13 @@ class DTensorSpec:
|
||||
@staticmethod
|
||||
def compute_default_sparse_shard_order(
|
||||
placements: tuple[Placement, ...],
|
||||
) -> TensorDimTuple:
|
||||
) -> ShardOrder:
|
||||
"""
|
||||
Compute the default shard order from placements.
|
||||
|
||||
Returns a ShardOrder where each ShardOrderEntry maps a tensor dimension
|
||||
to the mesh dimensions it's sharded on, in left-to-right order.
|
||||
"""
|
||||
# follow default left-to-right device order if shard_order is not specified
|
||||
tensor_dim_to_mesh_dims: dict[int, list[int]] = {}
|
||||
mesh_ndim = len(placements)
|
||||
@ -104,23 +117,24 @@ class DTensorSpec:
|
||||
if shard_dim not in tensor_dim_to_mesh_dims:
|
||||
tensor_dim_to_mesh_dims[shard_dim] = []
|
||||
tensor_dim_to_mesh_dims[shard_dim].append(mesh_dim)
|
||||
# convert dict into the tuple of tuple that is hashable
|
||||
|
||||
# Convert dict into ShardOrderEntry tuples
|
||||
default_sparse_shard_order = tuple(
|
||||
tuple(item)
|
||||
for item in (
|
||||
[key] + value if isinstance(value, list) else [key, value]
|
||||
for key, value in sorted(tensor_dim_to_mesh_dims.items())
|
||||
if value
|
||||
)
|
||||
ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value))
|
||||
for key, value in sorted(tensor_dim_to_mesh_dims.items())
|
||||
if value
|
||||
)
|
||||
return default_sparse_shard_order
|
||||
|
||||
def _verify_shard_order(self, shard_order: TensorDimTuple) -> None:
|
||||
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
|
||||
"""Verify that the shard_order is valid and matches the placements."""
|
||||
total_shard = 0
|
||||
if any(isinstance(p, _StridedShard) for p in self.placements):
|
||||
return
|
||||
prev_tensor_dim = -1
|
||||
for tensor_dim, *mesh_dims in shard_order:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
assert len(mesh_dims) > 0, f"shard_order {shard_order} has empty mesh dim"
|
||||
assert tensor_dim >= 0, (
|
||||
f"shard_order {shard_order} has invalid tensor dim {tensor_dim}"
|
||||
@ -227,12 +241,12 @@ class DTensorSpec:
|
||||
return f"Spec({placement_str} on {tensor_shape})"
|
||||
|
||||
@staticmethod
|
||||
def is_default_device_order(shard_order: TensorDimTuple) -> bool:
|
||||
def is_default_device_order(shard_order: ShardOrder) -> bool:
|
||||
"""
|
||||
Check if the device order is the default left-to-right order.
|
||||
"""
|
||||
for tensor_dim_and_mesh_dims in shard_order:
|
||||
tensor_dim, *mesh_dims = tensor_dim_and_mesh_dims
|
||||
for entry in shard_order:
|
||||
mesh_dims = entry.mesh_dims
|
||||
is_increasing = all(
|
||||
prev < nxt for prev, nxt in itertools.pairwise(mesh_dims)
|
||||
)
|
||||
@ -243,75 +257,69 @@ class DTensorSpec:
|
||||
@staticmethod
|
||||
def format_shard_order_str(
|
||||
placements: tuple[Placement, ...],
|
||||
shard_order: Optional[TensorDimTuple] = None,
|
||||
tensor_centric_format: Optional[bool] = None,
|
||||
shard_order: Optional[ShardOrder] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format DTensor sharding information as a string.
|
||||
Format DTensor sharding information as a human-readable string.
|
||||
|
||||
This method formats the sharding pattern in mesh-centric order, showing the placement
|
||||
for each mesh dimension sequentially. When a tensor dimension is sharded across multiple
|
||||
mesh dimensions, the order index indicates the execution sequence of the sharding operations.
|
||||
|
||||
Args:
|
||||
placements: Tuple of placement objects for each mesh dimension
|
||||
shard_order: Optional tensor dimension to mesh dimension mapping
|
||||
tensor_centric_format: Controls output format
|
||||
- When True: Shows tensor-centric format mapping tensor dims to mesh dims
|
||||
- When False: Shows standard DTensor mesh-centric format
|
||||
placements: Tuple of placement objects for each mesh dimension.
|
||||
shard_order: Optional ShardOrder specifying the sharding order.
|
||||
|
||||
Returns:
|
||||
String representation of the sharding pattern
|
||||
String representation of the sharding pattern in mesh-centric format.
|
||||
|
||||
Example:
|
||||
For a 3D tensor on a 2x2x2x2 mesh (16 devices) with::
|
||||
|
||||
placements = [Partial(), Shard(1), Shard(1), Replicate()]
|
||||
shard_order = (ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 1)),)
|
||||
|
||||
Mesh configuration:
|
||||
- mesh_dim_0: Partial reduction (sum)
|
||||
- mesh_dim_1: Shard tensor dimension 1 (executed second, order index 1)
|
||||
- mesh_dim_2: Shard tensor dimension 1 (executed first, order index 0)
|
||||
- mesh_dim_3: Replicate
|
||||
|
||||
Output: ``"PS(1)[1]S(1)[0]R"``
|
||||
|
||||
Explanation:
|
||||
- ``P``: mesh dimension 0 has partial reduction
|
||||
- ``S(1)[1]``: mesh dimension 1 shards tensor dimension 1 (order index 1 means second)
|
||||
- ``S(1)[0]``: mesh dimension 2 shards tensor dimension 1 (order index 0 means first)
|
||||
- ``R``: mesh dimension 3 replicates
|
||||
|
||||
The format follows mesh dimension order (0, 1, 2, 3), and when a tensor dimension
|
||||
is sharded across multiple mesh dimensions, the bracketed index shows the execution
|
||||
order: ``[0]`` is executed first, ``[1]`` is executed second, etc.
|
||||
"""
|
||||
out_str = ""
|
||||
# print mapping from tensor dim to mesh dim
|
||||
is_tensor_centric_format = (
|
||||
tensor_centric_format
|
||||
if tensor_centric_format is not None
|
||||
else globals().get("tensor_centric_format", False)
|
||||
)
|
||||
if is_tensor_centric_format:
|
||||
if shard_order is None:
|
||||
shard_order = DTensorSpec.compute_default_sparse_shard_order(placements)
|
||||
for tensor_dim, *mesh_dims in shard_order:
|
||||
if len(mesh_dims) > 0:
|
||||
out_str += f"S({tensor_dim})"
|
||||
out_str += f"[{', '.join([str(m) for m in mesh_dims])}]"
|
||||
# in addition, add the partial placement
|
||||
partial_to_mesh_dim: dict[Partial, list[int]] = {}
|
||||
for mesh_dim, p in enumerate(placements):
|
||||
if isinstance(p, Partial):
|
||||
if p not in partial_to_mesh_dim:
|
||||
partial_to_mesh_dim[p] = []
|
||||
partial_to_mesh_dim[p].append(mesh_dim)
|
||||
for p, mesh_dims in partial_to_mesh_dim.items():
|
||||
out_str += f"P({p.reduce_op})"
|
||||
out_str += f"[{', '.join([str(m) for m in mesh_dims])}]"
|
||||
# case when no dim get sharded, we use "R*" to represent
|
||||
if out_str == "":
|
||||
out_str = "R*"
|
||||
# native dtensor-style sharding representation: map from mesh
|
||||
# dim to tensor dim
|
||||
for mesh_dim, placement in enumerate(placements):
|
||||
if isinstance(placement, Shard):
|
||||
if shard_order is not None:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
|
||||
else:
|
||||
# native dtensor-style sharding representation: map from mesh
|
||||
# dim to tensor dim
|
||||
for mesh_dim, placement in enumerate(placements):
|
||||
if isinstance(placement, Replicate):
|
||||
out_str += "R"
|
||||
elif isinstance(placement, Shard):
|
||||
if shard_order is not None:
|
||||
for tensor_dim, *mesh_dims in shard_order:
|
||||
if placement.dim == tensor_dim:
|
||||
assert mesh_dim in mesh_dims
|
||||
if len(mesh_dims) > 1:
|
||||
out_str += (
|
||||
f"S({tensor_dim})[{mesh_dims.index(mesh_dim)}]"
|
||||
)
|
||||
else:
|
||||
# no need to show device order if the tensor dim is
|
||||
# only sharded in one mesh dim
|
||||
out_str += f"S({tensor_dim})"
|
||||
break
|
||||
else:
|
||||
out_str += f"S({placement.dim})"
|
||||
if placement.dim == tensor_dim:
|
||||
assert mesh_dim in mesh_dims
|
||||
if len(mesh_dims) > 1:
|
||||
out_str += f"{placement}[{mesh_dims.index(mesh_dim)}]"
|
||||
else:
|
||||
# no need to show device order if the tensor dim is
|
||||
# only sharded in one mesh dim
|
||||
out_str += str(placement)
|
||||
break
|
||||
else:
|
||||
assert isinstance(placement, Partial)
|
||||
out_str += f"P({placement.reduce_op})"
|
||||
out_str += str(placement)
|
||||
else:
|
||||
out_str += str(placement)
|
||||
return out_str
|
||||
|
||||
@property
|
||||
|
||||
@ -69,7 +69,7 @@ class Shard(Placement):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def split_tensor(
|
||||
def _make_split_tensor(
|
||||
dim: int,
|
||||
tensor: torch.Tensor,
|
||||
num_chunks: int,
|
||||
@ -119,7 +119,7 @@ class Shard(Placement):
|
||||
with_padding: bool = True,
|
||||
contiguous: bool = True,
|
||||
) -> tuple[list[torch.Tensor], list[int]]:
|
||||
return Shard.split_tensor(
|
||||
return Shard._make_split_tensor(
|
||||
self.dim,
|
||||
tensor,
|
||||
num_chunks,
|
||||
@ -171,7 +171,7 @@ class Shard(Placement):
|
||||
return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
|
||||
|
||||
@staticmethod
|
||||
def shard_tensor(
|
||||
def _make_shard_tensor(
|
||||
dim: int,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
@ -194,13 +194,13 @@ class Shard(Placement):
|
||||
if src_data_rank is None:
|
||||
# src_data_rank specified as None explicitly means to skip the
|
||||
# communications, simply split
|
||||
scatter_list, _ = Shard.split_tensor(
|
||||
scatter_list, _ = Shard._make_split_tensor(
|
||||
dim, tensor, num_chunks, with_padding=False, contiguous=True
|
||||
)
|
||||
|
||||
return scatter_list[mesh_dim_local_rank]
|
||||
|
||||
scatter_list, pad_sizes = Shard.split_tensor(
|
||||
scatter_list, pad_sizes = Shard._make_split_tensor(
|
||||
dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
)
|
||||
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
|
||||
@ -224,7 +224,7 @@ class Shard(Placement):
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
return Shard.shard_tensor(self.dim, tensor, mesh, mesh_dim, src_data_rank)
|
||||
return Shard._make_shard_tensor(self.dim, tensor, mesh, mesh_dim, src_data_rank)
|
||||
|
||||
def _reduce_shard_tensor(
|
||||
self,
|
||||
@ -246,7 +246,7 @@ class Shard(Placement):
|
||||
|
||||
is_padded = tensor.size(self.dim) % num_chunks != 0
|
||||
if is_padded:
|
||||
scattered_list, pad_sizes = Shard.split_tensor(
|
||||
scattered_list, pad_sizes = Shard._make_split_tensor(
|
||||
self.dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
)
|
||||
tensor = torch.cat(scattered_list, dim=self.dim)
|
||||
@ -641,7 +641,7 @@ class Replicate(Placement):
|
||||
return "R"
|
||||
|
||||
@staticmethod
|
||||
def replicate_tensor(
|
||||
def _make_replicate_tensor(
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
@ -670,7 +670,7 @@ class Replicate(Placement):
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
return Replicate.replicate_tensor(tensor, mesh, mesh_dim, src_data_rank)
|
||||
return Replicate._make_replicate_tensor(tensor, mesh, mesh_dim, src_data_rank)
|
||||
|
||||
def is_replicate(self) -> bool:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user