# Owner(s): ["oncall: jit"] import inspect import os import sys import unittest from typing import Dict, List import torch from torch.testing import FileCheck # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA class TestBuiltins(JitTestCase): """ Tests for TorchScript support of Python builtin functions. """ def test_has_attr(self): class HasA(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = 0 class HasB(torch.nn.Module): def __init__(self) -> None: super().__init__() self.b = 1 class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mods = torch.nn.ModuleList([HasA(), HasB()]) def forward(self): # use a list to encode hasattr results l = torch.jit.annotate(List[int], []) for mod in self.mods: l.append(int(hasattr(mod, "a"))) l.append(int(hasattr(mod, "b"))) # actually retrieve the attr to test static refinement if hasattr(mod, "a"): l.append(mod.a) if hasattr(mod, "b"): l.append(mod.b) return l self.checkModule(Mod(), ()) def test_has_attr_invalid_args(self): class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mod = torch.nn.Linear(1, 1) def forward(self, name): # not allowed, `name` must be static. return hasattr(self.mod, name) with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"): torch.jit.script(Mod()) class Mod(torch.nn.Module): def forward(self, name): # not allowed, `torch.rand` is not a class type return hasattr(torch.rand(2, 3), name) with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"): torch.jit.script(Mod()) def test_del(self): def fn(x: List[int]) -> List[int]: a = x * 2 del a return x self.checkScript(fn, ([1, 2, 3],)) with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): @torch.jit.script def fn(x): a = x**2 del a return a # noqa: F821 with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): @torch.jit.script def fn(x): a = x**2 if a: del a return a with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"): @torch.jit.script def fn(x): a = x**2 del b # noqa: F821 return a def test_del_multiple_operands(self): def fn(x: List[int]) -> List[int]: a, b, c = x[0], x[1], x[2] del a, b, c return x self.checkScript(fn, ([1, 2, 3],)) def del_list_multiple_operands(x: List[int]) -> List[int]: del x[0], x[1] return x py_out = del_list_multiple_operands([0, 1, 2]) jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) self.assertEqual(py_out, jit_out) def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: del x["hi"], x["there"] return x py_out = del_dict_multiple_operands({"hi": 5, "there": 6}) jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) self.assertEqual(py_out, jit_out) def test_torch_check(self): """Test torch._check functionality with flexible argument handling""" def test_check_basic(x): torch._check(x.sum().item() > -1000) return x def test_check_with_message(x): torch._check(x.sum().item() > -1000, "Tensor sum must be reasonable") return x def test_check_with_kwarg_message(x): torch._check( x.sum().item() > -1000, message="Tensor sum must be reasonable" ) return x def test_check_cond_kwarg(x): torch._check(cond=x.sum().item() > -1000) return x def test_check_both_kwargs(x): torch._check(cond=x.sum().item() > -1000, message="Both as kwargs") return x def test_check_kwargs_reversed(x): torch._check(message="Reversed order", cond=x.sum().item() > -1000) return x def test_check_in_loop(x): sizes = torch.jit.annotate(List[int], x.tolist()) for s in sizes: torch._check(s > -100) return x test_tensor = torch.tensor([1, 2, 3]) # Test all variations self.checkScript(test_check_basic, (test_tensor,)) self.checkScript(test_check_with_message, (test_tensor,)) self.checkScript(test_check_with_kwarg_message, (test_tensor,)) self.checkScript(test_check_cond_kwarg, (test_tensor,)) self.checkScript(test_check_both_kwargs, (test_tensor,)) self.checkScript(test_check_kwargs_reversed, (test_tensor,)) self.checkScript(test_check_in_loop, (test_tensor,)) # Test that the compiled functions work correctly scripted_basic = torch.jit.script(test_check_basic) scripted_with_message = torch.jit.script(test_check_with_message) scripted_with_kwarg = torch.jit.script(test_check_with_kwarg_message) scripted_cond_kwarg = torch.jit.script(test_check_cond_kwarg) scripted_both_kwargs = torch.jit.script(test_check_both_kwargs) scripted_kwargs_reversed = torch.jit.script(test_check_kwargs_reversed) scripted_in_loop = torch.jit.script(test_check_in_loop) # These should all succeed without throwing result1 = scripted_basic(test_tensor) result2 = scripted_with_message(test_tensor) result3 = scripted_with_kwarg(test_tensor) result4 = scripted_cond_kwarg(test_tensor) result5 = scripted_both_kwargs(test_tensor) result6 = scripted_kwargs_reversed(test_tensor) result7 = scripted_in_loop(test_tensor) # Results should be the same as input for result in [result1, result2, result3, result4, result5, result6, result7]: self.assertEqual(result, test_tensor) # Check that the message constants are present in the graphs FileCheck().check("Tensor sum must be reasonable").run( scripted_with_message.graph ) FileCheck().check("Tensor sum must be reasonable").run( scripted_with_kwarg.graph ) FileCheck().check("Both as kwargs").run(scripted_both_kwargs.graph) FileCheck().check("Reversed order").run(scripted_kwargs_reversed.graph) # Verify the graphs contain some computation (not just empty) basic_graph_str = str(scripted_basic.graph) self.assertTrue( len(basic_graph_str) > 100, "Basic graph should contain some computation" ) # Verify the loop case contains a loop FileCheck().check("prim::Loop").run(scripted_in_loop.graph) for scripted_func in [ scripted_basic, scripted_with_message, scripted_with_kwarg, scripted_cond_kwarg, scripted_both_kwargs, scripted_kwargs_reversed, ]: FileCheck().check("prim::If").check("prim::RaiseException").run( scripted_func.graph ) def test_torch_check_invalid_args(self): """Test torch._check with invalid arguments""" # Test too many arguments with self.assertRaisesRegex( RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" ): @torch.jit.script def too_many_args(x): torch._check(True, "msg", "extra") return x # Test invalid keyword argument with self.assertRaisesRegex(RuntimeError, "unexpected keyword argument"): @torch.jit.script def invalid_kwarg(x): torch._check(True, invalid_arg="msg") return x # Test duplicate cond argument (positional + keyword) with self.assertRaisesRegex( RuntimeError, "multiple values for argument 'cond'" ): @torch.jit.script def duplicate_cond(x): torch._check(True, cond=False) return x # Test missing required cond argument with self.assertRaisesRegex(RuntimeError, "missing required argument 'cond'"): @torch.jit.script def missing_cond(x): torch._check(message="msg only") return x # Test no arguments at all with self.assertRaisesRegex( RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" ): @torch.jit.script def no_args(x): torch._check() return x # Test too many total arguments (positional + keyword) with self.assertRaisesRegex( RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" ): @torch.jit.script def too_many_total_args(x): torch._check(True, "msg", cond=False) return x class TestTensorBuiltins(JitTestCase): def test_tensor_properties(self): def should_keep(tensor, name): if inspect.isroutine(getattr(tensor, name)): return False if name.startswith("_"): return False return True tensor = torch.arange(4, dtype=torch.float).view(2, 2) keys = dir(tensor) # real and imag are only implemented for complex tensors. self.assertRaises(RuntimeError, lambda: should_keep(tensor, "imag")) keys.remove("imag") properties = [p for p in keys if should_keep(tensor, p)] code_template = """ def fn(x): return x.{} """ EQUALITY_MISMATCH = { # TorchScript doesn't have real enums so they return an int instead # of the actual value "dtype", "layout", } MISSING_PROPERTIES = { "grad_fn", # This is an undocumented property so it's not included "output_nr", # This has a longer implementation, maybe not worth copying to # TorchScript if named tensors don't work there anyways "names", # We don't plan to support grad_dtype in TorchScript "grad_dtype", } for p in properties: if p in MISSING_PROPERTIES: continue code = code_template.format(p) cu = torch.jit.CompilationUnit() cu.define(code) if p in EQUALITY_MISMATCH: continue self.assertEqual(getattr(tensor, p), cu.fn(tensor)) def test_tensor_subscript_assign(self): def fn1(x): a = torch.zeros_like(x, dtype=torch.uint8) a[torch.tensor(0)] = torch.tensor(2, dtype=torch.uint8) return a def fn2(x): a = torch.zeros_like(x, dtype=torch.uint8) a[0] = 2 return a def fn3(x): a = torch.zeros_like(x, dtype=torch.uint8) a[torch.tensor(0)] = 2 return a def fn4(x): a = torch.zeros_like(x, dtype=torch.uint8) a[0] = torch.tensor(2, dtype=torch.uint8) return a def fn5(x): a = torch.zeros_like(x, dtype=torch.float32) a[torch.tensor(0)] = 2 return a for fn in (fn1, fn2, fn3, fn4, fn5): self.checkScript(fn, (torch.zeros(2, dtype=torch.uint8),)) @unittest.skipIf(not RUN_CUDA, "requires CUDA") def test_tensor_subscript_assign_device(self): def fn6(x): a = torch.zeros_like(x, dtype=torch.float32, device="cuda") a[torch.tensor(0)] = 2 return a self.checkScript(fn6, (torch.zeros(2, dtype=torch.float32, device="cuda"),)) def test_tensor_item(self): def test_scalar_cast(x): scalar = x.item() return int(scalar), float(scalar) graph = torch.jit.script(test_scalar_cast).graph FileCheck().check("(int, float) = prim::TupleConstruct").run(graph) self.checkScript(test_scalar_cast, (torch.tensor(1.0),)) self.checkScript(test_scalar_cast, (torch.tensor(1),)) def test_method_on_number(self): def func(): c = 1 return c.add(1) with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): torch.jit.script(func) # testing implicit conversion of tensors to scalars to match function arguments def test_scalar_to_num_conversions(self): @torch.jit.script def multiple_defs(x): c = 1 x = x + c return x self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph)) @torch.jit.script def tensor_to_int_script(x, tensor): return x.unsqueeze(tensor) # location present in error message with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"): tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2])) def tensor_to_int(x, tensor): return x.unsqueeze(tensor) @torch.jit.script def tensor_to_float_script(x, tensor): return x.addcmul(tensor, tensor, value=tensor) def tensor_to_float(x, tensor): return x.addcmul(tensor, tensor, value=tensor) x = torch.zeros(10) # float tensor, float tensor with grad, int tensor (can't set grad on int tensor) tensors = [ torch.tensor(1.1), torch.tensor(1.1, requires_grad=True), torch.tensor(0), torch.tensor([2]), ] script_funs = [tensor_to_int_script, tensor_to_float_script] funs = [tensor_to_int, tensor_to_float] # return the result, or whether exception was thrown def test_func(func, x, tensor): try: result = func(x, tensor) except RuntimeError: result = True except TypeError: result = True return result # assert result or exception equal for each (function, inputs) for tensor in tensors: for i in range(len(script_funs)): self.assertEqual( test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor) ) if __name__ == "__main__": raise_on_run_directly("test/test_jit.py")