mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Update documentation for scatter_reduce
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74608 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
11606f5c8f
commit
11f1fef981
@ -3374,11 +3374,68 @@ Example::
|
||||
|
||||
""".format(**reproducibility_notes))
|
||||
|
||||
add_docstr_all('scatter_reduce', r"""
|
||||
scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor
|
||||
add_docstr_all('scatter_reduce_', r"""
|
||||
scatter_reduce_(dim, index, src, reduce, *, include_self=True) -> Tensor
|
||||
|
||||
See :func:`torch.scatter_reduce`
|
||||
""")
|
||||
Reduces all values from the :attr:`src` tensor to the indices specified in
|
||||
the :attr:`index` tensor in the :attr:`self` tensor using the applied reduction
|
||||
defined via the :attr:`reduce` argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`,
|
||||
:obj:`"amax"`, :obj:`"amin"`). For each value in :attr:`src`, it is reduced to an
|
||||
index in :attr:`self` which is specified by its index in :attr:`src` for
|
||||
``dimension != dim`` and by the corresponding value in :attr:`index` for
|
||||
``dimension = dim``. If :obj:`include_self="True"`, the values in the :attr:`self`
|
||||
tensor are included in the reduction.
|
||||
|
||||
:attr:`self`, :attr:`index` and :attr:`src` should all have
|
||||
the same number of dimensions. It is also required that
|
||||
``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that
|
||||
``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``.
|
||||
Note that ``index`` and ``src`` do not broadcast.
|
||||
|
||||
For a 3-D tensor with :obj:`reduce="sum"` and :obj:`include_self=True` the
|
||||
output is given as::
|
||||
|
||||
self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
|
||||
self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
|
||||
self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
|
||||
|
||||
Note:
|
||||
{forward_reproducibility_note}
|
||||
|
||||
.. note::
|
||||
|
||||
The backward pass is implemented only for ``src.shape == index.shape``.
|
||||
|
||||
.. warning::
|
||||
|
||||
This function is in beta and may change in the near future.
|
||||
|
||||
Args:
|
||||
dim (int): the axis along which to index
|
||||
index (LongTensor): the indices of elements to scatter and reduce.
|
||||
src (Tensor): the source elements to scatter and reduce
|
||||
reduce (str): the reduction operation to apply for non-unique indices
|
||||
(:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`)
|
||||
include_self (bool): whether elements from the :attr:`self` tensor are
|
||||
included in the reduction
|
||||
|
||||
Example::
|
||||
|
||||
>>> src = torch.tensor([1., 2., 3., 4., 5., 6.])
|
||||
>>> index = torch.tensor([0, 1, 0, 1, 2, 1])
|
||||
>>> input = torch.tensor([1., 2., 3., 4.])
|
||||
>>> input.scatter_reduce(0, index, src, reduce="sum")
|
||||
tensor([5., 14., 8., 4.])
|
||||
>>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False)
|
||||
tensor([4., 12., 5., 4.])
|
||||
>>> input2 = torch.tensor([5., 4., 3., 2.])
|
||||
>>> input2.scatter_reduce(0, index, src, reduce="amax")
|
||||
tensor([5., 6., 5., 2.])
|
||||
>>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False)
|
||||
tensor([3., 6., 5., 2.])
|
||||
|
||||
|
||||
""".format(**reproducibility_notes))
|
||||
|
||||
add_docstr_all('select',
|
||||
r"""
|
||||
@ -4766,6 +4823,13 @@ scatter_add(dim, index, src) -> Tensor
|
||||
Out-of-place version of :meth:`torch.Tensor.scatter_add_`
|
||||
""")
|
||||
|
||||
add_docstr_all('scatter_reduce',
|
||||
r"""
|
||||
scatter_reduce(dim, index, src, reduce, *, include_self=True) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.scatter_reduce_`
|
||||
""")
|
||||
|
||||
add_docstr_all('masked_scatter',
|
||||
r"""
|
||||
masked_scatter(mask, tensor) -> Tensor
|
||||
|
Reference in New Issue
Block a user