Files
pytorch/test/test_native_functions.py
Aaron Gokaslan c5fafe9f48 [BE]: TRY002 - Ban raising vanilla exceptions (#124570)
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.

I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
2024-04-21 22:26:40 +00:00

243 lines
8.9 KiB
Python

# Owner(s): ["module: unknown"]
from typing import Optional, List
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
# End-to-end tests of features in native_functions.yaml
class FloatListWrapperModule(torch.nn.Module):
def forward(self, values, incr: Optional[List[float]]):
return torch._C._nn._test_optional_floatlist(values, incr)
class IntListWrapperModule(torch.nn.Module):
def forward(self, values, incr: Optional[List[int]]):
return torch._C._nn._test_optional_intlist(values, incr)
class TestNativeFunctions(TestCase):
def _lists_with_str(self):
return [
("foo",),
(2, "foo"),
("foo", 3),
["foo"],
[2, "foo"],
["foo", 3],
"foo",
]
def _test_raises_str_typeerror(self, fn):
for arg in self._lists_with_str():
self.assertRaisesRegex(TypeError, "str", lambda: fn(arg))
try:
fn(arg)
except TypeError as e:
print(e)
def test_symintlist_error(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
def test_vararg_symintlist_error(self):
self._test_raises_str_typeerror(lambda arg: torch.rand(arg))
self._test_raises_str_typeerror(lambda arg: torch.rand(*arg))
def test_symintlist_error_with_overload_but_is_unique(self):
x = torch.randn(1)
y = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg))
def test_symintlist_error_with_overload(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: x.view(arg))
def test_intlist_error_with_overload(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
#
# optional float list
#
def do_test_optional_floatlist_with_module(self, module):
values = torch.tensor([1.5, 2.5], dtype=torch.float)
returned = module(values, None)
self.assertEqual(values, returned)
# Make sure that it's an alias, indicating that the operator saw a nullopt.
values[0] = 3.5
self.assertEqual(values, returned)
returned = module(values, [5.1, 4.1])
self.assertEqual(values, torch.tensor([3.5, 2.5], dtype=torch.float))
self.assertEqual(returned, torch.tensor([8.6, 6.6], dtype=torch.float))
def trace_optional_floatlist(self, const):
def wrapper(values):
return torch._C._nn._test_optional_floatlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_floatlist(self):
self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
traced_none = self.trace_optional_floatlist(None)
traced_list = self.trace_optional_floatlist([5.1, 4.1])
# Not really a module, just lets us use our two traced functions to handle
# the specific cases of passing None and [5.1, 4.1].
def fake_module(values, const):
if const is None:
return traced_none(values)
if const == [5.1, 4.1]:
return traced_list(values)
raise Exception("Invalid argument") # noqa: TRY002
self.do_test_optional_floatlist_with_module(fake_module)
def test_optional_floatlist_invalid(self):
with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"):
FloatListWrapperModule()(torch.zeros(1), ["hi"])
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"])
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
FloatListWrapperModule()(torch.zeros(1), torch.zeros(1))
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1))
#
# optional int list
#
def do_test_optional_intlist_with_module(self, module):
values = torch.tensor([1, 2], dtype=torch.int)
returned = module(values, None)
self.assertEqual(values, returned)
# Make sure that it's an alias, indicating that the operator saw a nullopt.
values[0] = 3
self.assertEqual(values, returned)
returned = module(values, [5, 4])
self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int))
def trace_optional_intlist(self, const):
def wrapper(values):
return torch._C._nn._test_optional_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_intlist(self):
self.do_test_optional_intlist_with_module(IntListWrapperModule())
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
traced_none = self.trace_optional_intlist(None)
traced_list = self.trace_optional_intlist([5, 4])
# Not really a module, just lets us use our two traced functions to handle
# the specific cases of passing None and [5, 4].
def fake_module(values, const):
if const is None:
return traced_none(values)
if const == [5, 4]:
return traced_list(values)
raise Exception("Invalid argument") # noqa: TRY002
self.do_test_optional_intlist_with_module(fake_module)
def test_optional_intlist_invalid(self):
with self.assertRaisesRegex(TypeError, "must be .* but found"):
IntListWrapperModule()(torch.zeros(1), [0.5])
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5])
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
IntListWrapperModule()(torch.zeros(1), torch.zeros(1))
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))
#
# optional filled int list
#
def do_test_optional_filled_intlist_with_module(self, module):
values = torch.tensor([1, 2], dtype=torch.int)
returned = module(values, None)
self.assertEqual(values, returned)
# Make sure that it's an alias, indicating that the operator saw a nullopt.
values[0] = 3
self.assertEqual(values, returned)
returned = module(values, 10)
self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
self.assertEqual(returned, torch.tensor([13, 12], dtype=torch.int))
def trace_optional_filled_intlist(self, const):
def wrapper(values):
return torch._C._nn._test_optional_filled_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_filled_intlist(self):
def f(n: int):
x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n))
y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n)
return x, y
# eager
returned = f(10)
self.assertEqual(returned[0], returned[1])
# scripted
s = torch.jit.script(f)
returned = s(10)
self.assertEqual(returned[0], returned[1])
# traced
traced_none = self.trace_optional_filled_intlist(None)
traced_int = self.trace_optional_filled_intlist(10)
# Not really a module, just lets us use our two traced functions to handle
# the specific cases of passing None and 10.
def fake_module(values, const):
if const is None:
return traced_none(values)
if const == 10:
return traced_int(values)
raise Exception("Invalid argument") # noqa: TRY002
self.do_test_optional_filled_intlist_with_module(fake_module)
def test_string_defaults(self):
dummy = torch.rand(1)
fn = torch._C._nn._test_string_default
fn(dummy)
with self.assertRaisesRegex(RuntimeError, "A"):
fn(dummy, a="")
with self.assertRaisesRegex(RuntimeError, "B"):
fn(dummy, b="")
def f(x):
torch._C._nn._test_string_default(x)
scripted_fn = torch.jit.script(f)
scripted_fn(dummy)
if __name__ == '__main__':
run_tests()