update the tensor.scatter_ doc (#120169)

Fixes #119543

- doc fixed with the `reduce` being a kwarg (see below for details)
- doc added another interface `(int dim, Tensor index, Number value, *, str reduce)` where
the full signature in the pyi file after build is
```
def scatter_(self, dim: _int, index: Tensor, value: Union[Number, _complex], *, reduce: str) -> Tensor:
```
. This can be further verified in
02fb043522/aten/src/ATen/native/native_functions.yaml (L8014)

Therefore, the value can be int, bool, float, or complex type.

Besides the issue mentioned in 119543, the `reduce should be a kwarg` as shown below
```
 * (int dim, Tensor index, Tensor src)
 * (int dim, Tensor index, Tensor src, *, str reduce)
 * (int dim, Tensor index, Number value)
 * (int dim, Tensor index, Number value, *, str reduce)
 ```

The test case for scala value is already implemented in

70bc3b3be4/test/test_scatter_gather_ops.py (L86)

so no additional test case required.

@mikaylagawarecki  @janeyx99

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120169
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
lancerts
2024-02-23 02:51:53 +00:00
committed by PyTorch MergeBot
parent bb6f50929b
commit 3426c6f559

View File

@ -4380,7 +4380,7 @@ In-place version of :meth:`~Tensor.rsqrt`
add_docstr_all(
"scatter_",
r"""
scatter_(dim, index, src, reduce=None) -> Tensor
scatter_(dim, index, src, *, reduce=None) -> Tensor
Writes all values from the tensor :attr:`src` into :attr:`self` at the indices
specified in the :attr:`index` tensor. For each value in :attr:`src`, its output
@ -4443,7 +4443,9 @@ Args:
index (LongTensor): the indices of elements to scatter, can be either empty
or of the same dimensionality as ``src``. When empty, the operation
returns ``self`` unchanged.
src (Tensor or float): the source element(s) to scatter.
src (Tensor): the source element(s) to scatter.
Keyword args:
reduce (str, optional): reduction operation to apply, can be either
``'add'`` or ``'multiply'``.
@ -4473,6 +4475,32 @@ Example::
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
.. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor:
:noindex:
Writes the value from :attr:`value` into :attr:`self` at the indices
specified in the :attr:`index` tensor. This operation is equivalent to the previous version,
with the :attr:`src` tensor filled entirely with :attr:`value`.
Args:
dim (int): the axis along which to index
index (LongTensor): the indices of elements to scatter, can be either empty
or of the same dimensionality as ``src``. When empty, the operation
returns ``self`` unchanged.
value (Scalar): the value to scatter.
Keyword args:
reduce (str, optional): reduction operation to apply, can be either
``'add'`` or ``'multiply'``.
Example::
>>> index = torch.tensor([[0, 1]])
>>> value = 2
>>> torch.zeros(3, 5).scatter_(0, index, value)
tensor([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
""",
)