[dtensor] move pad/unpad_tensor to separate utils (#124871)

as titled, 1. pad/unpad is a general util not specific to the Shard
placement, 2. for the propose of the next PR, move these two out of Shard
placement itself, and give additional pad_dim argument

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124871
Approved by: https://github.com/awgu, https://github.com/wz337, https://github.com/XilunWu
This commit is contained in:
Wanchao Liang
2024-04-28 16:38:55 -07:00
committed by PyTorch MergeBot
parent 935a946241
commit 8d46ab4104
4 changed files with 38 additions and 36 deletions

View File

@ -809,8 +809,10 @@ class TestDTensorPlacementTypes(DTensorTestBase):
]
assert_array_equal(expected_pad_sizes, pad_sizes)
from torch.distributed._tensor._collective_utils import unpad_tensor
unpadded_list = [
shard_placement._unpad_tensor(tensor, pad_sizes[i])
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
if pad_sizes[i] > 0
else tensor
for i, tensor in enumerate(splitted_tensor_list)