Add torch.unflatten and improve its docs (#81399)

unflatten now has a free function version in torch.flatten in addition to
    the method in torch.Tensor.flatten.

    Updated docs to reflect this and polished them a little.
    For consistency, changed the signature of the int version of unflatten in
    native_functions.yaml.

    Some override tests were failing because unflatten has unusual
    characteristics in terms of the .int and .Dimname versions having
    different number of arguments so this required some changes
    to test/test_override.py

    Removed support for using mix of integer and string arguments
    when specifying dimensions in unflatten.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81399
Approved by: https://github.com/Lezcano, https://github.com/ngimel
This commit is contained in:
Fabio Rocha
2022-07-29 09:48:38 +00:00
committed by PyTorch MergeBot
parent 5257d1d64b
commit fd84c458f4
13 changed files with 66 additions and 56 deletions

View File

@ -2899,7 +2899,7 @@ static inline void handle_unflatten_exception(const std::runtime_error &e,
}
}
Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional<DimnameList> names) {
dim = maybe_wrap_dim(dim, self.dim());
TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty");
@ -2938,8 +2938,12 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option
return result;
}
Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes) {
return native::unflatten_impl(self, dim, sizes, c10::nullopt);
}
Tensor unflatten(const Tensor& self, Dimname dim, IntArrayRef sizes, DimnameList names) {
return native::unflatten(self, dimname_to_position(self, dim), sizes, names);
return native::unflatten_impl(self, dimname_to_position(self, dim), sizes, names);
}
Tensor view_as(const Tensor& self, const Tensor& other) {

View File

@ -2247,11 +2247,11 @@
- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
variants: function, method
- func: unflatten.int(Tensor(a) self, int dim, int[] sizes, Dimname[]? names=None) -> Tensor(a)
variants: method
- func: unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a)
variants: function, method
- func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a)
variants: method
variants: function, method
- func: fill.Scalar(Tensor self, Scalar value) -> Tensor
variants: function

View File

@ -300,7 +300,6 @@ operators, see :ref:`name_inference_reference-doc`.
.. automethod:: align_as
.. automethod:: align_to
.. automethod:: unflatten
.. py:method:: flatten(dims, out_dim) -> Tensor
:noindex:

View File

@ -685,6 +685,7 @@ Tensor class reference
Tensor.type
Tensor.type_as
Tensor.unbind
Tensor.unflatten
Tensor.unfold
Tensor.uniform_
Tensor.unique

View File

@ -533,6 +533,7 @@ Other Operations
tril_indices
triu
triu_indices
unflatten
vander
view_as_real
view_as_complex

View File

@ -114,6 +114,7 @@ ALLOW_LIST = [
("c10d::broadcast", datetime.date(2022, 6, 25)),
("aten::.*functional", datetime.date(2022, 8, 1)),
("aten::_foreach.*", datetime.date(2022, 8, 1)),
("aten::unflatten", datetime.date(2022, 8, 10)),
# TODO: FIXME: prims shouldn't be checked
("prims::.*", datetime.date(9999, 1, 1)),
]

View File

@ -1083,22 +1083,13 @@ class TestNamedTensor(TestCase):
def test_unflatten(self):
# test args: tensor, int, namedshape
self.assertTrue(torch.equal(
torch.ones(4).unflatten(0, (('A', 2), ('B', 2))),
torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))),
torch.ones(2, 2, names=('A', 'B'))))
self.assertTrue(torch.equal(
torch.ones(4).unflatten(0, [('A', 2), ('B', 2)]),
torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]),
torch.ones(2, 2, names=('A', 'B'))))
self.assertTrue(torch.equal(
torch.ones(4).unflatten(0, (['A', 2], ['B', 2])),
torch.ones(2, 2, names=('A', 'B'))))
self.assertTrue(torch.equal(
torch.ones(4).unflatten(-1, (['A', 2], ['B', 2])),
torch.ones(2, 2, names=('A', 'B'))))
self.assertTrue(torch.equal(
torch.ones(4).unflatten(-1, (['A', -1], ['B', 2])),
torch.ones(2, 2, names=('A', 'B'))))
self.assertTrue(torch.equal(
torch.ones(4).unflatten(-1, (['A', 2], ['B', -1])),
torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])),
torch.ones(2, 2, names=('A', 'B'))))
self.assertTrue(torch.equal(
torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
@ -1112,18 +1103,13 @@ class TestNamedTensor(TestCase):
.unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))
# test args: namedtensor, int, namedshape
self.assertTrue(torch.equal(
torch.ones(2, 4, names=('A', 'B')).unflatten(1, (('B1', 2), ('B2', 2))),
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
# test args: namedtensor, str, namedshape
self.assertTrue(torch.equal(
torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
# test invalid args: namedtensor, str, sizes
with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"):
with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
torch.tensor([1], names=('A',)).unflatten('A', (1, 1))
# test invalid args: namedtensor, int, sizes

View File

@ -337,7 +337,7 @@ def generate_tensor_like_torch_implementations():
msg = (
"The following functions are not tested for __torch_function__ "
"support, please ensure there is an entry in the dict returned by "
"torch._overrides.get_testing_overrides for this function or if a "
"torch.overrides.get_testing_overrides for this function or if a "
"__torch_function__ override does not make sense, add an entry to "
"the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
)
@ -648,7 +648,11 @@ def generate_tensor_like_override_tests(cls):
func_args.append(3.5)
elif t == 'bool':
func_args.append(False)
elif t.startswith('int') or t in {'Dimname', 'DimnameList'}:
elif t == 'Dimname':
func_args.append("")
elif t == 'DimnameList':
func_args.append([""])
elif t.startswith('int'):
func_args.append(0)
elif t in {'Stream'}:
func_args.append(torch.Stream())

View File

@ -5857,7 +5857,7 @@ class TestTorch(TestCase):
torch.ones(2, 3, 0, 4, 5, 2))
# test invalid args: tensor, str, sizes
with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"):
with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
torch.tensor([1]).unflatten('A', (1, 1))
# test invalid args: tensor, str, namedshape

View File

@ -1135,34 +1135,10 @@ class Tensor(torch._C._TensorBase):
)
def unflatten(self, dim, sizes):
r"""Expands the dimension :attr:`dim` of the :attr:`self` tensor over multiple dimensions
of sizes given by :attr:`sizes`.
r"""
unflatten(dim, sizes) -> Tensor
* :attr:`sizes` is the new shape of the unflattened dimension and it can be a `Tuple[int]` as well
as `torch.Size` if :attr:`self` is a `Tensor`, or `namedshape` (Tuple[(name: str, size: int)])
if :attr:`self` is a `NamedTensor`. The total number of elements in sizes must match the number
of elements in the original dim being unflattened.
Args:
dim (Union[int, str]): Dimension to unflatten
sizes (Union[Tuple[int] or torch.Size, Tuple[Tuple[str, int]]]): New shape of the unflattened dimension
Examples:
>>> torch.randn(3, 4, 1).unflatten(1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.randn(3, 4, 1).unflatten(1, (-1, 2)).shape # the size -1 is inferred from the size of dimension 1
torch.Size([3, 2, 2, 1])
>>> torch.randn(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2)))
tensor([[[-1.1772, 0.0180],
[ 0.2412, 0.1431]],
[[-1.1819, -0.8899],
[ 1.5813, 0.2274]]], names=('A', 'B1', 'B2'))
>>> torch.randn(2, names=('A',)).unflatten('A', (('B1', -1), ('B2', 1)))
tensor([[-0.8591],
[ 0.3100]], names=('B1', 'B2'))
.. warning::
The named tensor API is experimental and subject to change.
See :func:`torch.unflatten`.
"""
if has_torch_function_unary(self):
@ -1176,7 +1152,9 @@ class Tensor(torch._C._TensorBase):
isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))
):
names, sizes = unzip_namedshape(sizes)
return super(Tensor, self).unflatten(dim, sizes, names)
return super(Tensor, self).unflatten(dim, sizes, names)
else:
return super(Tensor, self).unflatten(dim, sizes)
def rename_(self, *names, **rename_map):
"""In-place version of :meth:`~Tensor.rename`."""

View File

@ -4552,6 +4552,41 @@ Example::
),
)
add_docstr(
torch.unflatten,
r"""
unflatten(input, dim, sizes) -> Tensor
Expands a dimension of the input tensor over multiple dimensions.
.. seealso::
:func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one.
Args:
{input}
dim (int): Dimension to be unflattened, specified as an index into
``input.shape``.
sizes (Tuple[int]): New shape of the unflattened dimension.
One of its elements can be `-1` in which case the corresponding output
dimension is inferred. Otherwise, the product of ``sizes`` *must*
equal ``input.shape[dim]``.
Returns:
A View of input with the specified dimension unflattened.
Examples::
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(5, 12, 3), -1, (2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])
""".format(
**common_args
),
)
add_docstr(
torch.gather,
r"""

View File

@ -126,7 +126,7 @@ Tensor UnflattenImpl::forward(const Tensor& input) {
}
return input.unflatten(dimname, sizes, names);
}
return input.unflatten(options.dim(), options.sizes(), torch::nullopt);
return input.unflatten(options.dim(), options.sizes());
}
// ============================================================================

View File

@ -1073,6 +1073,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.true_divide: lambda input, other: -1,
torch.trunc: lambda input, out=None: -1,
torch.unbind: lambda input, dim=0: -1,
torch.unflatten: lambda input, dim, sizes, names: -1,
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,