Raise TypeError when the argument to isinf and isfinite is not a tensor (#20817)

Summary:
Currently when the argument to isinf and isfinite is not tensor, a ValueError is raised. This, however, should be a TypeError, because the error is a type mismatch.

In the error message, "str(tensor)" is replaced by "repr(tensor)" because, when an error occurs, a printable representation of the object is likely more useful than the "informal" string version of the object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20817

Differential Revision: D15495624

Pulled By: ezyang

fbshipit-source-id: 514198dcd723a7031818e50a87e187b22d51af73
This commit is contained in:
Hong Xu
2019-05-24 09:09:16 -07:00
committed by Facebook Github Bot
parent 87040af498
commit ef1fdc27a3

View File

@ -226,7 +226,7 @@ def isfinite(tensor):
tensor([ 1, 0, 1, 0, 0], dtype=torch.uint8)
"""
if not isinstance(tensor, torch.Tensor):
raise ValueError("The argument is not a tensor", str(tensor))
raise TypeError("The argument is not a tensor: {}".format(repr(tensor)))
# Support int input, nan and inf are concepts in floating point numbers.
# Numpy uses type 'Object' when the int overflows long, but we don't
@ -252,7 +252,7 @@ def isinf(tensor):
tensor([ 0, 1, 0, 1, 0], dtype=torch.uint8)
"""
if not isinstance(tensor, torch.Tensor):
raise ValueError("The argument is not a tensor", str(tensor))
raise TypeError("The argument is not a tensor: {}".format(repr(tensor)))
if tensor.dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.zeros_like(tensor, dtype=torch.uint8)
return tensor.abs() == inf