Compare commits

...

7 Commits

Author SHA1 Message Date
907d20f41c Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />


We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-06 14:00:38 -08:00
a3a278f422 Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />


We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-06 13:34:16 -08:00
28376ac499 Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="811" height="434" alt="image" src="https://github.com/user-attachments/assets/8acf6b3e-5f97-448d-990b-a6f3de1d1077" />

We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-03 15:38:25 -08:00
2dfa23dd11 Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="811" height="434" alt="image" src="https://github.com/user-attachments/assets/8acf6b3e-5f97-448d-990b-a6f3de1d1077" />

We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-03 13:58:37 -08:00
800feff3ec Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-31 12:01:22 -07:00
e8a4f23538 Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-31 11:33:16 -07:00
6b638be0e9 Support convert StridedShard to shard order and vice versa
[ghstack-poisoned]
2025-10-31 11:02:00 -07:00
4 changed files with 391 additions and 145 deletions

View File

@ -2,7 +2,6 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import itertools
import unittest
@ -22,12 +21,15 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
distribute_tensor as _distribute_tensor,
generate_shard_orders,
instantiate_parametrized_tests,
make_full_tensor,
parametrize,
redistribute,
run_tests,
TEST_CUDA,
TEST_HPU,
@ -785,88 +787,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
else:
return ""
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
self,
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def distribute_tensor(
self,
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return self.redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def full_tensor(self, dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return self.redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def _shard_order_to_placement(self, shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
@with_comms
def test_ordered_redistribute(self):
"""Test ordered redistribution with various sharding syntaxes"""
@ -927,13 +847,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
sharding_src_dst_pairs_with_expected_trace
):
sharded_dt = self.distribute_tensor(
sharded_dt = _distribute_tensor(
input_data.clone(), mesh, src_placement, shard_order=src_order
)
with DebugMode(record_torchfunction=False) as debug_mode:
sharded_dt = self.redistribute(
sharded_dt, mesh, dst_placement, dst_order
)
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
trace_str = self._extract_redistribute_trace_from_debug_mode(
debug_mode.debug_string()
)
@ -957,49 +875,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
trace_str,
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
)
expected_dt = self.distribute_tensor(
expected_dt = _distribute_tensor(
input_data.clone(), mesh, dst_placement, shard_order=dst_order
)
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
def generate_shard_orders(self, mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n, k):
if k == 1:
yield [n]
else:
for i in range(1, n - k + 2):
for tail in compositions(n - i, k - 1):
yield [i] + tail
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
@with_comms
def test_generate_shard_orders(self):
"""Check if `generate_shard_orders` generates unique sharding combinations"""
@ -1012,7 +892,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
]
for test_input in test_inputs:
all_combinations = []
for shard_order in self.generate_shard_orders(
for shard_order in generate_shard_orders(
test_input["mesh"], test_input["tensor_rank"]
):
all_combinations.append(shard_order) # noqa: PERF402
@ -1062,12 +942,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
input_data = torch.randn(tensor_shape, device=self.device_type)
tensor_rank = input_data.ndim
with maybe_disable_local_tensor_mode():
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
shard_orders = generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
sharded_dt = self.distribute_tensor(
sharded_dt = _distribute_tensor(
input_data.clone(), mesh, placements=None, shard_order=shard_order
)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
# 2. Verify the correctness of redistribution from DTensor to DTensor.
# This test repeatedly redistributes a DTensor to various ordered
@ -1078,20 +958,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
tensor_rank = input_data.ndim
prev_sharded_dt = None
with maybe_disable_local_tensor_mode():
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
shard_orders = generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
if prev_sharded_dt is None:
prev_sharded_dt = self.distribute_tensor(
prev_sharded_dt = _distribute_tensor(
input_data.clone(),
mesh,
placements=None,
shard_order=shard_order,
)
else:
sharded_dt = self.redistribute(
sharded_dt = redistribute(
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
prev_sharded_dt = sharded_dt
@with_comms
@ -1136,13 +1016,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
local_tensor = torch.randn(shape, device=self.device_type)
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
with maybe_disable_local_tensor_mode():
shard_orders = self.generate_shard_orders(mesh, len(shape))
shard_orders = generate_shard_orders(mesh, len(shape))
for shard_order in shard_orders:
sharded_dt = self.redistribute(
sharded_dt = redistribute(
full_tensor, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
)
@unittest.skip(
@ -1161,15 +1041,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
tgt_placement = [
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
]
sharded_dt = self.distribute_tensor(
sharded_dt = _distribute_tensor(
input_data.clone(),
mesh,
src_placement,
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
)
sharded_dt = self.redistribute(
sharded_dt, mesh, tgt_placement, shard_order=None
)
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
@with_comms
def test_shard_order_same_data_as_strided_shard(self):
@ -1179,7 +1057,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
# specify right-to-left order use ordered shard
x_ordered_dt = self.distribute_tensor(
x_ordered_dt = _distribute_tensor(
x,
device_mesh,
placements=[Shard(0), Shard(0)],

View File

@ -4,8 +4,9 @@ import itertools
from typing import Any
import torch
from torch.distributed._local_tensor import LocalTensorMode
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._utils import (
_compute_local_shape_and_global_offset,
@ -23,9 +24,16 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
distribute_tensor as _distribute_tensor,
generate_shard_orders,
run_tests,
shard_order_to_placement,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
LocalDTensorTestBase,
with_comms,
)
@ -766,6 +774,63 @@ class TestStridedSharding(DTensorTestBase):
self.assertEqual(dtensor.full_tensor(), tensor)
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
@property
def world_size(self) -> int:
return 32
@with_comms
def test_StridedShard_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
shard_iter = generate_shard_orders(mesh, 3)
# It takes ~4.8h to complete total 2520 shard order combinations here
# using LocalTensor. So we only randomly pick 25 shard orders to test.
all_shard_order = list(shard_iter)
import random
random.seed(42)
shard_order_choices = random.sample(
all_shard_order, min(25, len(all_shard_order))
)
x = torch.randn(32, 32, 32)
for shard_order in shard_order_choices:
a = _distribute_tensor(x, mesh, None, shard_order)
placement_without_stridedshard = shard_order_to_placement(
shard_order, mesh
)
placements_with_stridedshard = (
DTensorSpec._convert_shard_order_to_StridedShard(
shard_order, placement_without_stridedshard, mesh
)
)
b = distribute_tensor(x, mesh, placements_with_stridedshard)
shard_order_from_stridedshard = (
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
placements_with_stridedshard, mesh
)
)
self.assertEqual(shard_order, shard_order_from_stridedshard)
self.assertEqual(a.to_local(), b.to_local())
@with_comms
def test_StridedShard_not_convertible_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
unconvertible_placements_list = [
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
[_StridedShard(0, split_factor=2), Shard(1)],
[_StridedShard(1, split_factor=16), Shard(1)],
]
for placements in unconvertible_placements_list:
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
tuple(placements), mesh
)
self.assertIsNone(shard_order)
class Test2DStridedLocalShard(DTensorTestBase):
@property
def world_size(self):

View File

@ -1,4 +1,5 @@
import itertools
import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, cast, NamedTuple, Optional
@ -7,6 +8,7 @@ import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
_StridedShard,
MaskPartial,
Partial,
Placement,
Replicate,
@ -127,6 +129,185 @@ class DTensorSpec:
)
return default_shard_order
@staticmethod
def _convert_shard_order_to_StridedShard(
shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh
) -> tuple[Placement, ...]:
"""
Convert ShardOrder to placements with _StridedShard.
This function converts a ShardOrder specification into a tuple of Placement objects,
using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions
in a non-default order. The split_factor of each _StridedShard is determined by the
product of mesh dimension sizes that appear earlier in the shard order but later in
the placement tuple.
Args:
shard_order: ShardOrder specification indicating which tensor dimensions are
sharded on which mesh dimensions and in what execution order.
placements: Tuple of Placement objects that does not contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
Updated tuple of Placement objects with Shard or _StridedShard placements.
Algorithm:
For each ShardOrderEntry in shard_order:
- For each mesh dimension in the entry's mesh_dims (in order):
- Calculate split_factor as the product of mesh sizes for all mesh dimensions
that appear:
1. Earlier in the shard order (lower index in mesh_dims), and
2. Later in the placement tuple (higher mesh dimension index)
- If split_factor == 1: use normal Shard
- Otherwise: use _StridedShard with the calculated split_factor
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)
>>> placements = (Shard(0), Shard(0), Shard(0))
>>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1
>>> # -> placements[2] = Shard(0)
>>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[0] = _StridedShard(0, split_factor=2)
>>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[1] = _StridedShard(0, split_factor=2)
>>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
"""
placements_list = list(placements)
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for idx in range(len(mesh_dims)):
# TODO(zpcore): split_factor from `view` and `shard order`
# should be able to be multiplied into one. Need to loosen the
# condition here.
if type(placements[idx]) is not Shard:
raise ValueError(
f"Only Shard placement can be converted to _StridedShard, "
f"found {placements[idx]} in {placements=}."
)
mesh_dim = mesh_dims[idx]
split_factor = math.prod(
mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim
)
if split_factor == 1:
# use normal Shard
placements_list[mesh_dim] = Shard(tensor_dim)
else:
placements_list[mesh_dim] = _StridedShard(
tensor_dim, split_factor=split_factor
)
return tuple(placements_list)
@staticmethod
def _maybe_convert_StridedShard_to_shard_order(
placements: tuple[Placement, ...], mesh: DeviceMesh
) -> Optional[ShardOrder]:
"""
Try to convert _StridedShard placements to ShardOrder.
This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard
order by examining the split_factor of each _StridedShard and determining its position
in the execution order. If the _StridedShard configuration cannot be represented as a
valid ShardOrder (i.e., there's no shard order that produces the observed split_factors),
this function returns None.
Args:
placements: Tuple of Placement objects that may contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
ShardOrder if conversion is possible, None otherwise. For placements without
_StridedShard, returns the default shard order.
Algorithm:
1. If no _StridedShard in placements, return default shard order
2. Create an empty list for each tensor dimension to represent mesh dim ordering
3. Iterate through placements in reverse order (right to left):
- For each Shard/_StridedShard on a tensor dimension:
- Extract its split_factor (1 for Shard, split_factor for _StridedShard)
- Find the position in mesh_dims_order where accumulated_sf equals split_factor
- accumulated_sf is the product of mesh sizes of mesh dimensions that appear
earlier in mesh_dims_order (lower indices)
- Insert mesh_dim at the found position
4. If no valid position found for any split_factor, return None (unable to convert)
5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
>>> # Process tensor_dim=0 from right to left:
>>> # - mesh_dim=2: Shard(0) with sf=1
>>> # Try position 0: accumulated_sf=1, matches! Insert at position 0
>>> # Current mesh_dims_order order: [2]
>>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Current mesh_dims_order order: [2, 1]
>>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Final mesh_dims_order order: [2, 0, 1]
>>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),))
>>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1
Note:
This function validates that _StridedShard can be represented as a ShardOrder.
Not all _StridedShard configurations are valid - the split_factor must match
the product of mesh sizes in some execution order.
"""
if not any(isinstance(p, _StridedShard) for p in placements):
return DTensorSpec.compute_default_shard_order(placements)
max_tensor_dim = (
max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1
)
shard_order = []
tensor_dim_to_mesh_dims_order: list[list[int]] = [
[] for i in range(max_tensor_dim)
]
for mesh_dim in reversed(range(len(placements))):
cur_placement = placements[mesh_dim]
# _StridedShard may not be a subclass of Shard in the future, so write in this way:
if isinstance(cur_placement, Shard | _StridedShard):
tensor_dim = cur_placement.dim
mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim]
cur_sf = 1
if isinstance(cur_placement, _StridedShard):
cur_sf = cur_placement.split_factor
accumulated_sf = 1
find_order = False
for i in range(len(mesh_dims_order) + 1):
if accumulated_sf == cur_sf:
mesh_dims_order.insert(i, mesh_dim)
find_order = True
break
if i < len(mesh_dims_order):
accumulated_sf *= mesh.size(mesh_dims_order[i])
if not find_order:
# _StridedShard is not convertible to ShardOrder
return None
else:
if not isinstance(cur_placement, Replicate | Partial | MaskPartial):
raise ValueError(
f"Unsupported placement type {type(cur_placement)} encountered in "
f"{placements}; expected Replicate, Partial, or MaskPartial."
)
for tensor_dim in range(max_tensor_dim):
if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0:
shard_order.append(
ShardOrderEntry(
tensor_dim=tensor_dim,
mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]),
)
)
return tuple(shard_order)
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
"""Verify that the shard_order is valid and matches the placements."""
total_shard = 0

View File

@ -5878,3 +5878,125 @@ def patch_test_members(updates: dict[str, Any]):
return wrapper
return decorator
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor as native_distribute_tensor
from torch.distributed.tensor._redistribute import redistribute_local_tensor
import itertools
def _convert_shard_order_dict_to_ShardOrder(shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def distribute_tensor(
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = native_distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def make_full_tensor(dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def shard_order_to_placement(shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def generate_shard_orders(mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n: int, k: int):
# yields lists of length k, positive ints summing to n
for cuts in itertools.combinations(range(1, n), k - 1):
# add 0 and n as sentinels, then take consecutive differences
yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield _convert_shard_order_dict_to_ShardOrder(shard_order)