mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
update the tensor.scatter_ doc (#120169)
Fixes #119543 - doc fixed with the `reduce` being a kwarg (see below for details) - doc added another interface `(int dim, Tensor index, Number value, *, str reduce)` where the full signature in the pyi file after build is ``` def scatter_(self, dim: _int, index: Tensor, value: Union[Number, _complex], *, reduce: str) -> Tensor: ``` . This can be further verified in02fb043522/aten/src/ATen/native/native_functions.yaml (L8014)
Therefore, the value can be int, bool, float, or complex type. Besides the issue mentioned in 119543, the `reduce should be a kwarg` as shown below ``` * (int dim, Tensor index, Tensor src) * (int dim, Tensor index, Tensor src, *, str reduce) * (int dim, Tensor index, Number value) * (int dim, Tensor index, Number value, *, str reduce) ``` The test case for scala value is already implemented in70bc3b3be4/test/test_scatter_gather_ops.py (L86)
so no additional test case required. @mikaylagawarecki @janeyx99 Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/120169 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb6f50929b
commit
3426c6f559
@ -4380,7 +4380,7 @@ In-place version of :meth:`~Tensor.rsqrt`
|
||||
add_docstr_all(
|
||||
"scatter_",
|
||||
r"""
|
||||
scatter_(dim, index, src, reduce=None) -> Tensor
|
||||
scatter_(dim, index, src, *, reduce=None) -> Tensor
|
||||
|
||||
Writes all values from the tensor :attr:`src` into :attr:`self` at the indices
|
||||
specified in the :attr:`index` tensor. For each value in :attr:`src`, its output
|
||||
@ -4443,7 +4443,9 @@ Args:
|
||||
index (LongTensor): the indices of elements to scatter, can be either empty
|
||||
or of the same dimensionality as ``src``. When empty, the operation
|
||||
returns ``self`` unchanged.
|
||||
src (Tensor or float): the source element(s) to scatter.
|
||||
src (Tensor): the source element(s) to scatter.
|
||||
|
||||
Keyword args:
|
||||
reduce (str, optional): reduction operation to apply, can be either
|
||||
``'add'`` or ``'multiply'``.
|
||||
|
||||
@ -4473,6 +4475,32 @@ Example::
|
||||
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
|
||||
[2.0000, 2.0000, 2.0000, 3.2300]])
|
||||
|
||||
.. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor:
|
||||
:noindex:
|
||||
|
||||
Writes the value from :attr:`value` into :attr:`self` at the indices
|
||||
specified in the :attr:`index` tensor. This operation is equivalent to the previous version,
|
||||
with the :attr:`src` tensor filled entirely with :attr:`value`.
|
||||
|
||||
Args:
|
||||
dim (int): the axis along which to index
|
||||
index (LongTensor): the indices of elements to scatter, can be either empty
|
||||
or of the same dimensionality as ``src``. When empty, the operation
|
||||
returns ``self`` unchanged.
|
||||
value (Scalar): the value to scatter.
|
||||
|
||||
Keyword args:
|
||||
reduce (str, optional): reduction operation to apply, can be either
|
||||
``'add'`` or ``'multiply'``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> index = torch.tensor([[0, 1]])
|
||||
>>> value = 2
|
||||
>>> torch.zeros(3, 5).scatter_(0, index, value)
|
||||
tensor([[2., 0., 0., 0., 0.],
|
||||
[0., 2., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.]])
|
||||
""",
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user