Compare commits

...

3 Commits

Author SHA1 Message Date
50beed91ce WIP repro issue and fix compute_local_shape_and_global_offset 2025-03-31 19:10:11 -07:00
05e1fda920 part 2 - fix strided sharding for uneven padding
this builds on the previous PR and corrects the full_tensor
reconstruction to account for padding in the case of strided
sharding with uneven tensor shape and padding.

Example:  (copied from _StridedShard._to_replicate_tensor docstring)
-------
mesh = (DP=2, TP=2)
original = torch.arange(5)

tp sharded tensor
-----------------
`tp = distribute_tensor(x, world_mesh['tp'], [Shard(0)])`

local_tensors:
rank0: [0,1,2]    rank1: [3,4]
rank1: [0,1,2]    rank3: [3,4]

fsdp+tp sharded tensor
----------------------
`dp_tp = ...` (the process of creating a strided-shard tensor is skipped over as it is hacky and complicated #TODO put an example somewhre and ref to it)
dp_tp has placement (_StridedShard(0, split_factor=2), Shard(0))
local_tensors:
rank0: [0,1]  rank1: [3]
rank1: [2]    rank3: [4]

Now, say someone wants to reconstruct dp_tp's full tensor. This will invoke 'redistribute' to replicate.
redistribute will first replicate the "Shard(0)" placement on the rightmost mesh dim, then replicate the
StridedShard placement second, which is implemented by this function.
So our starting point (`local_tensor` arg) is the result of replicating the Shard(0) placement across the
TP dim, which looks like this.

Note the discrepancy with the 'tp sharded tensor' line above!  We'll fix it by locally shuffling data.

local_tensors:
rank0: [0,1,3]  rank1: [0,1,3]
rank1: [2,4]    rank3: [2,4]

Step 1: replicate over the DP dimension.  Afterwards, each rank can locally sort the values.
  note: we need padding to do this allgather, and we'll need to keep track of the padding amount for later
	local_tensors:
rank0: [0,1,3,2,4]    rank1: [0,1,3,2,4]
rank1: [0,1,3,2,4]    rank3: [0,1,3,2,4]

Step 2: chunk and shuffle values around to account for the wrong order of operations above
and get the original tensor content back

01324#       <- our allgather includes padding, if padding was applied in step 1
01324        <- Remove the padding
013, 24      <- chunk once, 'undoing' the DP allgather
01, 3, 2, 4  <- chunk each chunk, 'undoing' the initial (wrong) TP allgather performed by Shard(0)->Replicate()
012, 34      <- interleave with stride=TP mesh dim size
01234        <- concatenate

ghstack-source-id: 646a190c216521077b8f20dae22201ba49f86f54
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150146
2025-03-27 16:30:20 -07:00
fd983305ee Support uneven sharding for FSDP2 + TP
ghstack-source-id: e64108627909533b09b05a88927a8e5f25d213cf
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148894
2025-03-11 12:54:06 +00:00
11 changed files with 389 additions and 242 deletions

View File

@ -420,9 +420,7 @@ class TestFullyShardShardedParameterDTensor(FSDPTestMultiThread):
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
# Use odd dim sizes to test uneven shards
# TODO: change "mlp_dim" back to 9 when uneven sharding
# is supported for FSDP+TP
model = MLP(8, dim_multiplier=3)
model = MLP(9, dim_multiplier=3)
orig_params = [param.detach().clone() for param in model.parameters()]
orig_param_names = [param_name for param_name, _ in model.named_parameters()]
parallelize_module(

View File

@ -1155,9 +1155,7 @@ class TestFullyShardNDTraining(FSDPTest):
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
# is supported for FSDP+TP
"mlp_dim": [4, 16, 20],
"mlp_dim": [5, 16, 17],
"foreach": [False],
},
functools.partial(self._test_2d_mlp_with_nd_mesh, global_mesh),
@ -1230,9 +1228,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
# is supported for FSDP+TP
"mlp_dim": [4, 16, 20],
"mlp_dim": [3, 5, 16, 17],
"foreach": [False],
},
functools.partial(self._test_3d_mlp_with_nd_mesh, global_mesh),
@ -1261,6 +1257,12 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
)
# Checking paramters match orig model is critical to validate .full_tensor correctly replicates the
# strided-sharded layers.
for ref_p, p in zip(ref_model.parameters(), model.parameters()):
self.assertIsInstance(p, DTensor)
self.assertEqual(ref_p, p.full_tensor())
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
torch.manual_seed(42 + dp_pg.rank() + 1)

View File

@ -115,15 +115,6 @@ class TestFullyShard2DTraining(FSDPTest):
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)
# TODO: remove this test when uneven sharding is supported for FSDP+TP
@skip_if_lt_x_gpu(2)
def test_2d_uneven_shard_raise_error(self):
global_mesh = self.init_global_mesh()
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
model = MLPStack(3)
with self.assertRaisesRegex(NotImplementedError, "uneven sharding"):
model.parallelize(tp_mesh, dp_mesh, False)
@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_train_parity_2d_mlp(self):
@ -132,9 +123,7 @@ class TestFullyShard2DTraining(FSDPTest):
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
# is supported for FSDP+TP
"mlp_dim": [4, 16, 20],
"mlp_dim": [3, 16, 17],
},
functools.partial(self._test_train_parity_2d_mlp, global_mesh),
)

View File

@ -446,7 +446,7 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
@skip_if_lt_x_gpu(4)
def test_save_with_fsdp2_tp_and_load_with_tp(self):
self.run_subtests(
{"allow_implicit_replication": [True, False]},
{"allow_implicit_replication": [False]},
self._test_save_with_fsdp2_tp_and_load_with_tp,
)
@ -458,9 +458,9 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
"""
Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP.
"""
def _get_base_model(mlp_dim: int = 2):
base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim))
mlp_dim = 5
def _get_base_model(mlp_dim: int = mlp_dim):
base_model = nn.Sequential(MLP(mlp_dim, dim_multiplier=1), MLP(mlp_dim, dim_multiplier=1), MLP(mlp_dim, dim_multiplier=1))
return base_model
cm = (
@ -468,27 +468,30 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
if allow_implicit_replication
else contextlib.nullcontext()
)
tp_parallelize_plan = {
"0.in_proj": ColwiseParallel(),
"0.out_proj": RowwiseParallel(),
"1.in_proj": ColwiseParallel(),
"1.out_proj": RowwiseParallel(),
"2.in_proj": ColwiseParallel(),
"2.out_proj": RowwiseParallel(),
}
if allow_implicit_replication:
# intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized
tp_parallelize_plan.pop("0.in_proj")
tp_parallelize_plan.pop("0.out_proj")
# TODO(whc) this code seems broken on main? overwritten by the code below
# tp_parallelize_plan = {
# "0.in_proj": ColwiseParallel(),
# "0.out_proj": RowwiseParallel(),
# "1.in_proj": ColwiseParallel(),
# "1.out_proj": RowwiseParallel(),
# "2.in_proj": ColwiseParallel(),
# "2.out_proj": RowwiseParallel(),
# }
# if allow_implicit_replication:
# # intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized
# tp_parallelize_plan.pop("0.in_proj")
# tp_parallelize_plan.pop("0.out_proj")
with cm:
# Must set 'use_local_output=False' in order to test uneven-sharding case
# see https://github.com/pytorch/pytorch/issues/150336
tp_parallelize_plan = {
"0.in_proj": ColwiseParallel(),
"0.out_proj": RowwiseParallel(),
"1.in_proj": ColwiseParallel(),
"1.out_proj": RowwiseParallel(),
"2.in_proj": ColwiseParallel(),
"2.out_proj": RowwiseParallel(),
"0.in_proj": ColwiseParallel(use_local_output=False),
"0.out_proj": RowwiseParallel(use_local_output=False),
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(use_local_output=False),
}
# init device mesh
@ -503,7 +506,8 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
)
dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
for save_full_state_dict in [True, False]:
for save_full_state_dict in [False]:
# for save_full_state_dict in [True]:
# Save state dict with original model
base_model = _get_base_model().cuda()
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
@ -520,8 +524,22 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
fully_shard(fsdp2_tp_model, mesh=dp_mesh)
fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
import torch.distributed.distributed_c10d as dist
rank = dist.get_rank()
par = dict(fsdp2_tp_model.named_parameters())["1.in_proj.weight"]
"""
rank=0 par.shape=torch.Size([5, 5]), (_StridedShard(dim=0, sf=2), Shard(dim=0)) par._local_tensor.shape=torch.Size([2, 5]) shapes=(2, 5) offsets=(0, 0)
rank=1 par.shape=torch.Size([5, 5]), (_StridedShard(dim=0, sf=2), Shard(dim=0)) par._local_tensor.shape=torch.Size([1, 5]) shapes=(1, 5) offsets=(2, 0)
rank=2 par.shape=torch.Size([5, 5]), (_StridedShard(dim=0, sf=2), Shard(dim=0)) par._local_tensor.shape=torch.Size([1, 5]) shapes=(1, 5) offsets=(1, 0)
rank=3 par.shape=torch.Size([5, 5]), (_StridedShard(dim=0, sf=2), Shard(dim=0)) par._local_tensor.shape=torch.Size([1, 5]) shapes=(1, 5) offsets=(3, 0)
"""
assert isinstance(par, DTensor)
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
shapes, offsets = compute_local_shape_and_global_offset(par.shape, global_mesh_2d, par.placements)
print(f"{rank=} {par.shape=}, {par.placements} {par._local_tensor.shape=} {shapes=} {offsets=}")
self.assertEqual(base_model[1].in_proj.weight, fsdp2_tp_model[1].in_proj.weight.full_tensor(), "bingo")
# one-step training to modify state dict
inp = torch.randn((2,), device=self.rank)
inp = torch.randn((1, mlp_dim,), device=self.rank)
base_model(inp).sum().backward()
base_optim.step()
fsdp2_tp_model(inp).sum().backward()
@ -589,11 +607,35 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
# Load state dict into another 'base model'
noparallel_model = _get_base_model()
noparallel_optim = torch.optim.AdamW(noparallel_model.parameters(), lr=0.1)
noparallel_state_dict = {
"model": get_model_state_dict(noparallel_model),
"optim": get_optimizer_state_dict(noparallel_model, noparallel_optim),
}
dcp.load(noparallel_state_dict, checkpoint_id=self.temp_dir)
noparallel_model.load_state_dict(noparallel_state_dict["model"])
noparallel_optim.load_state_dict(noparallel_state_dict["optim"])
noparallel_full_msd = get_model_state_dict(
noparallel_model,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
noparallel_full_osd = get_optimizer_state_dict(
noparallel_model,
noparallel_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
self.assertEqual(base_model[1].in_proj.weight, noparallel_model[1].in_proj.weight, "bingo")
# Compare full state dict to make sure they are the same.
self.assertEqual(base_msd, tp_full_msd)
self.assertEqual(base_osd, tp_full_osd)
self.assertEqual(fsdp2_tp_full_msd, tp_full_msd)
self.assertEqual(fsdp2_tp_full_osd, tp_full_osd)
self.assertEqual(base_msd, noparallel_full_msd)
self.assertEqual(base_osd, noparallel_full_osd)
if __name__ == "__main__":

View File

@ -956,7 +956,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
)
if size == 0:
# when tensor size is 0, there is no padding needed for all the ranks.
expected_pad_sizes = []
expected_pad_sizes = [0] * self.world_size
assert_array_equal(expected_pad_sizes, pad_sizes)
is_tensor_empty = [

View File

@ -114,6 +114,29 @@ class UtilTest(DTensorTestBase):
self.assertEqual(local_shape, expected_local_shape)
self.assertEqual(global_offset, expected_global_offset)
@with_comms
def test_uneven_fsdp_tp_meta_compute(self):
# FSDP + TP uneven sharding
tp_size = 2
dp_size = self.world_size // tp_size
global_mesh = init_device_mesh(
self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
global_tensor_shape = torch.Size([15, 5])
placements = [_StridedShard(0, split_factor=tp_size), Shard(0)]
local_shape, global_offset = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
assert global_mesh.get_coordinate is not None
dp_rank = global_mesh.get_local_rank("dp")
tp_rank = global_mesh.get_local_rank("tp")
expected_local_shape = (1, 5) if dp_rank == dp_size - 1 and tp_rank == tp_size - 1 else (2, 5)
shard_idx_on_dim_0 = tp_rank * dp_size + dp_rank
expected_global_offset = (shard_idx_on_dim_0, 0) if dp_rank == 0 and tp_rank == 0 else (shard_idx_on_dim_0 + 1, 0)
self.assertEqual(local_shape, expected_local_shape)
self.assertEqual(global_offset, expected_global_offset)
@with_comms
def test_hsdp_tp_meta_compute(self):
# HSDP + TP sharding

View File

@ -327,16 +327,6 @@ class FSDPParam:
self._spmd_placements,
tensor_meta=self._tp_spec.tensor_meta,
)
# TODO: Enable uneven sharding for FSDP+TP.
if split_factor > 1: # FSDP has strided sharding on tensor dim 0
num_shards = self._sharding_spec.num_shards_map[0]
tensor_size_dim_0 = self._sharding_spec.shape[0]
if tensor_size_dim_0 % num_shards != 0:
raise NotImplementedError(
"FSDP+TP sharding does not support uneven sharding for now: "
f"tensor dim 0 has size {tensor_size_dim_0} which cannot be "
f"evenly sharded into {num_shards} shards."
)
param_data = cast(DTensor, param)._local_tensor
else:
self._spmd_mesh = self.mesh_info.mesh

View File

@ -200,9 +200,7 @@ def fill_empty_tensor_to_shards(
if num_empty_tensors == 0:
return shards
tensor_size = list(shards[0].size())
tensor_size = [
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
]
tensor_size[shard_dim] = 0
tensor = shards[0].new_zeros(tensor_size)
shards.extend(tensor for _ in range(num_empty_tensors))
return shards

View File

@ -231,9 +231,9 @@ def redistribute_local_tensor(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert current.is_shard(), (
f"Current placement should be shard but found {current}"
)
assert (
current.is_shard()
), f"Current placement should be shard but found {current}"
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Sequence
from typing import cast
@ -14,6 +15,59 @@ from torch.distributed.tensor.placement_types import (
Shard,
)
import logging
logger= logging.getLogger(__name__)
def _explicit_order_placements(
mesh_shape: ShapeType, placements: Sequence[Placement]
) -> Sequence[tuple[int, Placement]]:
"""
Replace Strided Shards with regular shards in an adjusted order.
Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order.
ex.
[Shard(0), _StridedShard(0, split_factor=2), Shard(0)] ->
[(0, Shard(0)), (2, Shard(0)), (1, Shard(0))]
"""
if not any(isinstance(p, _StridedShard) for p in placements):
return list(enumerate(placements))
ordered = []
_strided_tmp = defaultdict(list)
for mesh_dim, p in enumerate(placements):
if isinstance(p, _StridedShard):
# validate the stride is the correct multiple of the meshdim and the earlier shard
_strided_tmp[p.dim].append((mesh_dim, p))
else:
ordered.append((mesh_dim, p))
if isinstance(p, Shard):
# we can only convert strided shardings to ordered shardings if split-factors are always __________
aggregate_size = mesh_shape[p.dim]
while p.dim in _strided_tmp and len(_strided_tmp[p.dim]) > 0:
# We can process multiple strided shardings in reverse-order
# (e.g. [_StridedShard(0, split_factor=4), _StridedShard(0, split_factor=2), Shard(0)])
# TODO- validate this logic and enable it (mainly, validate aggregate_size part)
if len(_strided_tmp[p.dim]) > 1:
raise NotImplementedError(
"NYI nested strided sharding conversion to ordered"
)
strided_dim, strided = _strided_tmp[p.dim].pop(-1)
if not strided.split_factor == aggregate_size:
raise RuntimeError(
f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})"
f" == aggregate mesh size ({aggregate_size})"
)
aggregate_size *= mesh_shape[strided_dim]
ordered.append((strided_dim, Shard(p.dim)))
return ordered
def compute_local_shape_and_global_offset(
global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
@ -22,39 +76,10 @@ def compute_local_shape_and_global_offset(
Compute the local tensor shape and the global offsets into the original tensor
of a DTensor on its current global rank. This is useful for checkpointing purpose.
Example (2 host with 4GPUs each):
# Below is a DeviceMesh with mesh_shape of (2, 4)
mesh = DeviceMesh(device_type="cuda",
mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
],
)
Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh
with a placements of [Shard(0), Shard(0)].
The local shape and global offset will be as follows:
rank0 -- local_shape:[1, 4], global_offset:[0, 0]
rank1 -- local_shape:[1, 4], global_offset:[1, 0]
rank2 -- local_shape:[1, 4], global_offset:[2, 0]
rank5 -- local_shape:[1, 4], global_offset:[5, 0]
rank3 -- local_shape:[1, 4], global_offset:[3, 0]
rank4 -- local_shape:[1, 4], global_offset:[4, 0]
rank6 -- local_shape:[1, 4], global_offset:[6, 0]
rank7 -- local_shape:[1, 4], global_offset:[7, 0]
Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with
a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks.
The local shape and global offset will be as follows:
rank0 -- local_shape:[1,], global_offset:[0,]
rank1 -- local_shape:[1,], global_offset:[1,]
rank2 -- local_shape:[0,], global_offset:[2,]
rank5 -- local_shape:[0,], global_offset:[2,]
rank3 -- local_shape:[0,], global_offset:[2,]
rank4 -- local_shape:[0,], global_offset:[2,]
rank6 -- local_shape:[0,], global_offset:[2,]
rank7 -- local_shape:[0,], global_offset:[2,]
"""
ordered_placements = _explicit_order_placements(mesh.shape, placements)
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
@ -68,18 +93,18 @@ def compute_local_shape_and_global_offset(
] # index by (shard_dim, mesh_dim)
num_shards_by_tensor_dim = [1] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
for mesh_dim, placement in ordered_placements:
mesh_dim_size = mesh.size(mesh_dim)
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
assert (
shard_dim < len(local_shape)
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[idx],
my_coordinate[mesh_dim],
return_offset=True,
)
@ -97,68 +122,99 @@ def compute_local_shape_and_global_offset(
num_shards_by_tensor_dim[shard_dim] *= mesh_dim_size
# NOTE: the offset compute relies on the local shard index and it has no
# problem when strided sharding is not present. To correctly compute, we assume
# that the ``_StridedShard.split_factor`` field encodes how many partitions
# each local tensor will be further split into when sharding on higher mesh
# dimensions. However, this number is only correct if the DTensor is not
# sharded after the strided sharding completes. For example,
# [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements
# where the DTensor's dim-0 is first sharded on device mesh dim-0, then on
# device mesh dim-2, and last on mesh dim-1. We define the
# "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding
# part because strided sharding happens on mesh dim-1 and it was caused by
# the fact that sharding on dim-2 occurred ahead. In this case, there's no
# further sharding after this strided sharding part and ``split_factor``
# correctly encodes the number. Another example is
# [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's
# dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh
# dim-2. This violates our assumption that no further sharding shall occur
# after the strided sharding part and ``split_factor`` won't correctly
# encode the number of further split. So far, the only case where _StridedShard
# placement would appear is FSDP2 + TP on 2D mesh and the above case could only
# happen on mesh of 3 or more dimensions.
# TODO: change this function to correctly address this.
# TODO: this logic can be applied to contiguous sharding as well
strided_sharding = any(isinstance(p, _StridedShard) for p in placements)
if strided_sharding:
strided_part_seen = [False] * len(global_shape)
strided_part_end = [False] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
logger.warning(f"{mesh_dim=} {local_shape[shard_dim]=} {local_offset[shard_dim]=} {global_offset[shard_dim]=}")
"""
if strided_part_end[shard_dim]:
raise NotImplementedError(
f"Strided sharding does not allow Shard() to appear after "
f"the strided part has ended. {placement} at idx {idx} in "
f"{placements} violates this assumption."
)
# # 01234
# # Rank0: 012 Rank1: 34
# # Rank2: 012 Rank3: 34
# # Rank0: 01 Rank1: 3
# # Rank2: 2 Rank3: 4
if strided_part_seen[shard_dim]:
strided_part_end[shard_dim] = True
# [rank0]:W0331 18:18:10.874000 3331413 torch/distributed/tensor/_utils.py:101] idx=0 local_shape[shard_dim]=3 local_offset[shard_dim]=0 global_offset[shard_dim]=0
# [rank0]:W0331 18:18:10.874000 3331413 torch/distributed/tensor/_utils.py:101] idx=1 local_shape[shard_dim]=2 local_offset[shard_dim]=0 global_offset[shard_dim]=0
if isinstance(placement, _StridedShard):
strided_part_seen[shard_dim] = True
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
num_shards_by_tensor_dim[shard_dim]
// (placement.split_factor * mesh_dim_size)
)
else:
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
num_shards_by_tensor_dim[shard_dim]
)
# [rank1]:W0331 18:18:10.825000 3331414 torch/distributed/tensor/_utils.py:101] idx=0 local_shape[shard_dim]=3 local_offset[shard_dim]=0 global_offset[shard_dim]=0
# [rank1]:W0331 18:18:10.825000 3331414 torch/distributed/tensor/_utils.py:101] idx=1 local_shape[shard_dim]=1 local_offset[shard_dim]=2 global_offset[shard_dim]=2
shard_idx = [
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
for shard_dim, shard_idx_stride in enumerate(
shard_idx_stride_by_mesh_dim
)
]
# [rank2]:W0331 18:18:10.863000 3331415 torch/distributed/tensor/_utils.py:101] idx=0 local_shape[shard_dim]=2 local_offset[shard_dim]=3 global_offset[shard_dim]=3
# [rank2]:W0331 18:18:10.864000 3331415 torch/distributed/tensor/_utils.py:101] idx=1 local_shape[shard_dim]=1 local_offset[shard_dim]=0 global_offset[shard_dim]=3
global_offset = [x * y for x, y in zip(local_shape, shard_idx)]
# [rank3]:W0331 18:18:10.825000 3331416 torch/distributed/tensor/_utils.py:101] idx=0 local_shape[shard_dim]=2 local_offset[shard_dim]=3 global_offset[shard_dim]=3
# [rank3]:W0331 18:18:10.825000 3331416 torch/distributed/tensor/_utils.py:101] idx=1 local_shape[shard_dim]=1 local_offset[shard_dim]=1 global_offset[shard_dim]=4
# """
# logger.warning(f"{num_shards_by_tensor_dim=}")
# # NOTE: the offset compute relies on the local shard index and it has no
# # problem when strided sharding is not present. To correctly compute, we assume
# # that the ``_StridedShard.split_factor`` field encodes how many partitions
# # each local tensor will be further split into when sharding on higher mesh
# # dimensions. However, this number is only correct if the DTensor is not
# # sharded after the strided sharding completes. For example,
# # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements
# # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on
# # device mesh dim-2, and last on mesh dim-1. We define the
# # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding
# # part because strided sharding happens on mesh dim-1 and it was caused by
# # the fact that sharding on dim-2 occurred ahead. In this case, there's no
# # further sharding after this strided sharding part and ``split_factor``
# # correctly encodes the number. Another example is
# # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's
# # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh
# # dim-2. This violates our assumption that no further sharding shall occur
# # after the strided sharding part and ``split_factor`` won't correctly
# # encode the number of further split. So far, the only case where _StridedShard
# # placement would appear is FSDP2 + TP on 2D mesh and the above case could only
# # happen on mesh of 3 or more dimensions.
# # TODO: change this function to correctly address this.
# # TODO: this logic can be applied to contiguous sharding as well
# strided_sharding = any(isinstance(p, _StridedShard) for p in placements)
# if strided_sharding:
# # TODO(whc): Fix this for uneven padding case
# # - local_shapes should not include padding
# strided_part_seen = [False] * len(global_shape)
# strided_part_end = [False] * len(global_shape)
# for idx, placement in enumerate(placements):
# mesh_dim_size = mesh.size(idx)
# if isinstance(placement, Shard):
# shard_dim = placement.dim
# if strided_part_end[shard_dim]:
# raise NotImplementedError(
# f"Strided sharding does not allow Shard() to appear after "
# f"the strided part has ended. {placement} at idx {idx} in "
# f"{placements} violates this assumption."
# )
# if strided_part_seen[shard_dim]:
# logger.warning(f"ending strided part on {idx=} {shard_dim=}")
# strided_part_end[shard_dim] = True
# if isinstance(placement, _StridedShard):
# strided_part_seen[shard_dim] = True
# shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
# num_shards_by_tensor_dim[shard_dim]
# // (placement.split_factor * mesh_dim_size)
# )
# logger.warning(f"Starting strided part on {idx=} {shard_dim=}, {shard_idx_stride_by_mesh_dim[shard_dim][idx]=} ")
# else:
# num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
# shard_idx_stride_by_mesh_dim[shard_dim][idx] = (
# num_shards_by_tensor_dim[shard_dim]
# )
# logger.warning(f"non-strided part on {idx=} {shard_dim=}, {shard_idx_stride_by_mesh_dim[shard_dim][idx]=} ")
# shard_idx = [
# sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
# for shard_dim, shard_idx_stride in enumerate(
# shard_idx_stride_by_mesh_dim
# )
# ]
# logger.warning(f"{my_coordinate=} {shard_idx_stride_by_mesh_dim=} {shard_idx=}")
# global_offset = [x * y for x, y in zip(local_shape, shard_idx)]
# logger.warning(f"{local_shape=} {global_offset=}")
return tuple(local_shape), tuple(global_offset)
@ -204,9 +260,9 @@ def compute_global_tensor_info(
)
shard_dim = shard_placement.dim
assert shard_dim < tensor.ndim, (
f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
)
assert (
shard_dim < tensor.ndim
), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
local_dim_size = tensor_shape[shard_dim]
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size

View File

@ -83,45 +83,28 @@ class Shard(Placement):
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
This is because collectives usually require equal size tensor inputs
"""
assert self.dim <= tensor.ndim, (
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
)
assert (
self.dim <= tensor.ndim
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
# chunk tensor over dimension `dim` into n slices
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
num_empty_tensors = num_chunks - len(tensor_list)
# if no need to have padding or tensor dim size is evenly sharded already
# we can return early.
if not with_padding or tensor.size(self.dim) % num_chunks == 0:
if contiguous:
tensor_list = [t.contiguous() for t in tensor_list]
return (
fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors),
[],
)
tensor_list = fill_empty_tensor_to_shards(
tensor_list, self.dim, num_chunks - len(tensor_list)
)
# compute the chunk size inline with ``torch.chunk`` to calculate padding
full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks
# Compute chunk size for each chunk for ``self.dim``
chunk_sizes = [
tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0
for idx in range(num_chunks)
]
# Compute pad size on each chunk
pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes]
# Reuse tensor to fill empty chunk with empty tensor
tensor_list = fill_empty_tensor_to_shards(
tensor_list, self.dim, num_empty_tensors
)
shard_list = []
for shard, pad_size in zip(tensor_list, pad_sizes):
# Fill the empty tensor with zeroes with padding.
if with_padding and pad_size > 0:
shard_list: list[torch.Tensor] = []
pad_sizes: list[int] = []
for shard in tensor_list:
if with_padding:
pad_size = full_chunk_size - shard.size(self.dim)
shard = pad_tensor(shard, self.dim, pad_size)
shard = shard.contiguous() if contiguous else shard
pad_sizes.append(pad_size)
if contiguous:
shard = shard.contiguous()
shard_list.append(shard)
return shard_list, pad_sizes
@ -134,6 +117,11 @@ class Shard(Placement):
) -> tuple[int, int]:
"""
returns the local shard size and offset on a given tensor dim
# TODO(whc)
- `size_on_dim` arg has a pretty confusing name, would 'global_dim_size' be better?
- `local_shard_size_on_dim` could be renamed to `local_shard_size`?
- why do we have 'return_offset` bool? if we always return a tuple anyway, why not always return the real offset?
"""
# Compute the chunk size inline with ``torch.chunk``
if size_on_dim % num_chunks == 0:
@ -186,14 +174,17 @@ class Shard(Placement):
tensor, num_chunks, with_padding=True, contiguous=True
)
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
# perform scatter from the src_data_rank as data source when it is not None
mesh_scatter(
output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank
)
# Only unpad if the local_tensor was padded on the dimension.
if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0:
if pad_sizes[mesh_dim_local_rank] > 0:
output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank])
# Unpad might return a view, hence we need to remake it contiguous
output = output.contiguous()
return output
def _reduce_shard_tensor(
@ -243,15 +234,13 @@ class Shard(Placement):
is replicated on the previously sharded mesh dimension
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
# check if it's uneven, so we need to pad input tensor before all_gather
local_shape = list(local_tensor.size())
logical_dim_size = current_logical_shape[self.dim]
is_padded = logical_dim_size % num_chunks != 0
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
pad_size = full_chunk_size - local_shape[self.dim]
pad_size = full_chunk_size - local_tensor.size(self.dim)
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
if not local_tensor.is_contiguous():
@ -427,8 +416,6 @@ class _StridedShard(Shard):
dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the
`split_factor` of the _StridedShard placement on "dp" dim is 2.
TODO: strided sharding needs to work fine with uneven sharding. Now it forbids
resharding if the tensor is unevenly sharded.
TODO: we should remove _StridedShard placement once we can unify it with Shard
"""
@ -457,6 +444,8 @@ class _StridedShard(Shard):
"""human readable representation of the _StridedShard placement"""
return f"_S({self.dim}, {self.split_factor})"
# TODO(whc) (we should update this to match the uneven shard behavior in `to_replicate_tensor
# but this only matters for unit tests
def _split_tensor(
self,
tensor: torch.Tensor,
@ -465,37 +454,35 @@ class _StridedShard(Shard):
with_padding: bool = True,
contiguous: bool = True,
) -> tuple[list[torch.Tensor], list[int]]:
"""
TODO: currently _StridedShard does not support padding
"""
assert self.dim <= tensor.ndim, (
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
)
assert (
self.dim <= tensor.ndim
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
total_split = num_chunks * self.split_factor
assert tensor.size(self.dim) % total_split == 0, (
"_StridedShard currently only allows even sharding but got tensor size"
f" {tensor.size(self.dim)} on dim {self.dim} and total split"
f" {total_split}={num_chunks} * {self.split_factor}"
tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim))
tensor_list = fill_empty_tensor_to_shards(
tensor_list, self.dim, total_split - len(tensor_list)
)
group_size = self.split_factor
total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim))
tensor_list = [
torch.cat(
[
total_split_tensor_list[i + j * num_chunks] # stride is num_chunks
for j in range(group_size)
],
# compute the chunk size inline with ``torch.chunk`` to calculate padding
full_chunk_size = (tensor.size(self.dim) + total_split - 1) // total_split
shard_list: list[torch.Tensor] = []
pad_sizes: list[int] = []
for i in range(num_chunks):
shard = torch.cat(
[tensor_list[i + j * num_chunks] for j in range(self.split_factor)],
dim=self.dim,
)
for i in range(num_chunks)
]
if contiguous:
tensor_list = [t.contiguous() for t in tensor_list]
return tensor_list, []
if with_padding:
pad_size = full_chunk_size * self.split_factor - shard.size(self.dim)
shard = pad_tensor(shard, self.dim, pad_size)
pad_sizes.append(pad_size)
if contiguous:
shard = shard.contiguous()
shard_list.append(shard)
return shard_list, pad_sizes
def _to_replicate_tensor(
self,
@ -505,17 +492,76 @@ class _StridedShard(Shard):
current_logical_shape: list[int],
) -> torch.Tensor:
"""
Note: currently _StridedShard does not support padding
Given a tensor with strided sharding (e.g. [StridedShard(d), Shard(d)]),
this function is called during the process of converting to [Replicate(), Replicate()],
and `local_tensor` represents the portion of the tensor on this rank after the intermediate step of
converting to [StridedShard(d), Replicate()] in right-to-left unsharding order.
note: this conversion logic is pretty specialized on this 2D case. It could be generalized further. This
is a common enough case to be worth fixing (since it occurs when applying TP and then FSDP to a model).
note: this does not support 'reduce_scatter' for StridedShard.
Example
-------
mesh = (DP=2, TP=2)
# single-gpu "weight" of size 5, will be 'uneven' for sharding
original = torch.arange(5)
tp sharded tensor
-----------------
`tp = distribute_tensor(x, world_mesh['tp'], [Shard(0)])`
local_tensors:
rank0: [0,1,2] rank1: [3,4]
rank1: [0,1,2] rank3: [3,4]
fsdp+tp sharded tensor
----------------------
`dp_tp = ...` (the process of creating a strided-shard tensor is skipped over as it is hacky and complicated #TODO put an example somewhre and ref to it)
dp_tp has placement (_StridedShard(0, split_factor=2), Shard(0))
local_tensors:
rank0: [0,1] rank1: [3]
rank1: [2] rank3: [4]
Now, say someone wants to reconstruct dp_tp's full tensor. This will invoke 'redistribute' to replicate.
redistribute will first replicate the "Shard(0)" placement on the rightmost mesh dim, then replicate the
StridedShard placement second, which is implemented by this function.
So our starting point (`local_tensor` arg) is the result of replicating the Shard(0) placement across the
TP dim, which looks like this.
Note the discrepancy with the 'tp sharded tensor' line above! We'll fix it by locally shuffling data.
local_tensors:
rank0: [0,1,3] rank1: [0,1,3]
rank1: [2,4] rank3: [2,4]
Step 1: replicate over the DP dimension. Afterwards, each rank can locally sort the values.
note: we need padding to do this allgather, and we'll need to keep track of the padding amount for later
local_tensors:
rank0: [0,1,3,2,4] rank1: [0,1,3,2,4]
rank1: [0,1,3,2,4] rank3: [0,1,3,2,4]
Step 2: chunk and shuffle values around to account for the wrong order of operations above
and get the original tensor content back
01324# <- our allgather includes padding, if padding was applied in step 1
01324 <- Remove the padding
013, 24 <- chunk once, 'undoing' the DP allgather
01, 3, 2, 4 <- chunk each chunk, 'undoing' the initial (wrong) TP allgather performed by Shard(0)->Replicate()
012, 34 <- interleave with stride=TP mesh dim size
01234 <- concatenate
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
total_split = num_chunks * self.split_factor
# NOTE: we require Strided Sharding to be even for now
assert current_logical_shape[self.dim] % total_split == 0, (
"_StridedShard requires even sharding but got tensor size "
f"{current_logical_shape[self.dim]} on dim {self.dim} and "
f"total split {total_split}=num_chunks {num_chunks} "
f"* split_factor {self.split_factor}"
)
logical_dim_size = current_logical_shape[self.dim]
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
local_pad_size = full_chunk_size - local_tensor.size(self.dim)
if local_pad_size > 0:
local_tensor = pad_tensor(local_tensor, self.dim, local_pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
result = funcol.all_gather_tensor(
local_tensor,
@ -525,19 +571,22 @@ class _StridedShard(Shard):
if isinstance(result, funcol.AsyncCollectiveTensor):
result = result.wait()
tensor_shard_list = torch.chunk(result, total_split, dim=self.dim)
# rearrange the order
new_tensor_shard_list = []
for idx in range(len(tensor_shard_list)):
# the shard split of index `idx` is assigned a new index within
# _StridedShard._split_tensor:
# the original tensor was split into `total_split` chunks,
# all chunks with the same `idx % num_chunks` are merged into one
# new shard and placed on mesh's local rank `idx % num_chunks`
idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks
new_tensor_shard_list.append(tensor_shard_list[idx_after_split])
if result.shape[self.dim] > logical_dim_size:
result = unpad_tensor(
result, self.dim, result.shape[self.dim] - logical_dim_size
)
return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous()
# this reverses our 'all_gather' but gives every rank a copy
dp_shards = torch.chunk(result, num_chunks, dim=self.dim)
# this undoes the 'Shard(0)' -> Replicate() that happened over the wrong mesh dim in the first place
tp_shards = []
for p in dp_shards:
tp_shards.extend(torch.chunk(p, self.split_factor, dim=self.dim))
# now we just have to correctly stride the shards
reordered_shards = []
for i in range(self.split_factor):
reordered_shards.extend(tp_shards[i :: self.split_factor])
return torch.cat(reordered_shards, dim=self.dim).contiguous()
@dataclass(frozen=True)