Revert "[dtensor] move pad/unpad_tensor to separate utils (#124871)"

This reverts commit 0b0eea222978e6b377e2c67f89902d5eb1aa7da3.

Reverted https://github.com/pytorch/pytorch/pull/124871 on behalf of https://github.com/jeanschmidt due to Broke internal tests, see D56587991 for more details ([comment](https://github.com/pytorch/pytorch/pull/124871#issuecomment-2079001103))
This commit is contained in:
PyTorch MergeBot
2024-04-26 09:30:34 +00:00
parent 35a82d4a4a
commit 359ff49bf4
3 changed files with 31 additions and 32 deletions

View File

@ -809,10 +809,8 @@ class TestDTensorPlacementTypes(DTensorTestBase):
]
assert_array_equal(expected_pad_sizes, pad_sizes)
from torch.distributed._tensor._collective_utils import unpad_tensor
unpadded_list = [
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
shard_placement._unpad_tensor(tensor, pad_sizes[i])
if pad_sizes[i] > 0
else tensor
for i, tensor in enumerate(splitted_tensor_list)

View File

@ -164,24 +164,6 @@ def mesh_all_to_all(
return work
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
pad = [0, 0] * (tensor.ndim - pad_dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad)
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
return tensor.narrow(
pad_dim,
start=0,
length=tensor.size(pad_dim) - pad_size,
)
def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int:
assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)

View File

@ -7,12 +7,7 @@ import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch.distributed._tensor._collective_utils import (
mesh_broadcast,
mesh_scatter,
pad_tensor,
unpad_tensor,
)
from torch.distributed._tensor._collective_utils import mesh_broadcast, mesh_scatter
from torch.distributed.device_mesh import DeviceMesh
@ -88,13 +83,37 @@ class Shard(Placement):
for shard, pad_size in zip(tensor_list, pad_sizes):
# Fill the empty tensor with zeroes with padding.
if with_padding and pad_size > 0:
shard = pad_tensor(shard, self.dim, pad_size)
shard = self._pad_tensor(shard, pad_size)
shard = shard.contiguous() if contiguous else shard
shard_list.append(shard)
return shard_list, pad_sizes
else:
return tensor_list, pad_sizes
def _pad_tensor(
self,
tensor: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
if pad_size == 0:
return tensor
pad = [0, 0] * (tensor.ndim - self.dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad)
def _unpad_tensor(
self,
tensor: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
if pad_size == 0:
return tensor
return tensor.narrow(
self.dim,
start=0,
length=tensor.size(self.dim) - pad_size,
)
@staticmethod
def _local_shard_size_on_dim(
size_on_dim: int,
@ -147,7 +166,7 @@ class Shard(Placement):
# Only unpad if the local_tensor was padded on the dimension.
pad_size = pad_sizes[my_coordinate[mesh_dim]]
if pad_size > 0:
output = unpad_tensor(output, self.dim, pad_size)
output = self._unpad_tensor(output, pad_size)
return output
def _reduce_shard_tensor(
@ -182,7 +201,7 @@ class Shard(Placement):
)
if is_padded:
output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
return output
def _to_replicate_tensor(
@ -206,7 +225,7 @@ class Shard(Placement):
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
pad_size = full_chunk_size - local_shape[self.dim]
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
local_tensor = self._pad_tensor(local_tensor, pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
@ -218,7 +237,7 @@ class Shard(Placement):
)
if is_padded:
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
result = unpad_tensor(result, self.dim, unpad_size)
result = self._unpad_tensor(result, unpad_size)
return result
def _replicate_to_shard(