[dtensor] move pad/unpad_tensor to separate utils (#124871)

as titled, 1. pad/unpad is a general util not specific to the Shard
placement, 2. for the propose of the next PR, move these two out of Shard
placement itself, and give additional pad_dim argument

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124871
Approved by: https://github.com/awgu, https://github.com/wz337, https://github.com/XilunWu
This commit is contained in:
Wanchao Liang
2024-04-28 16:38:55 -07:00
committed by PyTorch MergeBot
parent 935a946241
commit 8d46ab4104
4 changed files with 38 additions and 36 deletions

View File

@ -809,8 +809,10 @@ class TestDTensorPlacementTypes(DTensorTestBase):
]
assert_array_equal(expected_pad_sizes, pad_sizes)
from torch.distributed._tensor._collective_utils import unpad_tensor
unpadded_list = [
shard_placement._unpad_tensor(tensor, pad_sizes[i])
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
if pad_sizes[i] > 0
else tensor
for i, tensor in enumerate(splitted_tensor_list)

View File

@ -9,6 +9,7 @@ from torch.distributed._tensor._collective_utils import (
mesh_all_to_all,
mesh_broadcast,
mesh_scatter,
unpad_tensor,
)
from torch.distributed._tensor.placement_types import _Partial, Shard
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
@ -490,8 +491,8 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0)
if pad_sizes[my_rank] != 0:
scattered_tensor = shard_placement._unpad_tensor(
scattered_tensor, pad_sizes[my_rank]
scattered_tensor = unpad_tensor(
scattered_tensor, shard_dim, pad_sizes[my_rank]
)
if scattered_tensor.numel() == 0:
@ -533,7 +534,7 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
)
unpadded_list = [
(
shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i])
unpad_tensor(big_tensor_chunks[i], shard_dim, pad_sizes[i])
if pad_sizes[i] > 0
else big_tensor_chunks[i]
)
@ -629,8 +630,8 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
# unpad scattered_tensor
if pad_sizes[my_rank] > 0:
scattered_tensor = shard_placement._unpad_tensor(
scattered_tensor, pad_sizes[my_rank]
scattered_tensor = unpad_tensor(
scattered_tensor, shard_dim, pad_sizes[my_rank]
)
if scattered_tensor.numel() == 0:

View File

@ -164,6 +164,24 @@ def mesh_all_to_all(
return work
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
pad = [0, 0] * (tensor.ndim - pad_dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad)
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
return tensor.narrow(
pad_dim,
start=0,
length=tensor.size(pad_dim) - pad_size,
)
def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int:
assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)

View File

@ -7,7 +7,12 @@ import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch.distributed._tensor._collective_utils import mesh_broadcast, mesh_scatter
from torch.distributed._tensor._collective_utils import (
mesh_broadcast,
mesh_scatter,
pad_tensor,
unpad_tensor,
)
from torch.distributed.device_mesh import DeviceMesh
@ -83,37 +88,13 @@ class Shard(Placement):
for shard, pad_size in zip(tensor_list, pad_sizes):
# Fill the empty tensor with zeroes with padding.
if with_padding and pad_size > 0:
shard = self._pad_tensor(shard, pad_size)
shard = pad_tensor(shard, self.dim, pad_size)
shard = shard.contiguous() if contiguous else shard
shard_list.append(shard)
return shard_list, pad_sizes
else:
return tensor_list, pad_sizes
def _pad_tensor(
self,
tensor: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
if pad_size == 0:
return tensor
pad = [0, 0] * (tensor.ndim - self.dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad)
def _unpad_tensor(
self,
tensor: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
if pad_size == 0:
return tensor
return tensor.narrow(
self.dim,
start=0,
length=tensor.size(self.dim) - pad_size,
)
@staticmethod
def _local_shard_size_on_dim(
size_on_dim: int,
@ -166,7 +147,7 @@ class Shard(Placement):
# Only unpad if the local_tensor was padded on the dimension.
pad_size = pad_sizes[my_coordinate[mesh_dim]]
if pad_size > 0:
output = self._unpad_tensor(output, pad_size)
output = unpad_tensor(output, self.dim, pad_size)
return output
def _reduce_shard_tensor(
@ -201,7 +182,7 @@ class Shard(Placement):
)
if is_padded:
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
return output
def _to_replicate_tensor(
@ -225,7 +206,7 @@ class Shard(Placement):
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
pad_size = full_chunk_size - local_shape[self.dim]
local_tensor = self._pad_tensor(local_tensor, pad_size)
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
@ -237,7 +218,7 @@ class Shard(Placement):
)
if is_padded:
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
result = self._unpad_tensor(result, unpad_size)
result = unpad_tensor(result, self.dim, unpad_size)
return result
def _replicate_to_shard(