mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TorchScript, PT2] Add torch._check compatibility support (#159988)
Summary: Add support for torch._check() in TorchScript jit.script frontend. * It will be special cased to behave like torch._assert, turned into an if + raise exception. Test Plan: Unit tests Rollback Plan: Differential Revision: D79744604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159988 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
566c6d52ef
commit
731ee31f7b
@ -131,6 +131,164 @@ class TestBuiltins(JitTestCase):
|
||||
jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6})
|
||||
self.assertEqual(py_out, jit_out)
|
||||
|
||||
def test_torch_check(self):
|
||||
"""Test torch._check functionality with flexible argument handling"""
|
||||
|
||||
def test_check_basic(x):
|
||||
torch._check(x.sum().item() > -1000)
|
||||
return x
|
||||
|
||||
def test_check_with_message(x):
|
||||
torch._check(x.sum().item() > -1000, "Tensor sum must be reasonable")
|
||||
return x
|
||||
|
||||
def test_check_with_kwarg_message(x):
|
||||
torch._check(
|
||||
x.sum().item() > -1000, message="Tensor sum must be reasonable"
|
||||
)
|
||||
return x
|
||||
|
||||
def test_check_cond_kwarg(x):
|
||||
torch._check(cond=x.sum().item() > -1000)
|
||||
return x
|
||||
|
||||
def test_check_both_kwargs(x):
|
||||
torch._check(cond=x.sum().item() > -1000, message="Both as kwargs")
|
||||
return x
|
||||
|
||||
def test_check_kwargs_reversed(x):
|
||||
torch._check(message="Reversed order", cond=x.sum().item() > -1000)
|
||||
return x
|
||||
|
||||
def test_check_in_loop(x):
|
||||
sizes = torch.jit.annotate(List[int], x.tolist())
|
||||
for s in sizes:
|
||||
torch._check(s > -100)
|
||||
return x
|
||||
|
||||
test_tensor = torch.tensor([1, 2, 3])
|
||||
|
||||
# Test all variations
|
||||
self.checkScript(test_check_basic, (test_tensor,))
|
||||
self.checkScript(test_check_with_message, (test_tensor,))
|
||||
self.checkScript(test_check_with_kwarg_message, (test_tensor,))
|
||||
self.checkScript(test_check_cond_kwarg, (test_tensor,))
|
||||
self.checkScript(test_check_both_kwargs, (test_tensor,))
|
||||
self.checkScript(test_check_kwargs_reversed, (test_tensor,))
|
||||
self.checkScript(test_check_in_loop, (test_tensor,))
|
||||
|
||||
# Test that the compiled functions work correctly
|
||||
scripted_basic = torch.jit.script(test_check_basic)
|
||||
scripted_with_message = torch.jit.script(test_check_with_message)
|
||||
scripted_with_kwarg = torch.jit.script(test_check_with_kwarg_message)
|
||||
scripted_cond_kwarg = torch.jit.script(test_check_cond_kwarg)
|
||||
scripted_both_kwargs = torch.jit.script(test_check_both_kwargs)
|
||||
scripted_kwargs_reversed = torch.jit.script(test_check_kwargs_reversed)
|
||||
scripted_in_loop = torch.jit.script(test_check_in_loop)
|
||||
|
||||
# These should all succeed without throwing
|
||||
result1 = scripted_basic(test_tensor)
|
||||
result2 = scripted_with_message(test_tensor)
|
||||
result3 = scripted_with_kwarg(test_tensor)
|
||||
result4 = scripted_cond_kwarg(test_tensor)
|
||||
result5 = scripted_both_kwargs(test_tensor)
|
||||
result6 = scripted_kwargs_reversed(test_tensor)
|
||||
result7 = scripted_in_loop(test_tensor)
|
||||
|
||||
# Results should be the same as input
|
||||
for result in [result1, result2, result3, result4, result5, result6, result7]:
|
||||
self.assertEqual(result, test_tensor)
|
||||
|
||||
# Check that the message constants are present in the graphs
|
||||
FileCheck().check("Tensor sum must be reasonable").run(
|
||||
scripted_with_message.graph
|
||||
)
|
||||
FileCheck().check("Tensor sum must be reasonable").run(
|
||||
scripted_with_kwarg.graph
|
||||
)
|
||||
FileCheck().check("Both as kwargs").run(scripted_both_kwargs.graph)
|
||||
FileCheck().check("Reversed order").run(scripted_kwargs_reversed.graph)
|
||||
|
||||
# Verify the graphs contain some computation (not just empty)
|
||||
basic_graph_str = str(scripted_basic.graph)
|
||||
self.assertTrue(
|
||||
len(basic_graph_str) > 100, "Basic graph should contain some computation"
|
||||
)
|
||||
|
||||
# Verify the loop case contains a loop
|
||||
FileCheck().check("prim::Loop").run(scripted_in_loop.graph)
|
||||
|
||||
for scripted_func in [
|
||||
scripted_basic,
|
||||
scripted_with_message,
|
||||
scripted_with_kwarg,
|
||||
scripted_cond_kwarg,
|
||||
scripted_both_kwargs,
|
||||
scripted_kwargs_reversed,
|
||||
]:
|
||||
FileCheck().check("prim::If").check("prim::RaiseException").run(
|
||||
scripted_func.graph
|
||||
)
|
||||
|
||||
def test_torch_check_invalid_args(self):
|
||||
"""Test torch._check with invalid arguments"""
|
||||
|
||||
# Test too many arguments
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def too_many_args(x):
|
||||
torch._check(True, "msg", "extra")
|
||||
return x
|
||||
|
||||
# Test invalid keyword argument
|
||||
with self.assertRaisesRegex(RuntimeError, "unexpected keyword argument"):
|
||||
|
||||
@torch.jit.script
|
||||
def invalid_kwarg(x):
|
||||
torch._check(True, invalid_arg="msg")
|
||||
return x
|
||||
|
||||
# Test duplicate cond argument (positional + keyword)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "multiple values for argument 'cond'"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def duplicate_cond(x):
|
||||
torch._check(True, cond=False)
|
||||
return x
|
||||
|
||||
# Test missing required cond argument
|
||||
with self.assertRaisesRegex(RuntimeError, "missing required argument 'cond'"):
|
||||
|
||||
@torch.jit.script
|
||||
def missing_cond(x):
|
||||
torch._check(message="msg only")
|
||||
return x
|
||||
|
||||
# Test no arguments at all
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def no_args(x):
|
||||
torch._check()
|
||||
return x
|
||||
|
||||
# Test too many total arguments (positional + keyword)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def too_many_total_args(x):
|
||||
torch._check(True, "msg", cond=False)
|
||||
return x
|
||||
|
||||
|
||||
class TestTensorBuiltins(JitTestCase):
|
||||
def test_tensor_properties(self):
|
||||
|
Reference in New Issue
Block a user