mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.unflatten and improve its docs (#81399)
unflatten now has a free function version in torch.flatten in addition to the method in torch.Tensor.flatten. Updated docs to reflect this and polished them a little. For consistency, changed the signature of the int version of unflatten in native_functions.yaml. Some override tests were failing because unflatten has unusual characteristics in terms of the .int and .Dimname versions having different number of arguments so this required some changes to test/test_override.py Removed support for using mix of integer and string arguments when specifying dimensions in unflatten. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81399 Approved by: https://github.com/Lezcano, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5257d1d64b
commit
fd84c458f4
@ -2899,7 +2899,7 @@ static inline void handle_unflatten_exception(const std::runtime_error &e,
|
||||
}
|
||||
}
|
||||
|
||||
Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
|
||||
Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
|
||||
dim = maybe_wrap_dim(dim, self.dim());
|
||||
|
||||
TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty");
|
||||
@ -2938,8 +2938,12 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes) {
|
||||
return native::unflatten_impl(self, dim, sizes, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor unflatten(const Tensor& self, Dimname dim, IntArrayRef sizes, DimnameList names) {
|
||||
return native::unflatten(self, dimname_to_position(self, dim), sizes, names);
|
||||
return native::unflatten_impl(self, dimname_to_position(self, dim), sizes, names);
|
||||
}
|
||||
|
||||
Tensor view_as(const Tensor& self, const Tensor& other) {
|
||||
|
@ -2247,11 +2247,11 @@
|
||||
- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
|
||||
variants: function, method
|
||||
|
||||
- func: unflatten.int(Tensor(a) self, int dim, int[] sizes, Dimname[]? names=None) -> Tensor(a)
|
||||
variants: method
|
||||
- func: unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a)
|
||||
variants: function, method
|
||||
|
||||
- func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a)
|
||||
variants: method
|
||||
variants: function, method
|
||||
|
||||
- func: fill.Scalar(Tensor self, Scalar value) -> Tensor
|
||||
variants: function
|
||||
|
@ -300,7 +300,6 @@ operators, see :ref:`name_inference_reference-doc`.
|
||||
.. automethod:: align_as
|
||||
.. automethod:: align_to
|
||||
|
||||
.. automethod:: unflatten
|
||||
.. py:method:: flatten(dims, out_dim) -> Tensor
|
||||
:noindex:
|
||||
|
||||
|
@ -685,6 +685,7 @@ Tensor class reference
|
||||
Tensor.type
|
||||
Tensor.type_as
|
||||
Tensor.unbind
|
||||
Tensor.unflatten
|
||||
Tensor.unfold
|
||||
Tensor.uniform_
|
||||
Tensor.unique
|
||||
|
@ -533,6 +533,7 @@ Other Operations
|
||||
tril_indices
|
||||
triu
|
||||
triu_indices
|
||||
unflatten
|
||||
vander
|
||||
view_as_real
|
||||
view_as_complex
|
||||
|
@ -114,6 +114,7 @@ ALLOW_LIST = [
|
||||
("c10d::broadcast", datetime.date(2022, 6, 25)),
|
||||
("aten::.*functional", datetime.date(2022, 8, 1)),
|
||||
("aten::_foreach.*", datetime.date(2022, 8, 1)),
|
||||
("aten::unflatten", datetime.date(2022, 8, 10)),
|
||||
# TODO: FIXME: prims shouldn't be checked
|
||||
("prims::.*", datetime.date(9999, 1, 1)),
|
||||
]
|
||||
|
@ -1083,22 +1083,13 @@ class TestNamedTensor(TestCase):
|
||||
def test_unflatten(self):
|
||||
# test args: tensor, int, namedshape
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4).unflatten(0, (('A', 2), ('B', 2))),
|
||||
torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4).unflatten(0, [('A', 2), ('B', 2)]),
|
||||
torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4).unflatten(0, (['A', 2], ['B', 2])),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4).unflatten(-1, (['A', 2], ['B', 2])),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4).unflatten(-1, (['A', -1], ['B', 2])),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(4).unflatten(-1, (['A', 2], ['B', -1])),
|
||||
torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])),
|
||||
torch.ones(2, 2, names=('A', 'B'))))
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
|
||||
@ -1112,18 +1103,13 @@ class TestNamedTensor(TestCase):
|
||||
.unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
|
||||
torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))
|
||||
|
||||
# test args: namedtensor, int, namedshape
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(2, 4, names=('A', 'B')).unflatten(1, (('B1', 2), ('B2', 2))),
|
||||
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
|
||||
|
||||
# test args: namedtensor, str, namedshape
|
||||
self.assertTrue(torch.equal(
|
||||
torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
|
||||
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
|
||||
|
||||
# test invalid args: namedtensor, str, sizes
|
||||
with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"):
|
||||
with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
|
||||
torch.tensor([1], names=('A',)).unflatten('A', (1, 1))
|
||||
|
||||
# test invalid args: namedtensor, int, sizes
|
||||
|
@ -337,7 +337,7 @@ def generate_tensor_like_torch_implementations():
|
||||
msg = (
|
||||
"The following functions are not tested for __torch_function__ "
|
||||
"support, please ensure there is an entry in the dict returned by "
|
||||
"torch._overrides.get_testing_overrides for this function or if a "
|
||||
"torch.overrides.get_testing_overrides for this function or if a "
|
||||
"__torch_function__ override does not make sense, add an entry to "
|
||||
"the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
|
||||
)
|
||||
@ -648,7 +648,11 @@ def generate_tensor_like_override_tests(cls):
|
||||
func_args.append(3.5)
|
||||
elif t == 'bool':
|
||||
func_args.append(False)
|
||||
elif t.startswith('int') or t in {'Dimname', 'DimnameList'}:
|
||||
elif t == 'Dimname':
|
||||
func_args.append("")
|
||||
elif t == 'DimnameList':
|
||||
func_args.append([""])
|
||||
elif t.startswith('int'):
|
||||
func_args.append(0)
|
||||
elif t in {'Stream'}:
|
||||
func_args.append(torch.Stream())
|
||||
|
@ -5857,7 +5857,7 @@ class TestTorch(TestCase):
|
||||
torch.ones(2, 3, 0, 4, 5, 2))
|
||||
|
||||
# test invalid args: tensor, str, sizes
|
||||
with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"):
|
||||
with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
|
||||
torch.tensor([1]).unflatten('A', (1, 1))
|
||||
|
||||
# test invalid args: tensor, str, namedshape
|
||||
|
@ -1135,34 +1135,10 @@ class Tensor(torch._C._TensorBase):
|
||||
)
|
||||
|
||||
def unflatten(self, dim, sizes):
|
||||
r"""Expands the dimension :attr:`dim` of the :attr:`self` tensor over multiple dimensions
|
||||
of sizes given by :attr:`sizes`.
|
||||
r"""
|
||||
unflatten(dim, sizes) -> Tensor
|
||||
|
||||
* :attr:`sizes` is the new shape of the unflattened dimension and it can be a `Tuple[int]` as well
|
||||
as `torch.Size` if :attr:`self` is a `Tensor`, or `namedshape` (Tuple[(name: str, size: int)])
|
||||
if :attr:`self` is a `NamedTensor`. The total number of elements in sizes must match the number
|
||||
of elements in the original dim being unflattened.
|
||||
|
||||
Args:
|
||||
dim (Union[int, str]): Dimension to unflatten
|
||||
sizes (Union[Tuple[int] or torch.Size, Tuple[Tuple[str, int]]]): New shape of the unflattened dimension
|
||||
|
||||
Examples:
|
||||
>>> torch.randn(3, 4, 1).unflatten(1, (2, 2)).shape
|
||||
torch.Size([3, 2, 2, 1])
|
||||
>>> torch.randn(3, 4, 1).unflatten(1, (-1, 2)).shape # the size -1 is inferred from the size of dimension 1
|
||||
torch.Size([3, 2, 2, 1])
|
||||
>>> torch.randn(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2)))
|
||||
tensor([[[-1.1772, 0.0180],
|
||||
[ 0.2412, 0.1431]],
|
||||
[[-1.1819, -0.8899],
|
||||
[ 1.5813, 0.2274]]], names=('A', 'B1', 'B2'))
|
||||
>>> torch.randn(2, names=('A',)).unflatten('A', (('B1', -1), ('B2', 1)))
|
||||
tensor([[-0.8591],
|
||||
[ 0.3100]], names=('B1', 'B2'))
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
See :func:`torch.unflatten`.
|
||||
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
@ -1176,7 +1152,9 @@ class Tensor(torch._C._TensorBase):
|
||||
isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))
|
||||
):
|
||||
names, sizes = unzip_namedshape(sizes)
|
||||
return super(Tensor, self).unflatten(dim, sizes, names)
|
||||
return super(Tensor, self).unflatten(dim, sizes, names)
|
||||
else:
|
||||
return super(Tensor, self).unflatten(dim, sizes)
|
||||
|
||||
def rename_(self, *names, **rename_map):
|
||||
"""In-place version of :meth:`~Tensor.rename`."""
|
||||
|
@ -4552,6 +4552,41 @@ Example::
|
||||
),
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.unflatten,
|
||||
r"""
|
||||
unflatten(input, dim, sizes) -> Tensor
|
||||
|
||||
Expands a dimension of the input tensor over multiple dimensions.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
dim (int): Dimension to be unflattened, specified as an index into
|
||||
``input.shape``.
|
||||
sizes (Tuple[int]): New shape of the unflattened dimension.
|
||||
One of its elements can be `-1` in which case the corresponding output
|
||||
dimension is inferred. Otherwise, the product of ``sizes`` *must*
|
||||
equal ``input.shape[dim]``.
|
||||
|
||||
Returns:
|
||||
A View of input with the specified dimension unflattened.
|
||||
|
||||
Examples::
|
||||
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
|
||||
torch.Size([3, 2, 2, 1])
|
||||
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
|
||||
torch.Size([3, 2, 2, 1])
|
||||
>>> torch.unflatten(torch.randn(5, 12, 3), -1, (2, 2, 3, 1, 1)).shape
|
||||
torch.Size([5, 2, 2, 3, 1, 1, 3])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.gather,
|
||||
r"""
|
||||
|
@ -126,7 +126,7 @@ Tensor UnflattenImpl::forward(const Tensor& input) {
|
||||
}
|
||||
return input.unflatten(dimname, sizes, names);
|
||||
}
|
||||
return input.unflatten(options.dim(), options.sizes(), torch::nullopt);
|
||||
return input.unflatten(options.dim(), options.sizes());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
@ -1073,6 +1073,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.true_divide: lambda input, other: -1,
|
||||
torch.trunc: lambda input, out=None: -1,
|
||||
torch.unbind: lambda input, dim=0: -1,
|
||||
torch.unflatten: lambda input, dim, sizes, names: -1,
|
||||
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
|
||||
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
|
||||
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
|
||||
|
Reference in New Issue
Block a user