Add itertools.{prod, combinations, combinations_with_replacement} like op to pytorch (#9393)

Summary:
closes https://github.com/pytorch/pytorch/issues/7580
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9393

Differential Revision: D13659628

Pulled By: zou3519

fbshipit-source-id: 3a233befa785709395a793ba8833413be394a6fd
This commit is contained in:
Xiang Gao
2019-01-15 08:24:27 -08:00
committed by Facebook Github Bot
parent 964732fa8d
commit 1065e7cd24
6 changed files with 202 additions and 1 deletions

View File

@ -6267,3 +6267,47 @@ Example::
>>> [7, 8, 9]]))
(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9]))
""")
add_docstr(torch.combinations,
r"""
combinations(tensor, r=2, with_replacement=False) -> seq
Compute combinations of length :math:`r` of the given tensor. The behavior is similar to
python's `itertools.combinations` when `with_replacement` is set to `False`, and
`itertools.combinations_with_replacement` when `with_replacement` is set to `True`.
Arguments:
tensor (Tensor): 1D vector.
r (int, optional): number of elements to combine
with_replacement (boolean, optional): whether to allow duplication in combination
Returns:
Tensor: A tensor equivalent to converting all the input tensors into lists, do
`itertools.combinations` or `itertools.combinations_with_replacement` on these
lists, and finally convert the resulting list into tensor.
Example::
>>> a = [1, 2, 3]
>>> list(itertools.combinations(a, r=2))
[(1, 2), (1, 3), (2, 3)]
>>> list(itertools.combinations(a, r=3))
[(1, 2, 3)]
>>> list(itertools.combinations_with_replacement(a, r=2))
[(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
>>> tensor_a = torch.tensor(a)
>>> torch.combinations(tensor_a)
tensor([[1, 2],
[1, 3],
[2, 3]])
>>> torch.combinations(tensor_a, r=3)
tensor([[1, 2, 3]])
>>> torch.combinations(tensor_a, with_replacement=True)
tensor([[1, 1],
[1, 2],
[1, 3],
[2, 2],
[2, 3],
[3, 3]])
""")