Compare commits

...

6 Commits

Author SHA1 Message Date
771008b97d Update on "[DTensor] support flatten/unfllaten with _StridedSharding"
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-31 11:59:01 -07:00
43e4a9e07c Update on "[DTensor] support flatten/unfllaten with _StridedSharding"
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 14:12:21 -07:00
f4a5296a73 Update on "[DTensor] support flatten/unfllaten with _StridedSharding"
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 00:36:44 -07:00
7d73419996 Update on "[DTensor] support flatten/unfllaten with _StridedSharding"
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-29 17:46:08 -07:00
bb4030c675 Update on "[DTensor] support flatten/unfllaten with _StridedSharding"
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-29 15:56:48 -07:00
fc07679dd9 [DTensor] support flatten/unfllaten with _StridedSharding
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
2025-10-28 18:05:10 -07:00
5 changed files with 227 additions and 37 deletions

View File

@ -30,6 +30,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
skip_unless_torch_gpu,
with_comms,
)
from torch.distributed.tensor.placement_types import (
_StridedShard,
)
funcol = torch.ops.c10d_functional
@ -239,6 +242,19 @@ class DistMatrixOpsTest(DTensorTestBase):
self.assertEqual(y, dy.full_tensor())
@with_comms
def test_mm_with_strided_input(self):
mesh = self.build_device_mesh()
batch_size, seq_len, contract_dim, out_dim = 2, 4, 3, 7
global_inps = torch.arange(batch_size * seq_len * contract_dim, device="cuda").float().view(batch_size, seq_len, contract_dim)
inps = distribute_tensor(global_inps, mesh, (Shard(1), ))
inps_viewed = inps.view(batch_size * seq_len, contract_dim)
global_weight = torch.arange(contract_dim * out_dim).float().view(contract_dim, out_dim)
weight = distribute_tensor(global_weight, mesh, (Replicate(), ))
out = torch.mm(inps_viewed, weight)
expected_placements = (_StridedShard(dim=0, split_factor=2),)
self.assertEqual(out.placements, expected_placements)
@with_comms
def test_t(self):
device_mesh = self.build_device_mesh()

View File

@ -632,6 +632,94 @@ class TestViewOps(DTensorTestBase):
)
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
@with_comms
def test_dtensor_flatten(self):
# 1D case
mesh = init_device_mesh(self.device_type, (self.world_size,))
batch_size, seq_len, dim = 6, 6, 3
global_inps = torch.arange(batch_size * seq_len * dim).view(
batch_size, seq_len, dim
)
inps = distribute_tensor(global_inps, mesh, (Shard(1),))
inps_viewed = inps.view(batch_size * seq_len, dim)
expected_placements = (_StridedShard(dim=0, split_factor=6),)
self.assertEqual(inps_viewed.placements, expected_placements)
# 2D case: S1, S2
mesh = init_device_mesh(self.device_type, (self.world_size // 2, 2))
batch_size, seq_len, dim1, dim2 = 6, 6, 6, 3
global_inps = torch.arange(batch_size * seq_len * dim1 * dim2).view(
batch_size, seq_len, dim1, dim2
)
inps = distribute_tensor(global_inps, mesh, (Shard(1), Shard(2)))
inps_viewed = inps.view(batch_size * seq_len * dim1, dim2)
expected_placements = (
_StridedShard(dim=0, split_factor=6),
_StridedShard(dim=0, split_factor=12),
)
self.assertEqual(inps_viewed.placements, expected_placements)
# 2D case: R, S2
mesh = init_device_mesh(self.device_type, (self.world_size // 2, 2))
batch_size, seq_len, dim1, dim2 = 6, 6, 6, 3
global_inps = torch.arange(batch_size * seq_len * dim1 * dim2).view(
batch_size, seq_len, dim1, dim2
)
inps = distribute_tensor(global_inps, mesh, (Replicate(), Shard(2)))
inps_viewed = inps.view(batch_size * seq_len * dim1, dim2)
expected_placements = (
Replicate(),
_StridedShard(dim=0, split_factor=36),
)
self.assertEqual(inps_viewed.placements, expected_placements)
@with_comms
def test_dtensor_unflatten(self):
# 1D case
mesh = init_device_mesh(self.device_type, (self.world_size,))
batch_size, seq_len, dim = 6, 6, 3
global_inps = torch.arange(batch_size * seq_len * dim).view(
batch_size * seq_len, dim
)
inps = distribute_tensor(global_inps, mesh, (_StridedShard(0, split_factor=6),))
inps_viewed = inps.view(batch_size, seq_len, dim)
expected_placements = (Shard(1),)
self.assertEqual(inps_viewed.placements, expected_placements)
# 2D case: S1, S2
mesh = init_device_mesh(self.device_type, (self.world_size // 2, 2))
batch_size, seq_len, dim1, dim2 = 6, 6, 6, 3
global_inps = torch.arange(batch_size * seq_len * dim1 * dim2).view(
batch_size * seq_len * dim1, dim2
)
inps = distribute_tensor(
global_inps,
mesh,
(
_StridedShard(dim=0, split_factor=6),
_StridedShard(dim=0, split_factor=12),
),
)
inps_viewed = inps.view(batch_size, seq_len, dim1, dim2)
expected_placements = (Shard(1), Shard(2))
self.assertEqual(inps_viewed.placements, expected_placements)
# 2D case: R, S2
mesh = init_device_mesh(self.device_type, (self.world_size // 2, 2))
batch_size, seq_len, dim1, dim2 = 6, 6, 6, 3
global_inps = torch.arange(batch_size * seq_len * dim1 * dim2).view(
batch_size * seq_len * dim1, dim2
)
inps = distribute_tensor(
global_inps, mesh, (Replicate(), _StridedShard(dim=0, split_factor=36))
)
inps_viewed = inps.view(batch_size, seq_len, dim1, dim2)
expected_placements = (
Replicate(),
Shard(2),
)
self.assertEqual(inps_viewed.placements, expected_placements)
@with_comms
def test_view_redistribution(self):
"""
@ -670,6 +758,7 @@ TestViewOpsWithLocalTensor = create_local_tensor_test_class(
skipped_tests=[
# Comparing data pointers is not supported for local tensor
"test_dtensor_view_op_uneven",
"test_dtensor_flatten",
],
)

View File

@ -33,6 +33,7 @@ from torch.distributed.tensor.placement_types import (
Placement,
Replicate,
Shard,
_StridedShard,
)
@ -93,6 +94,9 @@ def _mm_like_strategy(
generate_redistribute_costs(mat2_strategy, mat2_spec),
]
strtg.redistribute_cost = redistribute_cost
if len(self_strategy.strategies) == 1 and len(self_strategy.strategies[0].output_specs.placements) == 1 and len(self_spec.placements) == 1 and self_spec.placements[0].is_shard() and isinstance(self_strategy.strategies[0].output_specs.placements[0], _StridedShard) and self_spec.placements[0].dim == self_strategy.strategies[0].output_specs.placements[0].dim and strtg.output_specs == strtg.input_specs[0]:
strtg.input_specs[0] = self_strategy.strategies[0].output_specs
strtg.output_specs = strtg.input_specs[0]
filtered_strategies.append(strtg)
mm_strategy.strategies = filtered_strategies

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import math
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from typing import cast, Optional, Union
@ -502,7 +503,7 @@ dim_maps: dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
def propagate_shape_and_sharding(
input_src_placements: Sequence[Placement],
input_src_spec,
global_input_shape: Shape,
rule: DimMap,
mesh_sizes: Shape,
@ -519,6 +520,7 @@ def propagate_shape_and_sharding(
- An output dimension that is a split of the input dimension can only be sharded
if the leftmost split size is divisible by the mesh dimension
"""
input_src_placements: Sequence[Placement] = input_src_spec.placements
if not len(input_src_placements) == len(mesh_sizes):
raise AssertionError(f"{input_src_placements} != {mesh_sizes}")
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
@ -541,12 +543,17 @@ def propagate_shape_and_sharding(
shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim
def maybe_get_shard_mesh_dim_and_placement(
input_dim: InputDim,
input_dim: InputDim, ith_shard: Optional[int] = None
) -> tuple[Optional[int], Optional[Shard]]:
# if input_dim is sharded, return the mesh_dim and shard placement
for i, placement in enumerate(input_src_placements):
if isinstance(placement, Shard) and placement.dim == input_dim.input_dim:
return i, placement
num_shard_placements = 0
for mesh_dim, placement in enumerate(input_src_placements):
if isinstance(placement, Shard):
num_shard_placements += 1
if placement.dim == input_dim.input_dim:
if ith_shard is None:
return mesh_dim, placement
elif (num_shard_placements - 1) == ith_shard:
return mesh_dim, placement
return None, None
# NOTE: This function has three responsibilities:
@ -556,10 +563,13 @@ def propagate_shape_and_sharding(
# 1 and 2 doesn't require the info of whether current input is sharded.
# 3 requires that info, to decide whether we can error out. Maybe we can refactor
# to make this function purely "theoretical".
def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]:
def get_in_dim_to_shard(
cmd: DimSpec, shard_dim_map
) -> Optional[InputDim | list[InputDim]]:
if isinstance(cmd, InputDim):
return cmd
elif isinstance(cmd, Flatten):
sharded_dims = []
for i, dim in enumerate(cmd.input_dims):
# so far all Flatten is always composed of InputDims; revisit this if needed
if not isinstance(dim, InputDim):
@ -570,12 +580,10 @@ def propagate_shape_and_sharding(
)
input_sharded = shard_mesh_dim is not None
if i > 0:
can_shard_dim = False
if strict_view and input_sharded:
raise RuntimeError(
f"Attempted to flatten multiple dimensions, with dimension {dim.input_dim} being sharded. ",
"It cannot be performed without redistribution, which is disallowed by the current operator.",
)
for x in range(0, dim.input_dim + 1):
shardable_dims[x] = [True] * mesh_ndim
sharded_dims.append(dim)
elif input_sharded:
if not (shard_placement is not None and shard_mesh_dim is not None):
raise AssertionError(
@ -593,14 +601,39 @@ def propagate_shape_and_sharding(
)
shardable_dims[dim.input_dim] = [can_shard_dim] * mesh_ndim
if not isinstance(cmd.input_dims[0], InputDim):
raise AssertionError(
f"Expected InputDim, got {type(cmd.input_dims[0])}"
)
return cmd.input_dims[0]
if len(sharded_dims) > 0:
return sharded_dims
else:
if not isinstance(cmd.input_dims[0], InputDim):
raise AssertionError(
f"Expected InputDim, got {type(cmd.input_dims[0])}"
)
return cmd.input_dims[0]
elif isinstance(cmd, Split):
in_dim = get_in_dim_to_shard(cmd.input_dim)
in_dim = get_in_dim_to_shard(cmd.input_dim, shard_dim_map)
out_size = cmd.group_shape[cmd.split_id]
if in_dim is not None:
# fix (_StridedShard(dim=0, sf=6), _StridedShard(dim=0, sf=12))
num_of_shard_placements = len(shard_dim_map.values())
shard_mesh_dim, input_src_placement = (
maybe_get_shard_mesh_dim_and_placement(
in_dim,
num_of_shard_placements,
)
)
if isinstance(input_src_placement, _StridedShard):
split_factor = math.prod(
[1] + list(cmd.group_shape[0 : cmd.split_id])
)
split_factor = split_factor // math.prod(
[1] + list(mesh_sizes[0:num_of_shard_placements])
)
if input_src_placement.split_factor != split_factor:
return None
else:
return in_dim
if cmd.split_id == 0 and in_dim is not None:
# we need to check that the input dimension is divisible
# by the size of the submesh we're sharding it on
@ -637,7 +670,7 @@ def propagate_shape_and_sharding(
# we will only shard our first component of the split
return in_dim if cmd.split_id == 0 else None
elif isinstance(cmd, Repeat):
in_dim = get_in_dim_to_shard(cmd.input_dim)
in_dim = get_in_dim_to_shard(cmd.input_dim, shard_dim_map)
if in_dim is not None:
shardable_dims[in_dim.input_dim] = [False] * mesh_ndim
return None
@ -647,20 +680,42 @@ def propagate_shape_and_sharding(
# for each output dim, find the corresponding input dim in terms of sharding prop
shard_dim_map = {}
for dim, cmd in enumerate(rule):
in_dim = get_in_dim_to_shard(cmd)
if in_dim is not None:
shard_dim_map[in_dim.input_dim] = dim
in_dims = get_in_dim_to_shard(cmd, shard_dim_map)
if isinstance(in_dims, list) and len(in_dims) > 0:
for in_dim in in_dims:
if in_dim is not None:
assert in_dim.input_dim not in shard_dim_map
shard_dim_map[in_dim.input_dim] = [dim]
else:
if in_dims is not None:
if in_dims.input_dim not in shard_dim_map:
shard_dim_map[in_dims.input_dim] = [dim]
else:
shard_dim_map[in_dims.input_dim].append(dim)
input_tgt_placements = [
(
Replicate()
if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim]
else p
)
for mesh_dim, p in enumerate(input_src_placements)
]
input_tgt_placements: list[Placement] = []
for mesh_dim, p in enumerate(input_src_placements):
if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim]:
input_tgt_placements.append(Replicate())
else:
input_tgt_placements.append(p)
def _rewrite_shard_dim(p: Shard):
def _get_split_factor(input_src_spec, shard_dim):
split_factor = 1
for tensor_dim, global_tensor_size in enumerate(input_src_spec.shape):
if tensor_dim >= shard_dim:
return split_factor
mesh_dim = input_src_spec.dim_map[tensor_dim]
if mesh_dim >= 0:
local_tensor_size = math.ceil(
global_tensor_size / input_src_spec.mesh.shape[mesh_dim]
)
else:
local_tensor_size = global_tensor_size
split_factor = split_factor * local_tensor_size
return split_factor
def _rewrite_shard_dim(p: Shard, input_src_spec):
"""
Rewrite the shard dim to the corresponding tensor dim in output.
For ``_StridedShard``, we can safely keep the placement type and
@ -677,12 +732,24 @@ def propagate_shape_and_sharding(
inner ``dim`` attribute of ``Shard`` or ``_StridedShard``.
"""
if isinstance(p, _StridedShard):
return _StridedShard(shard_dim_map[p.dim], split_factor=p.split_factor)
tgt_shard_dims = shard_dim_map[p.dim]
tgt_shard_dim = tgt_shard_dims.pop(0)
if tgt_shard_dim > 0:
return Shard(tgt_shard_dim)
else:
return _StridedShard(shard_dim_map[p.dim], split_factor=p.split_factor)
else:
return Shard(shard_dim_map[p.dim])
tgt_shard_dim = shard_dim_map[p.dim]
assert len(tgt_shard_dim) == 1
tgt_shard_dim = tgt_shard_dim[0]
if p.dim == 0:
return Shard(tgt_shard_dim)
else:
split_factor = _get_split_factor(input_src_spec, p.dim)
return _StridedShard(tgt_shard_dim, split_factor=split_factor)
output_placements = [
_rewrite_shard_dim(p) if isinstance(p, Shard) else p
_rewrite_shard_dim(p, input_src_spec) if isinstance(p, Shard) else p
for p in input_tgt_placements
]
@ -721,7 +788,7 @@ def register_op_strategy_map(
input_src_spec = input_placement_strategy.output_spec
input_tgt_placements, output_placements = propagate_shape_and_sharding(
input_src_spec.placements,
input_src_spec,
tuple(global_in_shape),
rules,
mesh.shape,

View File

@ -146,14 +146,24 @@ def _compute_local_shape_and_global_offset(
# StridedShard implies a non-standard order to apply shards; get the
# correct order to start applying splits
ordered_placements = _explicit_order_placements(mesh_shape, placements)
all_shards_are_strided = all(
isinstance(p, _StridedShard) for p in placements if isinstance(p, Shard)
)
if all_shards_are_strided:
ordered_placements: Sequence[tuple[int, Placement]] = [
(mesh_dim, p)
for (mesh_dim, p) in enumerate(placements)
if isinstance(p, Shard)
]
else:
ordered_placements = _explicit_order_placements(mesh_shape, placements)
local_shape = list(global_shape)
# We'll compute the data for where the shard begins on a per-dim basis.
# However, a single dim can be sharded multiple times, so we will end up
# doing a Sum(size*stride) like computation to determine the location of our
# shard for each of the shardings on that dim.
global_offset = [0] * len(global_shape)
global_offset: list[int | None] = [0] * len(global_shape)
for mesh_dim, placement in ordered_placements:
mesh_dim_size = mesh_shape[mesh_dim]
@ -170,6 +180,10 @@ def _compute_local_shape_and_global_offset(
local_shape[shard_dim] = shard_size
if isinstance(placement, _StridedShard):
global_offset[shard_dim] = None
continue
shard_global_offset = global_offset[shard_dim] + not_none(shard_offset)
zero_global_offset = global_shape[shard_dim]