Update documentation for scatter_reduce

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74608

Approved by: https://github.com/cpuhrsch
This commit is contained in:
Mikayla Gawarecki
2022-04-04 22:07:32 +00:00
committed by PyTorch MergeBot
parent 11606f5c8f
commit 11f1fef981
4 changed files with 72 additions and 52 deletions

View File

@ -3374,11 +3374,68 @@ Example::
""".format(**reproducibility_notes))
add_docstr_all('scatter_reduce', r"""
scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor
add_docstr_all('scatter_reduce_', r"""
scatter_reduce_(dim, index, src, reduce, *, include_self=True) -> Tensor
See :func:`torch.scatter_reduce`
""")
Reduces all values from the :attr:`src` tensor to the indices specified in
the :attr:`index` tensor in the :attr:`self` tensor using the applied reduction
defined via the :attr:`reduce` argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`,
:obj:`"amax"`, :obj:`"amin"`). For each value in :attr:`src`, it is reduced to an
index in :attr:`self` which is specified by its index in :attr:`src` for
``dimension != dim`` and by the corresponding value in :attr:`index` for
``dimension = dim``. If :obj:`include_self="True"`, the values in the :attr:`self`
tensor are included in the reduction.
:attr:`self`, :attr:`index` and :attr:`src` should all have
the same number of dimensions. It is also required that
``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that
``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``.
Note that ``index`` and ``src`` do not broadcast.
For a 3-D tensor with :obj:`reduce="sum"` and :obj:`include_self=True` the
output is given as::
self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
Note:
{forward_reproducibility_note}
.. note::
The backward pass is implemented only for ``src.shape == index.shape``.
.. warning::
This function is in beta and may change in the near future.
Args:
dim (int): the axis along which to index
index (LongTensor): the indices of elements to scatter and reduce.
src (Tensor): the source elements to scatter and reduce
reduce (str): the reduction operation to apply for non-unique indices
(:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`)
include_self (bool): whether elements from the :attr:`self` tensor are
included in the reduction
Example::
>>> src = torch.tensor([1., 2., 3., 4., 5., 6.])
>>> index = torch.tensor([0, 1, 0, 1, 2, 1])
>>> input = torch.tensor([1., 2., 3., 4.])
>>> input.scatter_reduce(0, index, src, reduce="sum")
tensor([5., 14., 8., 4.])
>>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False)
tensor([4., 12., 5., 4.])
>>> input2 = torch.tensor([5., 4., 3., 2.])
>>> input2.scatter_reduce(0, index, src, reduce="amax")
tensor([5., 6., 5., 2.])
>>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False)
tensor([3., 6., 5., 2.])
""".format(**reproducibility_notes))
add_docstr_all('select',
r"""
@ -4766,6 +4823,13 @@ scatter_add(dim, index, src) -> Tensor
Out-of-place version of :meth:`torch.Tensor.scatter_add_`
""")
add_docstr_all('scatter_reduce',
r"""
scatter_reduce(dim, index, src, reduce, *, include_self=True) -> Tensor
Out-of-place version of :meth:`torch.Tensor.scatter_reduce_`
""")
add_docstr_all('masked_scatter',
r"""
masked_scatter(mask, tensor) -> Tensor