Update on "[3/N] [DTensor device order] Make some placement type class method static"

Some methods in `Placement` class can be exposed as static.

Those method should be useful w/o initializing the object. E.g., when we `distribute_tensor` from normal tensor, we may want:
```
local_tensor = Shard.shard_tensor(tensor_dim, local_tensor, device_mesh, mesh_dim,)
```




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
This commit is contained in:
zpcore
2025-10-08 16:54:23 -07:00
5 changed files with 193 additions and 169 deletions

View File

@ -20,7 +20,11 @@ from torch.distributed.tensor import (
Shard,
)
from torch.distributed.tensor._api import _shard_tensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._dtensor_spec import (
DTensorSpec,
ShardOrderEntry,
TensorMeta,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental import implicit_replication
from torch.distributed.tensor.parallel import (
@ -1065,35 +1069,23 @@ class TestDTensorSpec(DTensorTestBase):
def test_dtensor_spec_print(self):
self.assertExpectedInline(
DTensorSpec.format_shard_order_str((Shard(2), Shard(1), Shard(0)), None),
"""S(0)[2]S(1)[1]S(2)[0]""",
"""S(2)S(1)S(0)""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Shard(2), Shard(1), Shard(0)), ((0, 2), (1, 1), (2, 0))
),
"""S(0)[2]S(1)[1]S(2)[0]""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Shard(1), Shard(1), Shard(1)), ((1, 2, 0, 1),)
),
"""S(1)[2, 0, 1]""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Shard(2), Shard(1), Shard(0)), None, False
(Shard(2), Shard(1), Shard(0)),
(
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=2, mesh_dims=(0,)),
),
),
"""S(2)S(1)S(0)""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Shard(2), Shard(1), Shard(0)), ((0, 2), (1, 1), (2, 0)), False
),
"""S(2)S(1)S(0)""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Shard(1), Shard(1), Shard(1)), ((1, 2, 0, 1),), False
(Shard(1), Shard(1), Shard(1)),
(ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0, 1)),),
),
"""S(1)[1]S(1)[2]S(1)[0]""",
)
@ -1101,23 +1093,11 @@ class TestDTensorSpec(DTensorTestBase):
DTensorSpec.format_shard_order_str(
(Replicate(), Replicate(), Replicate()), None
),
"""R*""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Replicate(), Replicate(), Shard(1)), None
),
"""S(1)[2]""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Replicate(), Replicate(), Replicate()), None, False
),
"""RRR""",
)
self.assertExpectedInline(
DTensorSpec.format_shard_order_str(
(Replicate(), Replicate(), Shard(1)), None, False
(Replicate(), Replicate(), Shard(1)), None
),
"""RRS(1)""",
)
@ -1130,28 +1110,46 @@ class TestDTensorSpec(DTensorTestBase):
tensor_global = DTensor.from_local(
tensor_local, mesh, [Shard(1), Shard(1), Shard(0)]
)
tensor_global._spec.shard_order = ((0, 2), (1, 1, 0))
tensor_global._spec.shard_order = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
)
with self.assertRaisesRegex(
AssertionError, r"shard_order .* has empty mesh dim"
):
tensor_global._spec.shard_order = ((1,), (0, 2))
tensor_global._spec.shard_order = (
ShardOrderEntry(tensor_dim=1, mesh_dims=()),
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
)
with self.assertRaisesRegex(
AssertionError, "tensor dim should be sorted in shard_order"
):
tensor_global._spec.shard_order = ((1, 1, 0), (0, 2))
tensor_global._spec.shard_order = (
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
)
with self.assertRaisesRegex(
AssertionError,
r"placement\[\d+\] doesn't have a matching shard in shard_order",
):
tensor_global._spec.shard_order = ((0, 1), (1, 1, 0))
tensor_global._spec.shard_order = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
)
with self.assertRaisesRegex(
AssertionError, r"shard_order .* has invalid mesh dim \[\d+\]"
AssertionError, r"shard_order .* has invalid mesh dim \([\d,]+\)"
):
tensor_global._spec.shard_order = ((0, 3), (1, 1, 0))
tensor_global._spec.shard_order = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(3,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
)
with self.assertRaisesRegex(
AssertionError, r"shard_order .* has invalid tensor dim -?\d+"
):
tensor_global._spec.shard_order = ((0, 2), (-1, 1, 0))
tensor_global._spec.shard_order = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
ShardOrderEntry(tensor_dim=-1, mesh_dims=(1, 0)),
)
@with_comms
def test_dtensor_spec_update(self):
@ -1168,7 +1166,10 @@ class TestDTensorSpec(DTensorTestBase):
self.assertEqual(hash(tensor_global_1._spec), hash(tensor_global_2._spec))
self.assertEqual(tensor_global_1._spec, tensor_global_2._spec)
# not using the default shard_order
tensor_global_1._spec.shard_order = ((0, 2), (1, 1, 0))
tensor_global_1._spec.shard_order = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(1, 0)),
)
# hash should be recomputed in DTensorSpec.__setattr__()
self.assertNotEqual(hash(tensor_global_1._spec), hash(tensor_global_2._spec))
self.assertNotEqual(tensor_global_1._spec, tensor_global_2._spec)
@ -1182,7 +1183,13 @@ class TestDTensorSpec(DTensorTestBase):
tensor_global = DTensor.from_local(
tensor_local, mesh, [Shard(1), Shard(1), Shard(0)]
)
self.assertEqual(tensor_global._spec.shard_order, ((0, 2), (1, 0, 1)))
self.assertEqual(
tensor_global._spec.shard_order,
(
ShardOrderEntry(tensor_dim=0, mesh_dims=(2,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 1)),
),
)
tensor_global = DTensor.from_local(
tensor_local, mesh, [Replicate(), Replicate(), Replicate()]
@ -1213,12 +1220,21 @@ class TestDTensorSpec(DTensorTestBase):
tensor_local, mesh, [Shard(1), Shard(2), Shard(1)]
)
# DTensorSpec automatically builds the default left-to-right order
self.assertEqual(tensor_global._spec.shard_order, ((1, 0, 2), (2, 1)))
self.assertEqual(
tensor_global._spec.shard_order,
(
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2)),
ShardOrderEntry(tensor_dim=2, mesh_dims=(1,)),
),
)
self.assertTrue(
DTensorSpec.is_default_device_order(tensor_global._spec.shard_order)
)
# manually set the shard_order by exchange mesh dim 0 and 2
tensor_global._spec.shard_order = ((1, 2, 0), (2, 1))
tensor_global._spec.shard_order = (
ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0)),
ShardOrderEntry(tensor_dim=2, mesh_dims=(1,)),
)
self.assertFalse(
DTensorSpec.is_default_device_order(tensor_global._spec.shard_order)
)

View File

@ -203,19 +203,19 @@ class DeviceMeshTest(DTensorTestBase):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found"):
mesh_2d.get_mesh_dim_by_name("")
mesh_2d._get_mesh_dim_by_name("")
mesh_2d = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
)
self.assertEqual(mesh_2d.get_mesh_dim_by_name("dp"), 0)
self.assertEqual(mesh_2d.get_mesh_dim_by_name("tp"), 1)
self.assertEqual(mesh_2d._get_mesh_dim_by_name("dp"), 0)
self.assertEqual(mesh_2d._get_mesh_dim_by_name("tp"), 1)
tp_mesh = mesh_2d["tp"]
self.assertEqual(tp_mesh.get_mesh_dim_by_name("tp"), 0)
self.assertEqual(tp_mesh._get_mesh_dim_by_name("tp"), 0)
non_exist_mesh_name = "dp"
with self.assertRaisesRegex(
KeyError, f"Mesh dimension '{non_exist_mesh_name}' does not exist."
):
tp_mesh.get_mesh_dim_by_name(non_exist_mesh_name)
tp_mesh._get_mesh_dim_by_name(non_exist_mesh_name)
@with_comms
def test_get_local_rank_raises_exception(self):
@ -1081,8 +1081,8 @@ class TestMeshEnv(DTensorTestBase):
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
self.assertEqual(mesh_2d.get_mesh_dim_by_name("DP"), 0)
self.assertEqual(mesh_2d.get_mesh_dim_by_name("TP"), 1)
self.assertEqual(mesh_2d._get_mesh_dim_by_name("DP"), 0)
self.assertEqual(mesh_2d._get_mesh_dim_by_name("TP"), 1)
@with_comms
def test_get_all_submeshes(self):

View File

@ -210,7 +210,7 @@ else:
"The submesh can only be a 1D mesh."
)
child_mesh_dim_name = child_mesh_dim_names[0]
return root_mesh.get_mesh_dim_by_name(child_mesh_dim_name)
return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
return None
@staticmethod
@ -322,7 +322,7 @@ else:
"""
Return all the submeshes of a given mesh dimension of the device mesh.
"""
mesh_dim = device_mesh.get_mesh_dim_by_name(mesh_dim_name)
mesh_dim = device_mesh._get_mesh_dim_by_name(mesh_dim_name)
layout = device_mesh._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(
device_mesh.mesh,
@ -811,7 +811,7 @@ else:
return not_none(_resolve_process_group(dim_group_name))
else:
mesh_dim = (
self.get_mesh_dim_by_name(mesh_dim)
self._get_mesh_dim_by_name(mesh_dim)
if isinstance(mesh_dim, str)
else mesh_dim
)
@ -1075,7 +1075,7 @@ else:
self, mesh_dim_name, backend_override_tuple
)
def get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0:
raise KeyError(
"No `mesh_dim_names` found.",

View File

@ -13,34 +13,40 @@ from torch.distributed.tensor.placement_types import (
)
MeshDimTuple = tuple[int, ...]
TensorDimTuple = tuple[MeshDimTuple, ...]
class ShardOrderEntry(NamedTuple):
"""
Represents how a single tensor dimension is sharded across mesh dimensions.
Attributes:
tensor_dim: The tensor dimension being sharded (e.g., 0, 1, 2 for a 3D tensor).
mesh_dims: Tuple of mesh dimensions across which this tensor dimension is sharded,
in execution order. The first mesh dim is applied first, second is applied
second, etc.
Examples:
>>> # Tensor dim 1 sharded across mesh dim 2, then mesh dim 0
>>> ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0))
>>> # Tensor dim 0 sharded only on mesh dim 1
>>> ShardOrderEntry(tensor_dim=0, mesh_dims=(1,))
"""
tensor_dim: int
mesh_dims: tuple[int, ...]
# Controls the print format for tensor distribution visualization.
# When True: Shows tensor-centric format mapping tensor dimensions to mesh dimensions
# When False: Shows standard DTensor mesh-centric format mapping mesh dimensions to tensor dimensions
# Type alias for the complete shard order specification
# A tuple of ShardOrderEntry, one per sharded tensor dimension
#
# Example with a 3D tensor on a 2x2x2x2 mesh (16 devices) with
# ``placements``: [Partial(), Shard(1), Shard(1), Replicate()],
# ``shard_order``: {1: [2, 1]} (tensor dim 1 shard on mesh dimension 2 first then 1)
# - mesh_dim_0: Partial reduction (sum)
# - mesh_dim_1: Shard tensor dimension 1 (executed second)
# - mesh_dim_2: Shard tensor dimension 1 (executed first)
# - mesh_dim_3: Replicate
#
# When True (tensor-centric): "S(1)[1, 2]P(sum)[0]"
# - S(1)[1, 2]: tensor dimension 1 sharded on mesh dimension 1 and then 2
# - P(sum)[0]: partial reduction on mesh dimension 0
# - Clearly shows which tensor dims map to which mesh dims
#
# When False (mesh-centric): "P(sum)S(1)[1]S(1)[0]R"
# - P(sum): mesh dimension 0 has partial reduction
# - S(1): mesh dimension 1 shards tensor dimension 1 (with order 1)
# - S(1): mesh dimension 2 shards tensor dimension 1 (with order 0)
# - R: Replicated on mesh dimension 3
# - Standard DTensor placement with concise format following mesh dimension order
tensor_centric_format = True
# Example:
# shard_order = (
# ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
# ShardOrderEntry(tensor_dim=2, mesh_dims=(0, 3)),
# )
# This means:
# - Tensor dimension 0 is sharded on mesh dimension 1
# - Tensor dimension 2 is sharded on mesh dimension 0 first, then mesh dimension 3
ShardOrder = tuple[ShardOrderEntry, ...]
class TensorMeta(NamedTuple):
@ -64,16 +70,17 @@ class DTensorSpec:
# When a tensor dimension is sharded across multiple mesh axes,
# `shard_order` specifies the sequence in which these shardings are applied.
# This order determines how tensor shards are mapped and distributed across
# devices. `shard_order` is a tuple of tuples of integers: in each inner
# tuple, the first element is the tensor dimension being sharded, and the
# remaining elements are the device mesh dimensions (in order) over which
# that tensor dimension is sharded.
# devices.
#
# Example:
# For a tensor of shape [8, 16] and a 3D device mesh, if dim 0 is sharded over
# mesh dim 1, and dim 1 is sharded over mesh dim 0 and then mesh dim 2,
# the shard_order would be: shard_order = ((0, 1), (1, 0, 2))
shard_order: TensorDimTuple = None # type: ignore[assignment]
# the shard_order would be:
# shard_order = (
# ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
# ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2)),
# )
shard_order: ShardOrder = None # type: ignore[assignment]
def __post_init__(self) -> None:
if not isinstance(self.placements, tuple):
@ -87,7 +94,13 @@ class DTensorSpec:
@staticmethod
def compute_default_sparse_shard_order(
placements: tuple[Placement, ...],
) -> TensorDimTuple:
) -> ShardOrder:
"""
Compute the default shard order from placements.
Returns a ShardOrder where each ShardOrderEntry maps a tensor dimension
to the mesh dimensions it's sharded on, in left-to-right order.
"""
# follow default left-to-right device order if shard_order is not specified
tensor_dim_to_mesh_dims: dict[int, list[int]] = {}
mesh_ndim = len(placements)
@ -104,23 +117,24 @@ class DTensorSpec:
if shard_dim not in tensor_dim_to_mesh_dims:
tensor_dim_to_mesh_dims[shard_dim] = []
tensor_dim_to_mesh_dims[shard_dim].append(mesh_dim)
# convert dict into the tuple of tuple that is hashable
# Convert dict into ShardOrderEntry tuples
default_sparse_shard_order = tuple(
tuple(item)
for item in (
[key] + value if isinstance(value, list) else [key, value]
for key, value in sorted(tensor_dim_to_mesh_dims.items())
if value
)
ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value))
for key, value in sorted(tensor_dim_to_mesh_dims.items())
if value
)
return default_sparse_shard_order
def _verify_shard_order(self, shard_order: TensorDimTuple) -> None:
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
"""Verify that the shard_order is valid and matches the placements."""
total_shard = 0
if any(isinstance(p, _StridedShard) for p in self.placements):
return
prev_tensor_dim = -1
for tensor_dim, *mesh_dims in shard_order:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
assert len(mesh_dims) > 0, f"shard_order {shard_order} has empty mesh dim"
assert tensor_dim >= 0, (
f"shard_order {shard_order} has invalid tensor dim {tensor_dim}"
@ -227,12 +241,12 @@ class DTensorSpec:
return f"Spec({placement_str} on {tensor_shape})"
@staticmethod
def is_default_device_order(shard_order: TensorDimTuple) -> bool:
def is_default_device_order(shard_order: ShardOrder) -> bool:
"""
Check if the device order is the default left-to-right order.
"""
for tensor_dim_and_mesh_dims in shard_order:
tensor_dim, *mesh_dims = tensor_dim_and_mesh_dims
for entry in shard_order:
mesh_dims = entry.mesh_dims
is_increasing = all(
prev < nxt for prev, nxt in itertools.pairwise(mesh_dims)
)
@ -243,75 +257,69 @@ class DTensorSpec:
@staticmethod
def format_shard_order_str(
placements: tuple[Placement, ...],
shard_order: Optional[TensorDimTuple] = None,
tensor_centric_format: Optional[bool] = None,
shard_order: Optional[ShardOrder] = None,
) -> str:
"""
Format DTensor sharding information as a string.
Format DTensor sharding information as a human-readable string.
This method formats the sharding pattern in mesh-centric order, showing the placement
for each mesh dimension sequentially. When a tensor dimension is sharded across multiple
mesh dimensions, the order index indicates the execution sequence of the sharding operations.
Args:
placements: Tuple of placement objects for each mesh dimension
shard_order: Optional tensor dimension to mesh dimension mapping
tensor_centric_format: Controls output format
- When True: Shows tensor-centric format mapping tensor dims to mesh dims
- When False: Shows standard DTensor mesh-centric format
placements: Tuple of placement objects for each mesh dimension.
shard_order: Optional ShardOrder specifying the sharding order.
Returns:
String representation of the sharding pattern
String representation of the sharding pattern in mesh-centric format.
Example:
For a 3D tensor on a 2x2x2x2 mesh (16 devices) with::
placements = [Partial(), Shard(1), Shard(1), Replicate()]
shard_order = (ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 1)),)
Mesh configuration:
- mesh_dim_0: Partial reduction (sum)
- mesh_dim_1: Shard tensor dimension 1 (executed second, order index 1)
- mesh_dim_2: Shard tensor dimension 1 (executed first, order index 0)
- mesh_dim_3: Replicate
Output: ``"PS(1)[1]S(1)[0]R"``
Explanation:
- ``P``: mesh dimension 0 has partial reduction
- ``S(1)[1]``: mesh dimension 1 shards tensor dimension 1 (order index 1 means second)
- ``S(1)[0]``: mesh dimension 2 shards tensor dimension 1 (order index 0 means first)
- ``R``: mesh dimension 3 replicates
The format follows mesh dimension order (0, 1, 2, 3), and when a tensor dimension
is sharded across multiple mesh dimensions, the bracketed index shows the execution
order: ``[0]`` is executed first, ``[1]`` is executed second, etc.
"""
out_str = ""
# print mapping from tensor dim to mesh dim
is_tensor_centric_format = (
tensor_centric_format
if tensor_centric_format is not None
else globals().get("tensor_centric_format", False)
)
if is_tensor_centric_format:
if shard_order is None:
shard_order = DTensorSpec.compute_default_sparse_shard_order(placements)
for tensor_dim, *mesh_dims in shard_order:
if len(mesh_dims) > 0:
out_str += f"S({tensor_dim})"
out_str += f"[{', '.join([str(m) for m in mesh_dims])}]"
# in addition, add the partial placement
partial_to_mesh_dim: dict[Partial, list[int]] = {}
for mesh_dim, p in enumerate(placements):
if isinstance(p, Partial):
if p not in partial_to_mesh_dim:
partial_to_mesh_dim[p] = []
partial_to_mesh_dim[p].append(mesh_dim)
for p, mesh_dims in partial_to_mesh_dim.items():
out_str += f"P({p.reduce_op})"
out_str += f"[{', '.join([str(m) for m in mesh_dims])}]"
# case when no dim get sharded, we use "R*" to represent
if out_str == "":
out_str = "R*"
# native dtensor-style sharding representation: map from mesh
# dim to tensor dim
for mesh_dim, placement in enumerate(placements):
if isinstance(placement, Shard):
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
else:
# native dtensor-style sharding representation: map from mesh
# dim to tensor dim
for mesh_dim, placement in enumerate(placements):
if isinstance(placement, Replicate):
out_str += "R"
elif isinstance(placement, Shard):
if shard_order is not None:
for tensor_dim, *mesh_dims in shard_order:
if placement.dim == tensor_dim:
assert mesh_dim in mesh_dims
if len(mesh_dims) > 1:
out_str += (
f"S({tensor_dim})[{mesh_dims.index(mesh_dim)}]"
)
else:
# no need to show device order if the tensor dim is
# only sharded in one mesh dim
out_str += f"S({tensor_dim})"
break
else:
out_str += f"S({placement.dim})"
if placement.dim == tensor_dim:
assert mesh_dim in mesh_dims
if len(mesh_dims) > 1:
out_str += f"{placement}[{mesh_dims.index(mesh_dim)}]"
else:
# no need to show device order if the tensor dim is
# only sharded in one mesh dim
out_str += str(placement)
break
else:
assert isinstance(placement, Partial)
out_str += f"P({placement.reduce_op})"
out_str += str(placement)
else:
out_str += str(placement)
return out_str
@property

View File

@ -69,7 +69,7 @@ class Shard(Placement):
return True
@staticmethod
def split_tensor(
def _make_split_tensor(
dim: int,
tensor: torch.Tensor,
num_chunks: int,
@ -119,7 +119,7 @@ class Shard(Placement):
with_padding: bool = True,
contiguous: bool = True,
) -> tuple[list[torch.Tensor], list[int]]:
return Shard.split_tensor(
return Shard._make_split_tensor(
self.dim,
tensor,
num_chunks,
@ -171,7 +171,7 @@ class Shard(Placement):
return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
@staticmethod
def shard_tensor(
def _make_shard_tensor(
dim: int,
tensor: torch.Tensor,
mesh: DeviceMesh,
@ -194,13 +194,13 @@ class Shard(Placement):
if src_data_rank is None:
# src_data_rank specified as None explicitly means to skip the
# communications, simply split
scatter_list, _ = Shard.split_tensor(
scatter_list, _ = Shard._make_split_tensor(
dim, tensor, num_chunks, with_padding=False, contiguous=True
)
return scatter_list[mesh_dim_local_rank]
scatter_list, pad_sizes = Shard.split_tensor(
scatter_list, pad_sizes = Shard._make_split_tensor(
dim, tensor, num_chunks, with_padding=True, contiguous=True
)
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
@ -224,7 +224,7 @@ class Shard(Placement):
mesh_dim: int,
src_data_rank: Optional[int] = 0,
) -> torch.Tensor:
return Shard.shard_tensor(self.dim, tensor, mesh, mesh_dim, src_data_rank)
return Shard._make_shard_tensor(self.dim, tensor, mesh, mesh_dim, src_data_rank)
def _reduce_shard_tensor(
self,
@ -246,7 +246,7 @@ class Shard(Placement):
is_padded = tensor.size(self.dim) % num_chunks != 0
if is_padded:
scattered_list, pad_sizes = Shard.split_tensor(
scattered_list, pad_sizes = Shard._make_split_tensor(
self.dim, tensor, num_chunks, with_padding=True, contiguous=True
)
tensor = torch.cat(scattered_list, dim=self.dim)
@ -641,7 +641,7 @@ class Replicate(Placement):
return "R"
@staticmethod
def replicate_tensor(
def _make_replicate_tensor(
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
@ -670,7 +670,7 @@ class Replicate(Placement):
mesh_dim: int,
src_data_rank: Optional[int] = 0,
) -> torch.Tensor:
return Replicate.replicate_tensor(tensor, mesh, mesh_dim, src_data_rank)
return Replicate._make_replicate_tensor(tensor, mesh, mesh_dim, src_data_rank)
def is_replicate(self) -> bool:
return True