Revert "Remove split functional wrapper (#74727)"

This reverts commit cc3126083ecc4ac5d3952ee59b5fd47e53d45718.

Reverted https://github.com/pytorch/pytorch/pull/74727 on behalf of https://github.com/mehtanirav due to Breaking multiple internals builds and tests
This commit is contained in:
PyTorch MergeBot
2022-07-11 18:29:45 +00:00
parent b946e7a7f2
commit 7f3677d723
9 changed files with 82 additions and 60 deletions

View File

@ -1702,7 +1702,7 @@ on inplace modification of the outputs.
add_docstr(torch.unsafe_split,
r"""
unsafe_split(tensor, split_size, dim=0) -> List of Tensors
unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors
Works like :func:`torch.split` but without enforcing the autograd restrictions
on inplace modification of the outputs.
@ -10166,50 +10166,6 @@ Example::
[4, 6]]])
""".format(**common_args))
add_docstr(torch.split,
r"""
split(input, split_size, dim=0) -> List[Tensor]
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If :attr:`split_size` is an integer type, then :attr:`tensor` will
be split into equally sized chunks (if possible). Last chunk will be smaller if
the tensor size along the given dimension :attr:`dim` is not divisible by
:attr:`split_size`.
If :attr:`split_size` is a list, then :attr:`tensor` will be split
into ``len(split_size)`` chunks with sizes in :attr:`dim` according
to :attr:`split_size`.
Args:
tensor (Tensor): tensor to split.
split_size (int) or (list(int)): size of a single chunk or
list of sizes for each chunk
dim (int): dimension along which to split the tensor.
Example::
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
""")
add_docstr(torch.take,
r"""
take(input, index) -> Tensor