mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix _StridedShard
incorrect split (#165533)
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)
Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)
). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.
Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
This commit is contained in:
@ -20,6 +20,7 @@ from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -1145,6 +1146,22 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
device_mesh = init_device_mesh(self.device_type, (4, 2))
|
||||
x = torch.randn(8, 4, device=self.device_type)
|
||||
# specify right-to-left order use _StridedShard
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),),
|
||||
)
|
||||
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -25,6 +25,7 @@ from torch.distributed.tensor._utils import (
|
||||
normalize_to_torch_size,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
_StridedShard,
|
||||
Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
@ -776,18 +777,29 @@ def distribute_tensor(
|
||||
# distribute the tensor according to the placements.
|
||||
placements = list(placements)
|
||||
for idx, placement in enumerate(placements):
|
||||
if placement.is_shard():
|
||||
placement = cast(Shard, placement)
|
||||
if placement.dim < 0:
|
||||
# normalize shard placement dim
|
||||
placement = Shard(placement.dim + tensor.ndim)
|
||||
placements[idx] = placement
|
||||
local_tensor = placement._shard_tensor(
|
||||
local_tensor, device_mesh, idx, src_data_rank
|
||||
if isinstance(placement, Shard):
|
||||
placement_dim = (
|
||||
placement.dim + tensor.ndim if placement.dim < 0 else placement.dim
|
||||
)
|
||||
elif placement.is_replicate():
|
||||
placement = cast(Replicate, placement)
|
||||
local_tensor = placement._replicate_tensor(
|
||||
if isinstance(placement, _StridedShard):
|
||||
local_tensor = _StridedShard._make_shard_tensor(
|
||||
placement_dim,
|
||||
local_tensor,
|
||||
device_mesh,
|
||||
idx,
|
||||
src_data_rank,
|
||||
split_factor=placement.split_factor,
|
||||
)
|
||||
placements[idx] = _StridedShard(
|
||||
placement_dim, split_factor=placement.split_factor
|
||||
)
|
||||
else:
|
||||
local_tensor = Shard._make_shard_tensor(
|
||||
placement_dim, local_tensor, device_mesh, idx, src_data_rank
|
||||
)
|
||||
placements[idx] = Shard(placement_dim)
|
||||
elif isinstance(placement, Replicate):
|
||||
local_tensor = Replicate._make_replicate_tensor(
|
||||
local_tensor, device_mesh, idx, src_data_rank
|
||||
)
|
||||
else:
|
||||
|
@ -69,9 +69,8 @@ class Shard(Placement):
|
||||
else:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _make_split_tensor(
|
||||
dim: int,
|
||||
def _split_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
num_chunks: int,
|
||||
*,
|
||||
@ -87,47 +86,31 @@ class Shard(Placement):
|
||||
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
|
||||
This is because collectives usually require equal size tensor inputs
|
||||
"""
|
||||
assert dim <= tensor.ndim, (
|
||||
f"Sharding dim {dim} greater than tensor ndim {tensor.ndim}"
|
||||
assert self.dim <= tensor.ndim, (
|
||||
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
|
||||
)
|
||||
|
||||
# chunk tensor over dimension `dim` into n slices
|
||||
tensor_list = list(torch.chunk(tensor, num_chunks, dim=dim))
|
||||
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
|
||||
tensor_list = fill_empty_tensor_to_shards(
|
||||
tensor_list, dim, num_chunks - len(tensor_list)
|
||||
tensor_list, self.dim, num_chunks - len(tensor_list)
|
||||
)
|
||||
|
||||
# compute the chunk size inline with ``torch.chunk`` to calculate padding
|
||||
full_chunk_size = (tensor.size(dim) + num_chunks - 1) // num_chunks
|
||||
full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks
|
||||
|
||||
shard_list: list[torch.Tensor] = []
|
||||
pad_sizes: list[int] = []
|
||||
for shard in tensor_list:
|
||||
if with_padding:
|
||||
pad_size = full_chunk_size - shard.size(dim)
|
||||
shard = pad_tensor(shard, dim, pad_size)
|
||||
pad_size = full_chunk_size - shard.size(self.dim)
|
||||
shard = pad_tensor(shard, self.dim, pad_size)
|
||||
pad_sizes.append(pad_size)
|
||||
if contiguous:
|
||||
shard = shard.contiguous()
|
||||
shard_list.append(shard)
|
||||
return shard_list, pad_sizes
|
||||
|
||||
def _split_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
num_chunks: int,
|
||||
*,
|
||||
with_padding: bool = True,
|
||||
contiguous: bool = True,
|
||||
) -> tuple[list[torch.Tensor], list[int]]:
|
||||
return Shard._make_split_tensor(
|
||||
self.dim,
|
||||
tensor,
|
||||
num_chunks,
|
||||
with_padding=with_padding,
|
||||
contiguous=contiguous,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def local_shard_size_and_offset(
|
||||
@ -186,9 +169,8 @@ class Shard(Placement):
|
||||
local_tensor = local_tensor.contiguous()
|
||||
return local_tensor
|
||||
|
||||
@staticmethod
|
||||
def _make_shard_tensor(
|
||||
dim: int,
|
||||
def _shard_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
@ -210,14 +192,14 @@ class Shard(Placement):
|
||||
if src_data_rank is None:
|
||||
# src_data_rank specified as None explicitly means to skip the
|
||||
# communications, simply split
|
||||
scatter_list, _ = Shard._make_split_tensor(
|
||||
dim, tensor, num_chunks, with_padding=False, contiguous=True
|
||||
scatter_list, _ = self._split_tensor(
|
||||
tensor, num_chunks, with_padding=False, contiguous=True
|
||||
)
|
||||
|
||||
return Shard._select_shard(scatter_list, mesh_dim_local_rank)
|
||||
return self._select_shard(scatter_list, mesh_dim_local_rank)
|
||||
|
||||
scatter_list, pad_sizes = Shard._make_split_tensor(
|
||||
dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
scatter_list, pad_sizes = self._split_tensor(
|
||||
tensor, num_chunks, with_padding=True, contiguous=True
|
||||
)
|
||||
|
||||
it = iter(scatter_list)
|
||||
@ -234,17 +216,20 @@ class Shard(Placement):
|
||||
)
|
||||
|
||||
return Shard._maybe_unpad_tensor_with_sizes(
|
||||
dim, output, pad_sizes, mesh_dim_local_rank, True
|
||||
self.dim, output, pad_sizes, mesh_dim_local_rank, True
|
||||
)
|
||||
|
||||
def _shard_tensor(
|
||||
self,
|
||||
@classmethod
|
||||
def _make_shard_tensor(
|
||||
cls,
|
||||
dim: int,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
return Shard._make_shard_tensor(self.dim, tensor, mesh, mesh_dim, src_data_rank)
|
||||
shard_placement = cls(dim)
|
||||
return shard_placement._shard_tensor(tensor, mesh, mesh_dim, src_data_rank)
|
||||
|
||||
def _reduce_shard_tensor(
|
||||
self,
|
||||
@ -267,8 +252,8 @@ class Shard(Placement):
|
||||
is_padded = tensor.size(self.dim) % num_chunks != 0
|
||||
pad_sizes = None
|
||||
if is_padded:
|
||||
scattered_list, pad_sizes = Shard._make_split_tensor(
|
||||
self.dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
scattered_list, pad_sizes = self._split_tensor(
|
||||
tensor, num_chunks, with_padding=True, contiguous=True
|
||||
)
|
||||
tensor = torch.cat(scattered_list, dim=self.dim)
|
||||
elif not tensor.is_contiguous():
|
||||
@ -538,6 +523,21 @@ class _StridedShard(Shard):
|
||||
"""human readable representation of the _StridedShard placement"""
|
||||
return f"_S({self.dim}, {self.split_factor})"
|
||||
|
||||
@classmethod
|
||||
def _make_shard_tensor(
|
||||
cls,
|
||||
dim: int,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
split_factor: int = 1,
|
||||
) -> torch.Tensor:
|
||||
strided_shard_placement = cls(dim=dim, split_factor=split_factor)
|
||||
return strided_shard_placement._shard_tensor(
|
||||
tensor, mesh, mesh_dim, src_data_rank
|
||||
)
|
||||
|
||||
def _split_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
@ -704,8 +704,9 @@ class Replicate(Placement):
|
||||
"""
|
||||
return "R"
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def _make_replicate_tensor(
|
||||
cls,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
|
Reference in New Issue
Block a user