mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dtensor] move pad/unpad_tensor to separate utils (#124871)"
This reverts commit 0b0eea222978e6b377e2c67f89902d5eb1aa7da3. Reverted https://github.com/pytorch/pytorch/pull/124871 on behalf of https://github.com/jeanschmidt due to Broke internal tests, see D56587991 for more details ([comment](https://github.com/pytorch/pytorch/pull/124871#issuecomment-2079001103))
This commit is contained in:
@ -809,10 +809,8 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
]
|
||||
assert_array_equal(expected_pad_sizes, pad_sizes)
|
||||
|
||||
from torch.distributed._tensor._collective_utils import unpad_tensor
|
||||
|
||||
unpadded_list = [
|
||||
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
||||
shard_placement._unpad_tensor(tensor, pad_sizes[i])
|
||||
if pad_sizes[i] > 0
|
||||
else tensor
|
||||
for i, tensor in enumerate(splitted_tensor_list)
|
||||
|
@ -164,24 +164,6 @@ def mesh_all_to_all(
|
||||
return work
|
||||
|
||||
|
||||
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
|
||||
if pad_size == 0:
|
||||
return tensor
|
||||
pad = [0, 0] * (tensor.ndim - pad_dim)
|
||||
pad[-1] = pad_size
|
||||
return torch.nn.functional.pad(tensor, pad)
|
||||
|
||||
|
||||
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
|
||||
if pad_size == 0:
|
||||
return tensor
|
||||
return tensor.narrow(
|
||||
pad_dim,
|
||||
start=0,
|
||||
length=tensor.size(pad_dim) - pad_size,
|
||||
)
|
||||
|
||||
|
||||
def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int:
|
||||
assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
|
||||
return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
|
||||
|
@ -7,12 +7,7 @@ import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
from torch.distributed._tensor._collective_utils import (
|
||||
mesh_broadcast,
|
||||
mesh_scatter,
|
||||
pad_tensor,
|
||||
unpad_tensor,
|
||||
)
|
||||
from torch.distributed._tensor._collective_utils import mesh_broadcast, mesh_scatter
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
@ -88,13 +83,37 @@ class Shard(Placement):
|
||||
for shard, pad_size in zip(tensor_list, pad_sizes):
|
||||
# Fill the empty tensor with zeroes with padding.
|
||||
if with_padding and pad_size > 0:
|
||||
shard = pad_tensor(shard, self.dim, pad_size)
|
||||
shard = self._pad_tensor(shard, pad_size)
|
||||
shard = shard.contiguous() if contiguous else shard
|
||||
shard_list.append(shard)
|
||||
return shard_list, pad_sizes
|
||||
else:
|
||||
return tensor_list, pad_sizes
|
||||
|
||||
def _pad_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
if pad_size == 0:
|
||||
return tensor
|
||||
pad = [0, 0] * (tensor.ndim - self.dim)
|
||||
pad[-1] = pad_size
|
||||
return torch.nn.functional.pad(tensor, pad)
|
||||
|
||||
def _unpad_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
if pad_size == 0:
|
||||
return tensor
|
||||
return tensor.narrow(
|
||||
self.dim,
|
||||
start=0,
|
||||
length=tensor.size(self.dim) - pad_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _local_shard_size_on_dim(
|
||||
size_on_dim: int,
|
||||
@ -147,7 +166,7 @@ class Shard(Placement):
|
||||
# Only unpad if the local_tensor was padded on the dimension.
|
||||
pad_size = pad_sizes[my_coordinate[mesh_dim]]
|
||||
if pad_size > 0:
|
||||
output = unpad_tensor(output, self.dim, pad_size)
|
||||
output = self._unpad_tensor(output, pad_size)
|
||||
return output
|
||||
|
||||
def _reduce_shard_tensor(
|
||||
@ -182,7 +201,7 @@ class Shard(Placement):
|
||||
)
|
||||
|
||||
if is_padded:
|
||||
output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
|
||||
output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
|
||||
return output
|
||||
|
||||
def _to_replicate_tensor(
|
||||
@ -206,7 +225,7 @@ class Shard(Placement):
|
||||
if is_padded:
|
||||
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
|
||||
pad_size = full_chunk_size - local_shape[self.dim]
|
||||
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
|
||||
local_tensor = self._pad_tensor(local_tensor, pad_size)
|
||||
|
||||
if not local_tensor.is_contiguous():
|
||||
local_tensor = local_tensor.contiguous()
|
||||
@ -218,7 +237,7 @@ class Shard(Placement):
|
||||
)
|
||||
if is_padded:
|
||||
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
|
||||
result = unpad_tensor(result, self.dim, unpad_size)
|
||||
result = self._unpad_tensor(result, unpad_size)
|
||||
return result
|
||||
|
||||
def _replicate_to_shard(
|
||||
|
Reference in New Issue
Block a user