[TorchScript, PT2] Add torch._check compatibility support (#159988)

Summary:
Add support for torch._check() in TorchScript jit.script frontend.

* It will be special cased to behave like torch._assert, turned into an if + raise exception.

Test Plan:
Unit tests

Rollback Plan:

Differential Revision: D79744604

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159988
Approved by: https://github.com/davidberard98
This commit is contained in:
Yanan Cao (PyTorch)
2025-08-08 23:14:13 +00:00
committed by PyTorch MergeBot
parent 566c6d52ef
commit 731ee31f7b
5 changed files with 260 additions and 7 deletions

View File

@ -131,6 +131,164 @@ class TestBuiltins(JitTestCase):
jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6})
self.assertEqual(py_out, jit_out)
def test_torch_check(self):
"""Test torch._check functionality with flexible argument handling"""
def test_check_basic(x):
torch._check(x.sum().item() > -1000)
return x
def test_check_with_message(x):
torch._check(x.sum().item() > -1000, "Tensor sum must be reasonable")
return x
def test_check_with_kwarg_message(x):
torch._check(
x.sum().item() > -1000, message="Tensor sum must be reasonable"
)
return x
def test_check_cond_kwarg(x):
torch._check(cond=x.sum().item() > -1000)
return x
def test_check_both_kwargs(x):
torch._check(cond=x.sum().item() > -1000, message="Both as kwargs")
return x
def test_check_kwargs_reversed(x):
torch._check(message="Reversed order", cond=x.sum().item() > -1000)
return x
def test_check_in_loop(x):
sizes = torch.jit.annotate(List[int], x.tolist())
for s in sizes:
torch._check(s > -100)
return x
test_tensor = torch.tensor([1, 2, 3])
# Test all variations
self.checkScript(test_check_basic, (test_tensor,))
self.checkScript(test_check_with_message, (test_tensor,))
self.checkScript(test_check_with_kwarg_message, (test_tensor,))
self.checkScript(test_check_cond_kwarg, (test_tensor,))
self.checkScript(test_check_both_kwargs, (test_tensor,))
self.checkScript(test_check_kwargs_reversed, (test_tensor,))
self.checkScript(test_check_in_loop, (test_tensor,))
# Test that the compiled functions work correctly
scripted_basic = torch.jit.script(test_check_basic)
scripted_with_message = torch.jit.script(test_check_with_message)
scripted_with_kwarg = torch.jit.script(test_check_with_kwarg_message)
scripted_cond_kwarg = torch.jit.script(test_check_cond_kwarg)
scripted_both_kwargs = torch.jit.script(test_check_both_kwargs)
scripted_kwargs_reversed = torch.jit.script(test_check_kwargs_reversed)
scripted_in_loop = torch.jit.script(test_check_in_loop)
# These should all succeed without throwing
result1 = scripted_basic(test_tensor)
result2 = scripted_with_message(test_tensor)
result3 = scripted_with_kwarg(test_tensor)
result4 = scripted_cond_kwarg(test_tensor)
result5 = scripted_both_kwargs(test_tensor)
result6 = scripted_kwargs_reversed(test_tensor)
result7 = scripted_in_loop(test_tensor)
# Results should be the same as input
for result in [result1, result2, result3, result4, result5, result6, result7]:
self.assertEqual(result, test_tensor)
# Check that the message constants are present in the graphs
FileCheck().check("Tensor sum must be reasonable").run(
scripted_with_message.graph
)
FileCheck().check("Tensor sum must be reasonable").run(
scripted_with_kwarg.graph
)
FileCheck().check("Both as kwargs").run(scripted_both_kwargs.graph)
FileCheck().check("Reversed order").run(scripted_kwargs_reversed.graph)
# Verify the graphs contain some computation (not just empty)
basic_graph_str = str(scripted_basic.graph)
self.assertTrue(
len(basic_graph_str) > 100, "Basic graph should contain some computation"
)
# Verify the loop case contains a loop
FileCheck().check("prim::Loop").run(scripted_in_loop.graph)
for scripted_func in [
scripted_basic,
scripted_with_message,
scripted_with_kwarg,
scripted_cond_kwarg,
scripted_both_kwargs,
scripted_kwargs_reversed,
]:
FileCheck().check("prim::If").check("prim::RaiseException").run(
scripted_func.graph
)
def test_torch_check_invalid_args(self):
"""Test torch._check with invalid arguments"""
# Test too many arguments
with self.assertRaisesRegex(
RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments"
):
@torch.jit.script
def too_many_args(x):
torch._check(True, "msg", "extra")
return x
# Test invalid keyword argument
with self.assertRaisesRegex(RuntimeError, "unexpected keyword argument"):
@torch.jit.script
def invalid_kwarg(x):
torch._check(True, invalid_arg="msg")
return x
# Test duplicate cond argument (positional + keyword)
with self.assertRaisesRegex(
RuntimeError, "multiple values for argument 'cond'"
):
@torch.jit.script
def duplicate_cond(x):
torch._check(True, cond=False)
return x
# Test missing required cond argument
with self.assertRaisesRegex(RuntimeError, "missing required argument 'cond'"):
@torch.jit.script
def missing_cond(x):
torch._check(message="msg only")
return x
# Test no arguments at all
with self.assertRaisesRegex(
RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments"
):
@torch.jit.script
def no_args(x):
torch._check()
return x
# Test too many total arguments (positional + keyword)
with self.assertRaisesRegex(
RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments"
):
@torch.jit.script
def too_many_total_args(x):
torch._check(True, "msg", cond=False)
return x
class TestTensorBuiltins(JitTestCase):
def test_tensor_properties(self):