mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve argument printing (#87601)
No more "expected tuple but got tuple". We appropriately grovel in the list/tuple for the element that mismatched and report what exactly twinged the failure. invalid_arguments.cpp is a shitshow so I did something slapdash to get it not completely horrible. See https://github.com/pytorch/pytorch/issues/87514 for more context. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/87601 Approved by: https://github.com/Chillee
This commit is contained in:
@ -19,6 +19,46 @@ class IntListWrapperModule(torch.nn.Module):
|
||||
|
||||
class TestNativeFunctions(TestCase):
|
||||
|
||||
def _lists_with_str(self):
|
||||
return [
|
||||
("foo",),
|
||||
(2, "foo"),
|
||||
("foo", 3),
|
||||
["foo"],
|
||||
[2, "foo"],
|
||||
["foo", 3],
|
||||
"foo",
|
||||
]
|
||||
|
||||
def _test_raises_str_typeerror(self, fn):
|
||||
for arg in self._lists_with_str():
|
||||
self.assertRaisesRegex(TypeError, "str", lambda: fn(arg))
|
||||
try:
|
||||
fn(arg)
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
|
||||
def test_symintlist_error(self):
|
||||
x = torch.randn(1)
|
||||
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
|
||||
|
||||
def test_vararg_symintlist_error(self):
|
||||
self._test_raises_str_typeerror(lambda arg: torch.rand(arg))
|
||||
self._test_raises_str_typeerror(lambda arg: torch.rand(*arg))
|
||||
|
||||
def test_symintlist_error_with_overload_but_is_unique(self):
|
||||
x = torch.randn(1)
|
||||
y = torch.randn(1)
|
||||
self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg))
|
||||
|
||||
def test_symintlist_error_with_overload(self):
|
||||
x = torch.randn(1)
|
||||
self._test_raises_str_typeerror(lambda arg: x.view(arg))
|
||||
|
||||
def test_intlist_error_with_overload(self):
|
||||
x = torch.randn(1)
|
||||
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
|
||||
|
||||
#
|
||||
# optional float list
|
||||
#
|
||||
@ -113,7 +153,7 @@ class TestNativeFunctions(TestCase):
|
||||
self.do_test_optional_intlist_with_module(fake_module)
|
||||
|
||||
def test_optional_intlist_invalid(self):
|
||||
with self.assertRaisesRegex(TypeError, "must be .* not"):
|
||||
with self.assertRaisesRegex(TypeError, "must be .* but found"):
|
||||
IntListWrapperModule()(torch.zeros(1), [0.5])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
||||
|
Reference in New Issue
Block a user