[numpy] add decimals argument to round (#66195)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/65908

Added a new overload instead of updating the current signature. (Had issues with JIT and **maybe** it would have been FC breaking)

TODO:

* [x] Don't compute `std::pow(10, decimals)` for each element.
* [x] Update docs (https://docs-preview.pytorch.org/66195/generated/torch.round.html?highlight=round#torch.round)
* [x] Add tests
* ~~Should we try to make it composite?~~
* ~~Should we add specialized test with more values of `decimals` outside of OpInfo with larger range of values in input tensor?~~

cc mruberry rgommers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66195

Reviewed By: anjali411

Differential Revision: D31821385

Pulled By: mruberry

fbshipit-source-id: 9a03fcb809440f0c83530108284e69c345e1850f
(cherry picked from commit 50b67c696880b8dcfc42796956b4780b83bf7a7e)
This commit is contained in:
kshitij12345
2022-01-26 09:22:56 -08:00
committed by PyTorch MergeBot
parent 7e6312a5df
commit d3bbb281f3
10 changed files with 194 additions and 20 deletions

View File

@ -8450,26 +8450,53 @@ row_stack(tensors, *, out=None) -> Tensor
Alias of :func:`torch.vstack`.
""")
add_docstr(torch.round,
r"""
round(input, *, out=None) -> Tensor
add_docstr(torch.round, r"""
round(input, *, decimals=0, out=None) -> Tensor
Returns a new tensor with each of the elements of :attr:`input` rounded
to the closest integer.
Rounds elements of :attr:`input` to the nearest integer.
.. note::
This function implements the "round half to even" to
break ties when a number is equidistant from two
integers (e.g. `round(2.5)` is 2).
When the :attr:\`decimals\` argument is specified the
algorithm used is similar to NumPy's `around`. This
algorithm is fast but inexact and it can easily
overflow for low precision dtypes.
Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`.
.. seealso::
:func:`torch.ceil`, which rounds up.
:func:`torch.floor`, which rounds down.
:func:`torch.trunc`, which rounds towards zero.
Args:
{input}
decimals (int): Number of decimal places to round to (default: 0).
If decimals is negative, it specifies the number of positions
to the left of the decimal point.
Keyword args:
{out}
Example::
>>> a = torch.randn(4)
>>> a
tensor([ 0.9920, 0.6077, 0.9734, -1.0362])
>>> torch.round(a)
tensor([ 1., 1., 1., -1.])
>>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7)))
tensor([ 5., -2., 9., -8.])
>>> # Values equidistant from two integers are rounded towards the
>>> # the nearest even value (zero is treated as even)
>>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5]))
tensor([-0., 0., 2., 2.])
>>> # A positive decimals argument rounds to the to that decimal place
>>> torch.round(torch.tensor([0.1234567]), decimals=3)
tensor([0.1230])
>>> # A negative decimals argument rounds to the left of the decimal
>>> torch.round(torch.tensor([1200.1234567]), decimals=-3)
tensor([1000.])
""".format(**common_args))
add_docstr(torch.rsqrt,