mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408 Approved by: https://github.com/justinchuby, https://github.com/janeyx99
200 lines
5.7 KiB
Python
200 lines
5.7 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
import torch
|
|
from torch import nn
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
|
|
r"""
|
|
Test TorchScript exception handling.
|
|
"""
|
|
|
|
|
|
class TestException(TestCase):
|
|
def test_pyop_exception_message(self):
|
|
class Foo(torch.jit.ScriptModule):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 10, kernel_size=5)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
foo = Foo()
|
|
# testing that the correct error message propagates
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
|
|
):
|
|
foo(torch.ones([123])) # wrong size
|
|
|
|
def test_builtin_error_messsage(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
|
|
|
|
@torch.jit.script
|
|
def close_match(x):
|
|
return x.masked_fill(True)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"This op may not exist or may not be currently supported in TorchScript",
|
|
):
|
|
|
|
@torch.jit.script
|
|
def unknown_op(x):
|
|
torch.set_anomaly_enabled(True)
|
|
return x
|
|
|
|
def test_exceptions(self):
|
|
cu = torch.jit.CompilationUnit(
|
|
"""
|
|
def foo(cond):
|
|
if bool(cond):
|
|
raise ValueError(3)
|
|
return 1
|
|
"""
|
|
)
|
|
|
|
cu.foo(torch.tensor(0))
|
|
with self.assertRaisesRegex(torch.jit.Error, "3"):
|
|
cu.foo(torch.tensor(1))
|
|
|
|
def foo(cond):
|
|
a = 3
|
|
if bool(cond):
|
|
raise ArbitraryError(a, "hi") # noqa: F821
|
|
if 1 == 2:
|
|
raise ArbitraryError # noqa: F821
|
|
return a
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
|
|
torch.jit.script(foo)
|
|
|
|
def exception_as_value():
|
|
a = Exception()
|
|
print(a)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
|
|
torch.jit.script(exception_as_value)
|
|
|
|
@torch.jit.script
|
|
def foo_no_decl_always_throws():
|
|
raise RuntimeError("Hi")
|
|
|
|
# function that has no declared type but always throws set to None
|
|
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
|
|
self.assertTrue(str(output_type) == "NoneType")
|
|
|
|
@torch.jit.script
|
|
def foo_decl_always_throws():
|
|
# type: () -> Tensor
|
|
raise Exception("Hi") # noqa: TRY002
|
|
|
|
output_type = next(foo_decl_always_throws.graph.outputs()).type()
|
|
self.assertTrue(str(output_type) == "Tensor")
|
|
|
|
def foo():
|
|
raise 3 + 4
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
|
|
torch.jit.script(foo)
|
|
|
|
# a escapes scope
|
|
@torch.jit.script
|
|
def foo():
|
|
if 1 == 1:
|
|
a = 1
|
|
else:
|
|
if 1 == 1:
|
|
raise Exception("Hi") # noqa: TRY002
|
|
else:
|
|
raise Exception("Hi") # noqa: TRY002
|
|
return a
|
|
|
|
self.assertEqual(foo(), 1)
|
|
|
|
@torch.jit.script
|
|
def tuple_fn():
|
|
raise RuntimeError("hello", "goodbye")
|
|
|
|
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
|
|
tuple_fn()
|
|
|
|
@torch.jit.script
|
|
def no_message():
|
|
raise RuntimeError
|
|
|
|
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
|
|
no_message()
|
|
|
|
def test_assertions(self):
|
|
cu = torch.jit.CompilationUnit(
|
|
"""
|
|
def foo(cond):
|
|
assert bool(cond), "hi"
|
|
return 0
|
|
"""
|
|
)
|
|
|
|
cu.foo(torch.tensor(1))
|
|
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
|
cu.foo(torch.tensor(0))
|
|
|
|
@torch.jit.script
|
|
def foo(cond):
|
|
assert bool(cond), "hi"
|
|
|
|
foo(torch.tensor(1))
|
|
# we don't currently validate the name of the exception
|
|
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
|
foo(torch.tensor(0))
|
|
|
|
def test_python_op_exception(self):
|
|
@torch.jit.ignore
|
|
def python_op(x):
|
|
raise Exception("bad!") # noqa: TRY002
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
return python_op(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "operation failed in the TorchScript interpreter"
|
|
):
|
|
fn(torch.tensor(4))
|
|
|
|
def test_dict_expansion_raises_error(self):
|
|
def fn(self):
|
|
d = {"foo": 1, "bar": 2, "baz": 3}
|
|
return {**d}
|
|
|
|
with self.assertRaisesRegex(
|
|
torch.jit.frontend.NotSupportedError, "Dict expansion "
|
|
):
|
|
torch.jit.script(fn)
|
|
|
|
def test_custom_python_exception(self):
|
|
class MyValueError(ValueError):
|
|
pass
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
raise MyValueError("test custom exception")
|
|
|
|
with self.assertRaisesRegex(
|
|
torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
|
|
):
|
|
fn()
|
|
|
|
def test_custom_python_exception_defined_elsewhere(self):
|
|
from jit.myexception import MyKeyError
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
raise MyKeyError("This is a user defined key error")
|
|
|
|
with self.assertRaisesRegex(
|
|
torch.jit.Error,
|
|
"jit.myexception.MyKeyError: This is a user defined key error",
|
|
):
|
|
fn()
|