mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Raise TypeError when the argument to isinf and isfinite is not a tensor (#20817)
Summary: Currently when the argument to isinf and isfinite is not tensor, a ValueError is raised. This, however, should be a TypeError, because the error is a type mismatch. In the error message, "str(tensor)" is replaced by "repr(tensor)" because, when an error occurs, a printable representation of the object is likely more useful than the "informal" string version of the object. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20817 Differential Revision: D15495624 Pulled By: ezyang fbshipit-source-id: 514198dcd723a7031818e50a87e187b22d51af73
This commit is contained in:
committed by
Facebook Github Bot
parent
87040af498
commit
ef1fdc27a3
@ -226,7 +226,7 @@ def isfinite(tensor):
|
||||
tensor([ 1, 0, 1, 0, 0], dtype=torch.uint8)
|
||||
"""
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise ValueError("The argument is not a tensor", str(tensor))
|
||||
raise TypeError("The argument is not a tensor: {}".format(repr(tensor)))
|
||||
|
||||
# Support int input, nan and inf are concepts in floating point numbers.
|
||||
# Numpy uses type 'Object' when the int overflows long, but we don't
|
||||
@ -252,7 +252,7 @@ def isinf(tensor):
|
||||
tensor([ 0, 1, 0, 1, 0], dtype=torch.uint8)
|
||||
"""
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise ValueError("The argument is not a tensor", str(tensor))
|
||||
raise TypeError("The argument is not a tensor: {}".format(repr(tensor)))
|
||||
if tensor.dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
return torch.zeros_like(tensor, dtype=torch.uint8)
|
||||
return tensor.abs() == inf
|
||||
|
Reference in New Issue
Block a user