Compare commits

...

1 Commits

Author SHA1 Message Date
7249ff1f47 [DTensor] StridedShard support uneven sharding
This enables using FSDP+TP on parameters with dimensions that aren't
evenly divisible by the DP/TP mesh sizes.

- this may not support all possible combinations of strided shardings
  and shardings, but the support before this PR is not complete anyway

This contains several fixes for different aspects of DTensor behavior
relating to uneven strided sharding:
- original creation of the strided tensor requires fixes in
  StridedShard._split_tensor
- full_tensor() reconstruction requries fixes in
  StridedShard._to_replicate_tensor to correctly reshuffle the data into
  the original pre-sharded order
- Distributed Checkpointing support requires correct computation of the
  compute_local_shape_and_global_offset util

Example:  (copied from _StridedShard._to_replicate_tensor docstring)
-------
mesh = (DP=2, TP=2)
original = torch.arange(5)

tp sharded tensor
-----------------
`tp = distribute_tensor(x, world_mesh['tp'], [Shard(0)])`

local_tensors:
rank0: [0,1,2]    rank1: [3,4]
rank1: [0,1,2]    rank3: [3,4]

fsdp+tp sharded tensor
----------------------
`dp_tp = ...` (the process of creating a strided-shard tensor is skipped over as it is hacky and complicated #TODO put an example somewhre and ref to it)
dp_tp has placement (_StridedShard(0, split_factor=2), Shard(0))
local_tensors:
rank0: [0,1]  rank1: [3]
rank1: [2]    rank3: [4]

Now, say someone wants to reconstruct dp_tp's full tensor. This will invoke 'redistribute' to replicate.
redistribute will first replicate the "Shard(0)" placement on the rightmost mesh dim, then replicate the
StridedShard placement second, which is implemented by this function.
So our starting point (`local_tensor` arg) is the result of replicating the Shard(0) placement across the
TP dim, which looks like this.

Note the discrepancy with the 'tp sharded tensor' line above!  We'll fix it by locally shuffling data.

local_tensors:
rank0: [0,1,3]  rank1: [0,1,3]
rank1: [2,4]    rank3: [2,4]

Step 1: replicate over the DP dimension.  Afterwards, each rank can locally sort the values.
  note: we need padding to do this allgather, and we'll need to keep track of the padding amount for later
	local_tensors:
rank0: [0,1,3,2,4]    rank1: [0,1,3,2,4]
rank1: [0,1,3,2,4]    rank3: [0,1,3,2,4]

Step 2: chunk and shuffle values around to account for the wrong order of operations above
and get the original tensor content back

01324#       <- our allgather includes padding, if padding was applied in step 1
01324        <- Remove the padding
013, 24      <- chunk once, 'undoing' the DP allgather
01, 3, 2, 4  <- chunk each chunk, 'undoing' the initial (wrong) TP allgather performed by Shard(0)->Replicate()
012, 34      <- interleave with stride=TP mesh dim size
01234        <- concatenate

ghstack-source-id: 4cd7f9b93f2c91e55c59af177a47b8a01c8559f6
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150146

[DTensor] Fix compute_local_shape_and_global_offset for uneven sharding

This fix is needed for cases where distributed checkpointing (DCP) is used to
save a local state dict.  That's becuase DCP relies on the local-shape / global-offset
for each rank being correct to save files that can be correctly
resharded (or indeed loaded at all).

(If saving a 'full state dict' instead of a local one, DCP would convert
to 'full tensors' before saving, and that logic got fixed in the
previous PR in this stack.)

Also add a util `_explicit_order_placements` which converts a list of
placements with StridedSharding into a list of placements with only
regular sharding, with the order shuffled such that it is equivalent.

ghstack-source-id: e387336aedbcfdacfbef9a882c1fe1b0080be996
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150393
2025-04-01 15:29:31 -07:00
10 changed files with 253 additions and 212 deletions

View File

@ -433,9 +433,7 @@ class TestFullyShardShardedParameterDTensor(FSDPTestMultiThread):
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
# Use odd dim sizes to test uneven shards
# TODO: change "mlp_dim" back to 9 when uneven sharding
# is supported for FSDP+TP
model = MLP(8, dim_multiplier=3)
model = MLP(9, dim_multiplier=3)
orig_params = [param.detach().clone() for param in model.parameters()]
orig_param_names = [param_name for param_name, _ in model.named_parameters()]
parallelize_module(

View File

@ -1155,9 +1155,7 @@ class TestFullyShardNDTraining(FSDPTest):
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
# is supported for FSDP+TP
"mlp_dim": [4, 16, 20],
"mlp_dim": [5, 16, 17],
"foreach": [False],
},
functools.partial(self._test_2d_mlp_with_nd_mesh, global_mesh),
@ -1230,9 +1228,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
# is supported for FSDP+TP
"mlp_dim": [4, 16, 20],
"mlp_dim": [3, 5, 16, 17],
"foreach": [False],
},
functools.partial(self._test_3d_mlp_with_nd_mesh, global_mesh),
@ -1261,6 +1257,12 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
)
# Checking paramters match orig model is critical to validate .full_tensor correctly replicates the
# strided-sharded layers.
for ref_p, p in zip(ref_model.parameters(), model.parameters()):
self.assertIsInstance(p, DTensor)
self.assertEqual(ref_p, p.full_tensor())
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
torch.manual_seed(42 + dp_pg.rank() + 1)

View File

@ -115,15 +115,6 @@ class TestFullyShard2DTraining(FSDPTest):
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)
# TODO: remove this test when uneven sharding is supported for FSDP+TP
@skip_if_lt_x_gpu(2)
def test_2d_uneven_shard_raise_error(self):
global_mesh = self.init_global_mesh()
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
model = MLPStack(3)
with self.assertRaisesRegex(NotImplementedError, "uneven sharding"):
model.parallelize(tp_mesh, dp_mesh, False)
@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_train_parity_2d_mlp(self):
@ -132,9 +123,7 @@ class TestFullyShard2DTraining(FSDPTest):
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
# is supported for FSDP+TP
"mlp_dim": [4, 16, 20],
"mlp_dim": [3, 16, 17],
},
functools.partial(self._test_train_parity_2d_mlp, global_mesh),
)

View File

@ -458,9 +458,14 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
"""
Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP.
"""
mlp_dim = 5
def _get_base_model(mlp_dim: int = 2):
base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim))
def _get_base_model(mlp_dim: int = mlp_dim):
base_model = nn.Sequential(
MLP(mlp_dim, dim_multiplier=1),
MLP(mlp_dim, dim_multiplier=1),
MLP(mlp_dim, dim_multiplier=1),
)
return base_model
cm = (
@ -468,13 +473,15 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
if allow_implicit_replication
else contextlib.nullcontext()
)
# Must set 'use_local_output=False' in order to test uneven-sharding case
# see https://github.com/pytorch/pytorch/issues/150336
tp_parallelize_plan = {
"0.in_proj": ColwiseParallel(),
"0.out_proj": RowwiseParallel(),
"1.in_proj": ColwiseParallel(),
"1.out_proj": RowwiseParallel(),
"2.in_proj": ColwiseParallel(),
"2.out_proj": RowwiseParallel(),
"0.in_proj": ColwiseParallel(use_local_output=False),
"0.out_proj": RowwiseParallel(use_local_output=False),
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(use_local_output=False),
}
if allow_implicit_replication:
# intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized
@ -482,15 +489,6 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
tp_parallelize_plan.pop("0.out_proj")
with cm:
tp_parallelize_plan = {
"0.in_proj": ColwiseParallel(),
"0.out_proj": RowwiseParallel(),
"1.in_proj": ColwiseParallel(),
"1.out_proj": RowwiseParallel(),
"2.in_proj": ColwiseParallel(),
"2.out_proj": RowwiseParallel(),
}
# init device mesh
dp_size = 2
global_mesh_1d = init_device_mesh(
@ -521,7 +519,13 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
# one-step training to modify state dict
inp = torch.randn((2,), device=self.rank)
inp = torch.randn(
(
1,
mlp_dim,
),
device=self.rank,
)
base_model(inp).sum().backward()
base_optim.step()
fsdp2_tp_model(inp).sum().backward()

View File

@ -956,7 +956,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
)
if size == 0:
# when tensor size is 0, there is no padding needed for all the ranks.
expected_pad_sizes = []
expected_pad_sizes = [0] * self.world_size
assert_array_equal(expected_pad_sizes, pad_sizes)
is_tensor_empty = [

View File

@ -114,6 +114,31 @@ class UtilTest(DTensorTestBase):
self.assertEqual(local_shape, expected_local_shape)
self.assertEqual(global_offset, expected_global_offset)
@with_comms
def test_uneven_fsdp_tp_meta_compute(self):
# FSDP + TP uneven sharding
tp_size = 2
dp_size = self.world_size // tp_size
global_mesh = init_device_mesh(
self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
global_tensor_shape = torch.Size([15, 5])
placements = [_StridedShard(0, split_factor=tp_size), Shard(0)]
local_shape, global_offset = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
rank = global_mesh.get_rank()
if dp_size == 2:
expected_shapes = [4, 4, 4, 3]
expected_offsets = [0, 8, 4, 12]
elif dp_size == 4:
expected_shapes = [2, 2, 2, 2, 2, 2, 2, 1]
expected_offsets = [0, 8, 2, 10, 4, 12, 6, 14]
else:
raise RuntimeError("Expected dp_size 2 or 4")
self.assertEqual(local_shape[0], expected_shapes[rank])
self.assertEqual(global_offset[0], expected_offsets[rank])
@with_comms
def test_hsdp_tp_meta_compute(self):
# HSDP + TP sharding

View File

@ -327,16 +327,6 @@ class FSDPParam:
self._spmd_placements,
tensor_meta=self._tp_spec.tensor_meta,
)
# TODO: Enable uneven sharding for FSDP+TP.
if split_factor > 1: # FSDP has strided sharding on tensor dim 0
num_shards = self._sharding_spec.num_shards_map[0]
tensor_size_dim_0 = self._sharding_spec.shape[0]
if tensor_size_dim_0 % num_shards != 0:
raise NotImplementedError(
"FSDP+TP sharding does not support uneven sharding for now: "
f"tensor dim 0 has size {tensor_size_dim_0} which cannot be "
f"evenly sharded into {num_shards} shards."
)
param_data = cast(DTensor, param)._local_tensor
else:
self._spmd_mesh = self.mesh_info.mesh

View File

@ -200,9 +200,7 @@ def fill_empty_tensor_to_shards(
if num_empty_tensors == 0:
return shards
tensor_size = list(shards[0].size())
tensor_size = [
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
]
tensor_size[shard_dim] = 0
tensor = shards[0].new_zeros(tensor_size)
shards.extend(tensor for _ in range(num_empty_tensors))
return shards

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Sequence
from typing import cast
@ -15,6 +16,67 @@ from torch.distributed.tensor.placement_types import (
)
# TODO(whc) add tests for this util
def _explicit_order_placements(
mesh_shape: ShapeType, placements: Sequence[Placement]
) -> Sequence[tuple[int, Placement]]:
"""
Replace Strided Shards with regular shards in an adjusted order.
Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order.
ex.
[Shard(0), _StridedShard(0, split_factor=2), Shard(0)] ->
[(0, Shard(0)), (2, Shard(0)), (1, Shard(0))]
"""
if not any(isinstance(p, _StridedShard) for p in placements):
return list(enumerate(placements))
ordered = []
deferred_strided_placements = defaultdict(list)
strided_part_ended_for_dim = set()
for mesh_dim, p in enumerate(placements):
if isinstance(p, _StridedShard):
# validate the stride is the correct multiple of the meshdim and the earlier shard
deferred_strided_placements[p.dim].append((mesh_dim, p))
else:
ordered.append((mesh_dim, p))
if isinstance(p, Shard):
if p.dim in strided_part_ended_for_dim:
raise NotImplementedError(
f"Strided sharding does not allow Shard() to appear after "
f"the strided part has ended. {p} at mesh dim {mesh_dim} in "
f"{placements} violates this assumption."
)
if p.dim in deferred_strided_placements:
strided_part_ended_for_dim.add(p.dim)
strided_placements = deferred_strided_placements.pop(p.dim)
aggregate_size = mesh_shape[mesh_dim]
while len(strided_placements) > 0:
# We can process multiple strided shardings in reverse-order
# (e.g. [_StridedShard(0, split_factor=4), _StridedShard(0, split_factor=2), Shard(0)])
# TODO- validate this logic and enable it (mainly, validate aggregate_size part)
if len(strided_placements) > 1:
raise NotImplementedError(
"NYI nested strided sharding conversion to ordered"
)
strided_dim, strided = strided_placements.pop()
if not strided.split_factor == aggregate_size:
raise RuntimeError(
f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})"
f" == aggregate mesh size ({aggregate_size})"
)
aggregate_size *= mesh_shape[strided_dim]
ordered.append((strided_dim, Shard(p.dim)))
return ordered
# TODO(whc) the big huge NOTE below- can we change it now or is it all still relevant?
def compute_local_shape_and_global_offset(
global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
) -> tuple[tuple[int, ...], tuple[int, ...]]:
@ -22,39 +84,9 @@ def compute_local_shape_and_global_offset(
Compute the local tensor shape and the global offsets into the original tensor
of a DTensor on its current global rank. This is useful for checkpointing purpose.
Example (2 host with 4GPUs each):
# Below is a DeviceMesh with mesh_shape of (2, 4)
mesh = DeviceMesh(device_type="cuda",
mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
],
)
Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh
with a placements of [Shard(0), Shard(0)].
The local shape and global offset will be as follows:
rank0 -- local_shape:[1, 4], global_offset:[0, 0]
rank1 -- local_shape:[1, 4], global_offset:[1, 0]
rank2 -- local_shape:[1, 4], global_offset:[2, 0]
rank5 -- local_shape:[1, 4], global_offset:[5, 0]
rank3 -- local_shape:[1, 4], global_offset:[3, 0]
rank4 -- local_shape:[1, 4], global_offset:[4, 0]
rank6 -- local_shape:[1, 4], global_offset:[6, 0]
rank7 -- local_shape:[1, 4], global_offset:[7, 0]
Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with
a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks.
The local shape and global offset will be as follows:
rank0 -- local_shape:[1,], global_offset:[0,]
rank1 -- local_shape:[1,], global_offset:[1,]
rank2 -- local_shape:[0,], global_offset:[2,]
rank5 -- local_shape:[0,], global_offset:[2,]
rank3 -- local_shape:[0,], global_offset:[2,]
rank4 -- local_shape:[0,], global_offset:[2,]
rank6 -- local_shape:[0,], global_offset:[2,]
rank7 -- local_shape:[0,], global_offset:[2,]
"""
ordered_placements = _explicit_order_placements(mesh.shape, placements)
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
@ -63,13 +95,8 @@ def compute_local_shape_and_global_offset(
else:
local_shape = list(global_shape)
global_offset = [0] * len(global_shape)
shard_idx_stride_by_mesh_dim = [
[0] * mesh.ndim for _ in range(len(global_shape))
] # index by (shard_dim, mesh_dim)
num_shards_by_tensor_dim = [1] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
for mesh_dim, placement in ordered_placements:
mesh_dim_size = mesh.size(mesh_dim)
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
@ -79,7 +106,7 @@ def compute_local_shape_and_global_offset(
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[idx],
my_coordinate[mesh_dim],
return_offset=True,
)
@ -95,8 +122,6 @@ def compute_local_shape_and_global_offset(
else:
global_offset[shard_dim] += local_offset[shard_dim]
num_shards_by_tensor_dim[shard_dim] *= mesh_dim_size
# NOTE: the offset compute relies on the local shard index and it has no
# problem when strided sharding is not present. To correctly compute, we assume
# that the ``_StridedShard.split_factor`` field encodes how many partitions
@ -120,46 +145,6 @@ def compute_local_shape_and_global_offset(
# happen on mesh of 3 or more dimensions.
# TODO: change this function to correctly address this.
# TODO: this logic can be applied to contiguous sharding as well
strided_sharding = any(isinstance(p, _StridedShard) for p in placements)
if strided_sharding:
strided_part_seen = [False] * len(global_shape)
strided_part_end = [False] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
if strided_part_end[shard_dim]:
raise NotImplementedError(
f"Strided sharding does not allow Shard() to appear after "
f"the strided part has ended. {placement} at idx {idx} in "
f"{placements} violates this assumption."
)
if strided_part_seen[shard_dim]:
strided_part_end[shard_dim] = True
if isinstance(placement, _StridedShard):
strided_part_seen[shard_dim] = True
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
num_shards_by_tensor_dim[shard_dim]
// (placement.split_factor * mesh_dim_size)
)
else:
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
num_shards_by_tensor_dim[shard_dim]
)
shard_idx = [
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
for shard_dim, shard_idx_stride in enumerate(
shard_idx_stride_by_mesh_dim
)
]
global_offset = [x * y for x, y in zip(local_shape, shard_idx)]
return tuple(local_shape), tuple(global_offset)

View File

@ -89,39 +89,22 @@ class Shard(Placement):
# chunk tensor over dimension `dim` into n slices
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
num_empty_tensors = num_chunks - len(tensor_list)
# if no need to have padding or tensor dim size is evenly sharded already
# we can return early.
if not with_padding or tensor.size(self.dim) % num_chunks == 0:
if contiguous:
tensor_list = [t.contiguous() for t in tensor_list]
return (
fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors),
[],
)
tensor_list = fill_empty_tensor_to_shards(
tensor_list, self.dim, num_chunks - len(tensor_list)
)
# compute the chunk size inline with ``torch.chunk`` to calculate padding
full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks
# Compute chunk size for each chunk for ``self.dim``
chunk_sizes = [
tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0
for idx in range(num_chunks)
]
# Compute pad size on each chunk
pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes]
# Reuse tensor to fill empty chunk with empty tensor
tensor_list = fill_empty_tensor_to_shards(
tensor_list, self.dim, num_empty_tensors
)
shard_list = []
for shard, pad_size in zip(tensor_list, pad_sizes):
# Fill the empty tensor with zeroes with padding.
if with_padding and pad_size > 0:
shard_list: list[torch.Tensor] = []
pad_sizes: list[int] = []
for shard in tensor_list:
if with_padding:
pad_size = full_chunk_size - shard.size(self.dim)
shard = pad_tensor(shard, self.dim, pad_size)
shard = shard.contiguous() if contiguous else shard
pad_sizes.append(pad_size)
if contiguous:
shard = shard.contiguous()
shard_list.append(shard)
return shard_list, pad_sizes
@ -134,6 +117,11 @@ class Shard(Placement):
) -> tuple[int, int]:
"""
returns the local shard size and offset on a given tensor dim
# TODO(whc)
- `size_on_dim` arg has a pretty confusing name, would 'global_dim_size' be better?
- `local_shard_size_on_dim` could be renamed to `local_shard_size`?
- why do we have 'return_offset` bool? if we always return a tuple anyway, why not always return the real offset?
"""
# Compute the chunk size inline with ``torch.chunk``
if size_on_dim % num_chunks == 0:
@ -186,14 +174,17 @@ class Shard(Placement):
tensor, num_chunks, with_padding=True, contiguous=True
)
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
# perform scatter from the src_data_rank as data source when it is not None
mesh_scatter(
output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank
)
# Only unpad if the local_tensor was padded on the dimension.
if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0:
if pad_sizes[mesh_dim_local_rank] > 0:
output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank])
# Unpad might return a view, hence we need to remake it contiguous
output = output.contiguous()
return output
def _reduce_shard_tensor(
@ -243,15 +234,13 @@ class Shard(Placement):
is replicated on the previously sharded mesh dimension
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
# check if it's uneven, so we need to pad input tensor before all_gather
local_shape = list(local_tensor.size())
logical_dim_size = current_logical_shape[self.dim]
is_padded = logical_dim_size % num_chunks != 0
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
pad_size = full_chunk_size - local_shape[self.dim]
pad_size = full_chunk_size - local_tensor.size(self.dim)
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
if not local_tensor.is_contiguous():
@ -427,8 +416,6 @@ class _StridedShard(Shard):
dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the
`split_factor` of the _StridedShard placement on "dp" dim is 2.
TODO: strided sharding needs to work fine with uneven sharding. Now it forbids
resharding if the tensor is unevenly sharded.
TODO: we should remove _StridedShard placement once we can unify it with Shard
"""
@ -457,6 +444,8 @@ class _StridedShard(Shard):
"""human readable representation of the _StridedShard placement"""
return f"_S({self.dim}, {self.split_factor})"
# TODO(whc) (we should update this to match the uneven shard behavior in `to_replicate_tensor
# but this only matters for unit tests
def _split_tensor(
self,
tensor: torch.Tensor,
@ -465,37 +454,35 @@ class _StridedShard(Shard):
with_padding: bool = True,
contiguous: bool = True,
) -> tuple[list[torch.Tensor], list[int]]:
"""
TODO: currently _StridedShard does not support padding
"""
assert self.dim <= tensor.ndim, (
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
)
total_split = num_chunks * self.split_factor
assert tensor.size(self.dim) % total_split == 0, (
"_StridedShard currently only allows even sharding but got tensor size"
f" {tensor.size(self.dim)} on dim {self.dim} and total split"
f" {total_split}={num_chunks} * {self.split_factor}"
tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim))
tensor_list = fill_empty_tensor_to_shards(
tensor_list, self.dim, total_split - len(tensor_list)
)
group_size = self.split_factor
total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim))
tensor_list = [
torch.cat(
[
total_split_tensor_list[i + j * num_chunks] # stride is num_chunks
for j in range(group_size)
],
# compute the chunk size inline with ``torch.chunk`` to calculate padding
full_chunk_size = (tensor.size(self.dim) + total_split - 1) // total_split
shard_list: list[torch.Tensor] = []
pad_sizes: list[int] = []
for i in range(num_chunks):
shard = torch.cat(
[tensor_list[i + j * num_chunks] for j in range(self.split_factor)],
dim=self.dim,
)
for i in range(num_chunks)
]
if contiguous:
tensor_list = [t.contiguous() for t in tensor_list]
return tensor_list, []
if with_padding:
pad_size = full_chunk_size * self.split_factor - shard.size(self.dim)
shard = pad_tensor(shard, self.dim, pad_size)
pad_sizes.append(pad_size)
if contiguous:
shard = shard.contiguous()
shard_list.append(shard)
return shard_list, pad_sizes
def _to_replicate_tensor(
self,
@ -505,17 +492,77 @@ class _StridedShard(Shard):
current_logical_shape: list[int],
) -> torch.Tensor:
"""
Note: currently _StridedShard does not support padding
Given a tensor with strided sharding (e.g. [StridedShard(d), Shard(d)]),
this function is called during the process of converting to [Replicate(), Replicate()],
and `local_tensor` represents the portion of the tensor on this rank after the intermediate step of
converting to [StridedShard(d), Replicate()] in right-to-left unsharding order.
note: this conversion logic is pretty specialized on this 2D case. It could be generalized further. This
is a common enough case to be worth fixing (since it occurs when applying TP and then FSDP to a model).
note: this does not support 'reduce_scatter' for StridedShard.
Example
-------
mesh = (DP=2, TP=2)
# single-gpu "weight" of size 5, will be 'uneven' for sharding
original = torch.arange(5)
tp sharded tensor
-----------------
`tp = distribute_tensor(x, world_mesh['tp'], [Shard(0)])`
local_tensors:
rank0: [0,1,2] rank1: [3,4]
rank1: [0,1,2] rank3: [3,4]
fsdp+tp sharded tensor
----------------------
`dp_tp = ...` (the process of creating a strided-shard tensor is skipped over as it is hacky and complicated
#TODO put an example somewhre and ref to it)
dp_tp has placement (_StridedShard(0, split_factor=2), Shard(0))
local_tensors:
rank0: [0,1] rank1: [3]
rank1: [2] rank3: [4]
Now, say someone wants to reconstruct dp_tp's full tensor. This will invoke 'redistribute' to replicate.
redistribute will first replicate the "Shard(0)" placement on the rightmost mesh dim, then replicate the
StridedShard placement second, which is implemented by this function.
So our starting point (`local_tensor` arg) is the result of replicating the Shard(0) placement across the
TP dim, which looks like this.
Note the discrepancy with the 'tp sharded tensor' line above! We'll fix it by locally shuffling data.
local_tensors:
rank0: [0,1,3] rank1: [0,1,3]
rank1: [2,4] rank3: [2,4]
Step 1: replicate over the DP dimension. Afterwards, each rank can locally sort the values.
note: we need padding to do this allgather, and we'll need to keep track of the padding amount for later
local_tensors:
rank0: [0,1,3,2,4] rank1: [0,1,3,2,4]
rank1: [0,1,3,2,4] rank3: [0,1,3,2,4]
Step 2: chunk and shuffle values around to account for the wrong order of operations above
and get the original tensor content back
01324# <- our allgather includes padding, if padding was applied in step 1
01324 <- Remove the padding
013, 24 <- chunk once, 'undoing' the DP allgather
01, 3, 2, 4 <- chunk each chunk, 'undoing' the initial (wrong) TP allgather performed by Shard(0)->Replicate()
012, 34 <- interleave with stride=TP mesh dim size
01234 <- concatenate
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
total_split = num_chunks * self.split_factor
# NOTE: we require Strided Sharding to be even for now
assert current_logical_shape[self.dim] % total_split == 0, (
"_StridedShard requires even sharding but got tensor size "
f"{current_logical_shape[self.dim]} on dim {self.dim} and "
f"total split {total_split}=num_chunks {num_chunks} "
f"* split_factor {self.split_factor}"
)
logical_dim_size = current_logical_shape[self.dim]
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
local_pad_size = full_chunk_size - local_tensor.size(self.dim)
if local_pad_size > 0:
local_tensor = pad_tensor(local_tensor, self.dim, local_pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
result = funcol.all_gather_tensor(
local_tensor,
@ -525,19 +572,22 @@ class _StridedShard(Shard):
if isinstance(result, funcol.AsyncCollectiveTensor):
result = result.wait()
tensor_shard_list = torch.chunk(result, total_split, dim=self.dim)
# rearrange the order
new_tensor_shard_list = []
for idx in range(len(tensor_shard_list)):
# the shard split of index `idx` is assigned a new index within
# _StridedShard._split_tensor:
# the original tensor was split into `total_split` chunks,
# all chunks with the same `idx % num_chunks` are merged into one
# new shard and placed on mesh's local rank `idx % num_chunks`
idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks
new_tensor_shard_list.append(tensor_shard_list[idx_after_split])
if result.shape[self.dim] > logical_dim_size:
result = unpad_tensor(
result, self.dim, result.shape[self.dim] - logical_dim_size
)
return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous()
# this reverses our 'all_gather' but gives every rank a copy
dp_shards = torch.chunk(result, num_chunks, dim=self.dim)
# this undoes the 'Shard(0)' -> Replicate() that happened over the wrong mesh dim in the first place
tp_shards: list[torch.Tensor] = []
for p in dp_shards:
tp_shards.extend(torch.chunk(p, self.split_factor, dim=self.dim))
# now we just have to correctly stride the shards
reordered_shards = []
for i in range(self.split_factor):
reordered_shards.extend(tp_shards[i :: self.split_factor])
return torch.cat(reordered_shards, dim=self.dim).contiguous()
@dataclass(frozen=True)