Improve argument printing (#87601)

No more "expected tuple but got tuple".  We appropriately
grovel in the list/tuple for the element that mismatched
and report what exactly twinged the failure.

invalid_arguments.cpp is a shitshow so I did something
slapdash to get it not completely horrible.  See
https://github.com/pytorch/pytorch/issues/87514 for more context.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87601
Approved by: https://github.com/Chillee
This commit is contained in:
albanD
2022-10-24 15:37:20 -04:00
committed by PyTorch MergeBot
parent 72ec1b5fc1
commit 3263bd24be
4 changed files with 130 additions and 21 deletions

View File

@ -19,6 +19,46 @@ class IntListWrapperModule(torch.nn.Module):
class TestNativeFunctions(TestCase):
def _lists_with_str(self):
return [
("foo",),
(2, "foo"),
("foo", 3),
["foo"],
[2, "foo"],
["foo", 3],
"foo",
]
def _test_raises_str_typeerror(self, fn):
for arg in self._lists_with_str():
self.assertRaisesRegex(TypeError, "str", lambda: fn(arg))
try:
fn(arg)
except TypeError as e:
print(e)
def test_symintlist_error(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
def test_vararg_symintlist_error(self):
self._test_raises_str_typeerror(lambda arg: torch.rand(arg))
self._test_raises_str_typeerror(lambda arg: torch.rand(*arg))
def test_symintlist_error_with_overload_but_is_unique(self):
x = torch.randn(1)
y = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg))
def test_symintlist_error_with_overload(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: x.view(arg))
def test_intlist_error_with_overload(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
#
# optional float list
#
@ -113,7 +153,7 @@ class TestNativeFunctions(TestCase):
self.do_test_optional_intlist_with_module(fake_module)
def test_optional_intlist_invalid(self):
with self.assertRaisesRegex(TypeError, "must be .* not"):
with self.assertRaisesRegex(TypeError, "must be .* but found"):
IntListWrapperModule()(torch.zeros(1), [0.5])
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):