mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] move pad/unpad_tensor to separate utils (#124871)
as titled, 1. pad/unpad is a general util not specific to the Shard placement, 2. for the propose of the next PR, move these two out of Shard placement itself, and give additional pad_dim argument Pull Request resolved: https://github.com/pytorch/pytorch/pull/124871 Approved by: https://github.com/awgu, https://github.com/wz337, https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
935a946241
commit
8d46ab4104
@ -809,8 +809,10 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
]
|
||||
assert_array_equal(expected_pad_sizes, pad_sizes)
|
||||
|
||||
from torch.distributed._tensor._collective_utils import unpad_tensor
|
||||
|
||||
unpadded_list = [
|
||||
shard_placement._unpad_tensor(tensor, pad_sizes[i])
|
||||
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
||||
if pad_sizes[i] > 0
|
||||
else tensor
|
||||
for i, tensor in enumerate(splitted_tensor_list)
|
||||
|
Reference in New Issue
Block a user