mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	[numpy] add decimals argument to round (#66195)
Summary: Fixes https://github.com/pytorch/pytorch/issues/65908 Added a new overload instead of updating the current signature. (Had issues with JIT and **maybe** it would have been FC breaking) TODO: * [x] Don't compute `std::pow(10, decimals)` for each element. * [x] Update docs (https://docs-preview.pytorch.org/66195/generated/torch.round.html?highlight=round#torch.round) * [x] Add tests * ~~Should we try to make it composite?~~ * ~~Should we add specialized test with more values of `decimals` outside of OpInfo with larger range of values in input tensor?~~ cc mruberry rgommers Pull Request resolved: https://github.com/pytorch/pytorch/pull/66195 Reviewed By: anjali411 Differential Revision: D31821385 Pulled By: mruberry fbshipit-source-id: 9a03fcb809440f0c83530108284e69c345e1850f (cherry picked from commit 50b67c696880b8dcfc42796956b4780b83bf7a7e)
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							7e6312a5df
						
					
				
				
					commit
					d3bbb281f3
				
			| @ -8450,26 +8450,53 @@ row_stack(tensors, *, out=None) -> Tensor | ||||
| Alias of :func:`torch.vstack`. | ||||
| """) | ||||
|  | ||||
| add_docstr(torch.round, | ||||
|            r""" | ||||
| round(input, *, out=None) -> Tensor | ||||
| add_docstr(torch.round, r""" | ||||
| round(input, *, decimals=0, out=None) -> Tensor | ||||
|  | ||||
| Returns a new tensor with each of the elements of :attr:`input` rounded | ||||
| to the closest integer. | ||||
| Rounds elements of :attr:`input` to the nearest integer. | ||||
|  | ||||
| .. note:: | ||||
|     This function implements the "round half to even" to | ||||
|     break ties when a number is equidistant from two | ||||
|     integers (e.g. `round(2.5)` is 2). | ||||
|  | ||||
|     When the :attr:\`decimals\` argument is specified the | ||||
|     algorithm used is similar to NumPy's `around`. This | ||||
|     algorithm is fast but inexact and it can easily | ||||
|     overflow for low precision dtypes. | ||||
|     Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. | ||||
|  | ||||
| .. seealso:: | ||||
|     :func:`torch.ceil`, which rounds up. | ||||
|     :func:`torch.floor`, which rounds down. | ||||
|     :func:`torch.trunc`, which rounds towards zero. | ||||
|  | ||||
| Args: | ||||
|     {input} | ||||
|     decimals (int): Number of decimal places to round to (default: 0). | ||||
|         If decimals is negative, it specifies the number of positions | ||||
|         to the left of the decimal point. | ||||
|  | ||||
| Keyword args: | ||||
|     {out} | ||||
|  | ||||
| Example:: | ||||
|  | ||||
|     >>> a = torch.randn(4) | ||||
|     >>> a | ||||
|     tensor([ 0.9920,  0.6077,  0.9734, -1.0362]) | ||||
|     >>> torch.round(a) | ||||
|     tensor([ 1.,  1.,  1., -1.]) | ||||
|     >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) | ||||
|     tensor([ 5.,  -2.,  9., -8.]) | ||||
|  | ||||
|     >>> # Values equidistant from two integers are rounded towards the | ||||
|     >>> #   the nearest even value (zero is treated as even) | ||||
|     >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) | ||||
|     tensor([-0., 0., 2., 2.]) | ||||
|  | ||||
|     >>> # A positive decimals argument rounds to the to that decimal place | ||||
|     >>> torch.round(torch.tensor([0.1234567]), decimals=3) | ||||
|     tensor([0.1230]) | ||||
|  | ||||
|     >>> # A negative decimals argument rounds to the left of the decimal | ||||
|     >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) | ||||
|     tensor([1000.]) | ||||
| """.format(**common_args)) | ||||
|  | ||||
| add_docstr(torch.rsqrt, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user