mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "Remove split functional wrapper (#74727)"
This reverts commit cc3126083ecc4ac5d3952ee59b5fd47e53d45718. Reverted https://github.com/pytorch/pytorch/pull/74727 on behalf of https://github.com/mehtanirav due to Breaking multiple internals builds and tests
This commit is contained in:
@ -1702,7 +1702,7 @@ on inplace modification of the outputs.
|
||||
|
||||
add_docstr(torch.unsafe_split,
|
||||
r"""
|
||||
unsafe_split(tensor, split_size, dim=0) -> List of Tensors
|
||||
unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors
|
||||
|
||||
Works like :func:`torch.split` but without enforcing the autograd restrictions
|
||||
on inplace modification of the outputs.
|
||||
@ -10166,50 +10166,6 @@ Example::
|
||||
[4, 6]]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.split,
|
||||
r"""
|
||||
split(input, split_size, dim=0) -> List[Tensor]
|
||||
|
||||
Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||
|
||||
If :attr:`split_size` is an integer type, then :attr:`tensor` will
|
||||
be split into equally sized chunks (if possible). Last chunk will be smaller if
|
||||
the tensor size along the given dimension :attr:`dim` is not divisible by
|
||||
:attr:`split_size`.
|
||||
|
||||
If :attr:`split_size` is a list, then :attr:`tensor` will be split
|
||||
into ``len(split_size)`` chunks with sizes in :attr:`dim` according
|
||||
to :attr:`split_size`.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): tensor to split.
|
||||
split_size (int) or (list(int)): size of a single chunk or
|
||||
list of sizes for each chunk
|
||||
dim (int): dimension along which to split the tensor.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.arange(10).reshape(5,2)
|
||||
>>> a
|
||||
tensor([[0, 1],
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[8, 9]])
|
||||
>>> torch.split(a, 2)
|
||||
(tensor([[0, 1],
|
||||
[2, 3]]),
|
||||
tensor([[4, 5],
|
||||
[6, 7]]),
|
||||
tensor([[8, 9]]))
|
||||
>>> torch.split(a, [1,4])
|
||||
(tensor([[0, 1]]),
|
||||
tensor([[2, 3],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[8, 9]]))
|
||||
""")
|
||||
|
||||
add_docstr(torch.take,
|
||||
r"""
|
||||
take(input, index) -> Tensor
|
||||
|
Reference in New Issue
Block a user