Fix _StridedShard incorrect split (#165533)

https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)

Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.

Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
This commit is contained in:
zpcore
2025-10-17 20:54:46 +00:00
committed by PyTorch MergeBot
parent 06d324365c
commit ab65498d71
3 changed files with 82 additions and 52 deletions

View File

@ -25,6 +25,7 @@ from torch.distributed.tensor._utils import (
normalize_to_torch_size,
)
from torch.distributed.tensor.placement_types import (
_StridedShard,
Partial,
Placement,
Replicate,
@ -776,18 +777,29 @@ def distribute_tensor(
# distribute the tensor according to the placements.
placements = list(placements)
for idx, placement in enumerate(placements):
if placement.is_shard():
placement = cast(Shard, placement)
if placement.dim < 0:
# normalize shard placement dim
placement = Shard(placement.dim + tensor.ndim)
placements[idx] = placement
local_tensor = placement._shard_tensor(
local_tensor, device_mesh, idx, src_data_rank
if isinstance(placement, Shard):
placement_dim = (
placement.dim + tensor.ndim if placement.dim < 0 else placement.dim
)
elif placement.is_replicate():
placement = cast(Replicate, placement)
local_tensor = placement._replicate_tensor(
if isinstance(placement, _StridedShard):
local_tensor = _StridedShard._make_shard_tensor(
placement_dim,
local_tensor,
device_mesh,
idx,
src_data_rank,
split_factor=placement.split_factor,
)
placements[idx] = _StridedShard(
placement_dim, split_factor=placement.split_factor
)
else:
local_tensor = Shard._make_shard_tensor(
placement_dim, local_tensor, device_mesh, idx, src_data_rank
)
placements[idx] = Shard(placement_dim)
elif isinstance(placement, Replicate):
local_tensor = Replicate._make_replicate_tensor(
local_tensor, device_mesh, idx, src_data_rank
)
else:

View File

@ -69,9 +69,8 @@ class Shard(Placement):
else:
return True
@staticmethod
def _make_split_tensor(
dim: int,
def _split_tensor(
self,
tensor: torch.Tensor,
num_chunks: int,
*,
@ -87,47 +86,31 @@ class Shard(Placement):
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
This is because collectives usually require equal size tensor inputs
"""
assert dim <= tensor.ndim, (
f"Sharding dim {dim} greater than tensor ndim {tensor.ndim}"
assert self.dim <= tensor.ndim, (
f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
)
# chunk tensor over dimension `dim` into n slices
tensor_list = list(torch.chunk(tensor, num_chunks, dim=dim))
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
tensor_list = fill_empty_tensor_to_shards(
tensor_list, dim, num_chunks - len(tensor_list)
tensor_list, self.dim, num_chunks - len(tensor_list)
)
# compute the chunk size inline with ``torch.chunk`` to calculate padding
full_chunk_size = (tensor.size(dim) + num_chunks - 1) // num_chunks
full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks
shard_list: list[torch.Tensor] = []
pad_sizes: list[int] = []
for shard in tensor_list:
if with_padding:
pad_size = full_chunk_size - shard.size(dim)
shard = pad_tensor(shard, dim, pad_size)
pad_size = full_chunk_size - shard.size(self.dim)
shard = pad_tensor(shard, self.dim, pad_size)
pad_sizes.append(pad_size)
if contiguous:
shard = shard.contiguous()
shard_list.append(shard)
return shard_list, pad_sizes
def _split_tensor(
self,
tensor: torch.Tensor,
num_chunks: int,
*,
with_padding: bool = True,
contiguous: bool = True,
) -> tuple[list[torch.Tensor], list[int]]:
return Shard._make_split_tensor(
self.dim,
tensor,
num_chunks,
with_padding=with_padding,
contiguous=contiguous,
)
@staticmethod
@maybe_run_for_local_tensor
def local_shard_size_and_offset(
@ -186,9 +169,8 @@ class Shard(Placement):
local_tensor = local_tensor.contiguous()
return local_tensor
@staticmethod
def _make_shard_tensor(
dim: int,
def _shard_tensor(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
@ -210,14 +192,14 @@ class Shard(Placement):
if src_data_rank is None:
# src_data_rank specified as None explicitly means to skip the
# communications, simply split
scatter_list, _ = Shard._make_split_tensor(
dim, tensor, num_chunks, with_padding=False, contiguous=True
scatter_list, _ = self._split_tensor(
tensor, num_chunks, with_padding=False, contiguous=True
)
return Shard._select_shard(scatter_list, mesh_dim_local_rank)
return self._select_shard(scatter_list, mesh_dim_local_rank)
scatter_list, pad_sizes = Shard._make_split_tensor(
dim, tensor, num_chunks, with_padding=True, contiguous=True
scatter_list, pad_sizes = self._split_tensor(
tensor, num_chunks, with_padding=True, contiguous=True
)
it = iter(scatter_list)
@ -234,17 +216,20 @@ class Shard(Placement):
)
return Shard._maybe_unpad_tensor_with_sizes(
dim, output, pad_sizes, mesh_dim_local_rank, True
self.dim, output, pad_sizes, mesh_dim_local_rank, True
)
def _shard_tensor(
self,
@classmethod
def _make_shard_tensor(
cls,
dim: int,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
src_data_rank: Optional[int] = 0,
) -> torch.Tensor:
return Shard._make_shard_tensor(self.dim, tensor, mesh, mesh_dim, src_data_rank)
shard_placement = cls(dim)
return shard_placement._shard_tensor(tensor, mesh, mesh_dim, src_data_rank)
def _reduce_shard_tensor(
self,
@ -267,8 +252,8 @@ class Shard(Placement):
is_padded = tensor.size(self.dim) % num_chunks != 0
pad_sizes = None
if is_padded:
scattered_list, pad_sizes = Shard._make_split_tensor(
self.dim, tensor, num_chunks, with_padding=True, contiguous=True
scattered_list, pad_sizes = self._split_tensor(
tensor, num_chunks, with_padding=True, contiguous=True
)
tensor = torch.cat(scattered_list, dim=self.dim)
elif not tensor.is_contiguous():
@ -538,6 +523,21 @@ class _StridedShard(Shard):
"""human readable representation of the _StridedShard placement"""
return f"_S({self.dim}, {self.split_factor})"
@classmethod
def _make_shard_tensor(
cls,
dim: int,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
src_data_rank: Optional[int] = 0,
split_factor: int = 1,
) -> torch.Tensor:
strided_shard_placement = cls(dim=dim, split_factor=split_factor)
return strided_shard_placement._shard_tensor(
tensor, mesh, mesh_dim, src_data_rank
)
def _split_tensor(
self,
tensor: torch.Tensor,
@ -704,8 +704,9 @@ class Replicate(Placement):
"""
return "R"
@staticmethod
@classmethod
def _make_replicate_tensor(
cls,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,