mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/Skylion007
207 lines
5.8 KiB
Python
207 lines
5.8 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
import torch
|
|
from torch import nn
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
|
|
r"""
|
|
Test TorchScript exception handling.
|
|
"""
|
|
|
|
|
|
class TestException(TestCase):
|
|
def test_pyop_exception_message(self):
|
|
class Foo(torch.jit.ScriptModule):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 10, kernel_size=5)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
foo = Foo()
|
|
# testing that the correct error message propagates
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
|
|
):
|
|
foo(torch.ones([123])) # wrong size
|
|
|
|
def test_builtin_error_messsage(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
|
|
|
|
@torch.jit.script
|
|
def close_match(x):
|
|
return x.masked_fill(True)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"This op may not exist or may not be currently supported in TorchScript",
|
|
):
|
|
|
|
@torch.jit.script
|
|
def unknown_op(x):
|
|
torch.set_anomaly_enabled(True)
|
|
return x
|
|
|
|
def test_exceptions(self):
|
|
cu = torch.jit.CompilationUnit(
|
|
"""
|
|
def foo(cond):
|
|
if bool(cond):
|
|
raise ValueError(3)
|
|
return 1
|
|
"""
|
|
)
|
|
|
|
cu.foo(torch.tensor(0))
|
|
with self.assertRaisesRegex(torch.jit.Error, "3"):
|
|
cu.foo(torch.tensor(1))
|
|
|
|
def foo(cond):
|
|
a = 3
|
|
if bool(cond):
|
|
raise ArbitraryError(a, "hi") # noqa: F821
|
|
if 1 == 2:
|
|
raise ArbitraryError # noqa: F821
|
|
return a
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
|
|
torch.jit.script(foo)
|
|
|
|
def exception_as_value():
|
|
a = Exception()
|
|
print(a)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
|
|
torch.jit.script(exception_as_value)
|
|
|
|
@torch.jit.script
|
|
def foo_no_decl_always_throws():
|
|
raise RuntimeError("Hi")
|
|
|
|
# function that has no declared type but always throws set to None
|
|
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
|
|
self.assertTrue(str(output_type) == "NoneType")
|
|
|
|
@torch.jit.script
|
|
def foo_decl_always_throws():
|
|
# type: () -> Tensor
|
|
raise Exception("Hi") # noqa: TRY002
|
|
|
|
output_type = next(foo_decl_always_throws.graph.outputs()).type()
|
|
self.assertTrue(str(output_type) == "Tensor")
|
|
|
|
def foo():
|
|
raise 3 + 4
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
|
|
torch.jit.script(foo)
|
|
|
|
# a escapes scope
|
|
@torch.jit.script
|
|
def foo():
|
|
if 1 == 1:
|
|
a = 1
|
|
else:
|
|
if 1 == 1:
|
|
raise Exception("Hi") # noqa: TRY002
|
|
else:
|
|
raise Exception("Hi") # noqa: TRY002
|
|
return a
|
|
|
|
self.assertEqual(foo(), 1)
|
|
|
|
@torch.jit.script
|
|
def tuple_fn():
|
|
raise RuntimeError("hello", "goodbye")
|
|
|
|
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
|
|
tuple_fn()
|
|
|
|
@torch.jit.script
|
|
def no_message():
|
|
raise RuntimeError
|
|
|
|
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
|
|
no_message()
|
|
|
|
def test_assertions(self):
|
|
cu = torch.jit.CompilationUnit(
|
|
"""
|
|
def foo(cond):
|
|
assert bool(cond), "hi"
|
|
return 0
|
|
"""
|
|
)
|
|
|
|
cu.foo(torch.tensor(1))
|
|
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
|
cu.foo(torch.tensor(0))
|
|
|
|
@torch.jit.script
|
|
def foo(cond):
|
|
assert bool(cond), "hi"
|
|
|
|
foo(torch.tensor(1))
|
|
# we don't currently validate the name of the exception
|
|
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
|
foo(torch.tensor(0))
|
|
|
|
def test_python_op_exception(self):
|
|
@torch.jit.ignore
|
|
def python_op(x):
|
|
raise Exception("bad!") # noqa: TRY002
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
return python_op(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "operation failed in the TorchScript interpreter"
|
|
):
|
|
fn(torch.tensor(4))
|
|
|
|
def test_dict_expansion_raises_error(self):
|
|
def fn(self):
|
|
d = {"foo": 1, "bar": 2, "baz": 3}
|
|
return {**d}
|
|
|
|
with self.assertRaisesRegex(
|
|
torch.jit.frontend.NotSupportedError, "Dict expansion "
|
|
):
|
|
torch.jit.script(fn)
|
|
|
|
def test_custom_python_exception(self):
|
|
class MyValueError(ValueError):
|
|
pass
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
raise MyValueError("test custom exception")
|
|
|
|
with self.assertRaisesRegex(
|
|
torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
|
|
):
|
|
fn()
|
|
|
|
def test_custom_python_exception_defined_elsewhere(self):
|
|
from jit.myexception import MyKeyError
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
raise MyKeyError("This is a user defined key error")
|
|
|
|
with self.assertRaisesRegex(
|
|
torch.jit.Error,
|
|
"jit.myexception.MyKeyError: This is a user defined key error",
|
|
):
|
|
fn()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test is not currently used and should be "
|
|
"enabled in discover_tests.py if required."
|
|
)
|