mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/clee2000
		
			
				
	
	
		
			504 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			504 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["oncall: jit"]
 | |
| # ruff: noqa: F841
 | |
| 
 | |
| import os
 | |
| import sys
 | |
| import unittest
 | |
| from typing import Any, Dict, List, Optional, Tuple
 | |
| 
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.testing._internal.jit_utils
 | |
| from jit.test_module_interface import TestModuleInterface  # noqa: F401
 | |
| from torch import jit
 | |
| from torch.testing import FileCheck
 | |
| from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly
 | |
| from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
 | |
| 
 | |
| 
 | |
| # Make the helper files in test/ importable
 | |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
 | |
| sys.path.append(pytorch_test_dir)
 | |
| 
 | |
| 
 | |
| class TestMisc(JitTestCase):
 | |
|     def test_joined_str(self):
 | |
|         def func(x):
 | |
|             hello, test = "Hello", "test"
 | |
|             print(f"{hello + ' ' + test}, I'm a {test}")
 | |
|             print("format blank")
 | |
|             hi = "hi"
 | |
|             print(f"stuff before {hi}")
 | |
|             print(f"{hi} stuff after")
 | |
|             return x + 1
 | |
| 
 | |
|         x = torch.arange(4.0, requires_grad=True)
 | |
|         # TODO: Add support for f-strings in string parser frontend
 | |
|         # self.checkScript(func, [x], optimize=True, capture_output=True)
 | |
| 
 | |
|         with self.capture_stdout() as captured:
 | |
|             out = func(x)
 | |
| 
 | |
|         scripted = torch.jit.script(func)
 | |
|         with self.capture_stdout() as captured_script:
 | |
|             out_script = func(x)
 | |
| 
 | |
|         self.assertEqual(out, out_script)
 | |
|         self.assertEqual(captured, captured_script)
 | |
| 
 | |
|     def test_kwarg_support(self):
 | |
|         with self.assertRaisesRegex(
 | |
|             torch.jit.frontend.NotSupportedError, "variable number of arguments"
 | |
|         ):
 | |
| 
 | |
|             class M(torch.nn.Module):
 | |
|                 def forward(self, *, n_tokens: int, device_name: str = 2):
 | |
|                     pass
 | |
| 
 | |
|             torch.jit.script(M())
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, *, n_tokens: int, device_name: str):
 | |
|                 return n_tokens, device_name
 | |
| 
 | |
|         sm = torch.jit.script(M())
 | |
| 
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError, "missing value for argument 'n_tokens'"
 | |
|         ):
 | |
|             sm()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "positional arg"):
 | |
|             sm(3, "hello")
 | |
| 
 | |
|         self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
 | |
| 
 | |
|     def test_tuple_subscripted_assign(self):
 | |
|         with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
 | |
| 
 | |
|             @torch.jit.script
 | |
|             def foo(a: Tuple[int, int]) -> None:
 | |
|                 a[0] = a[1]
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
 | |
| 
 | |
|             @torch.jit.script
 | |
|             def bar(a: Tuple[int, int]) -> None:
 | |
|                 a[0] += a[1]
 | |
| 
 | |
|     def test_subexpression_List_Future(self):
 | |
|         @torch.jit.script
 | |
|         def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
 | |
|             return x[0]
 | |
| 
 | |
|         FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
 | |
| 
 | |
|     def test_subexpression_Future_annotate(self):
 | |
|         @torch.jit.script
 | |
|         def fn() -> torch.jit.Future[int]:
 | |
|             x: List[torch.jit.Future[int]] = []
 | |
|             return x[0]
 | |
| 
 | |
|         FileCheck().check("Future[int][]").run(fn.graph)
 | |
| 
 | |
|     def test_future_isinstance(self):
 | |
|         @torch.jit.script
 | |
|         def fn(x: Any) -> torch.jit.Future[int]:
 | |
|             assert isinstance(x, jit.Future[int])
 | |
|             return x
 | |
| 
 | |
|         FileCheck().check("Future[int]").run(fn.graph)
 | |
| 
 | |
|     def test_str_refine_any(self):
 | |
|         def forward(x: Any) -> str:
 | |
|             if isinstance(x, str):
 | |
|                 return x
 | |
|             return "foo"
 | |
| 
 | |
|         forward = torch.jit.script(forward)
 | |
|         self.assertEqual(forward(1), "foo")
 | |
|         self.assertEqual(forward("bar"), "bar")
 | |
| 
 | |
|     def test_subexpression_Tuple_int_int_Future(self):
 | |
|         @torch.jit.script
 | |
|         def fn(
 | |
|             x: Tuple[int, int, torch.jit.Future[int]],
 | |
|         ) -> Tuple[int, torch.jit.Future[int]]:
 | |
|             return x[0], x[2]
 | |
| 
 | |
|         FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
 | |
|             fn.graph
 | |
|         )
 | |
| 
 | |
|     def test_subexpression_Dict_int_Future(self):
 | |
|         @torch.jit.script
 | |
|         def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
 | |
|             return x[y]
 | |
| 
 | |
|         FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
 | |
| 
 | |
|     def test_subexpression_Optional(self):
 | |
|         @torch.jit.script
 | |
|         def fn(
 | |
|             x: Optional[Dict[int, torch.jit.Future[int]]],
 | |
|         ) -> Optional[torch.jit.Future[int]]:
 | |
|             if x is not None:
 | |
|                 return x[0]
 | |
|             else:
 | |
|                 return None
 | |
| 
 | |
|         FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
 | |
| 
 | |
|     def test_if_returning_any(self):
 | |
|         """
 | |
|         Check that an if statement can return different
 | |
|         types early from each branch when the return
 | |
|         type of the function is Any.
 | |
|         """
 | |
| 
 | |
|         def if_function(inp: torch.Tensor) -> Any:
 | |
|             if inp.shape[0] == 1:
 | |
|                 return inp * inp
 | |
|             else:
 | |
|                 return "str"
 | |
| 
 | |
|         self.checkScript(if_function, (torch.randn(5),))
 | |
| 
 | |
|     def test_hacked_twin(self):
 | |
|         def gen_data():
 | |
|             with freeze_rng_state():
 | |
|                 return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
 | |
| 
 | |
|         (
 | |
|             input,
 | |
|             index,
 | |
|             value,
 | |
|         ) = gen_data()
 | |
|         (
 | |
|             input1,
 | |
|             index1,
 | |
|             value1,
 | |
|         ) = gen_data()
 | |
|         out1 = torch.ops.aten.index_put.hacked_twin(
 | |
|             input, [index], value, accumulate=False
 | |
|         )
 | |
|         out2 = torch.index_put(input1, [index1], value1, accumulate=False)
 | |
|         self.assertEqual(out1, out2)
 | |
| 
 | |
|         torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
 | |
|         torch.index_put_(input1, [index1], value1, accumulate=False)
 | |
|         self.assertEqual(input, input1)
 | |
| 
 | |
|     def test_unsafe_hacked_twin(self):
 | |
|         def gen_data():
 | |
|             with freeze_rng_state():
 | |
|                 return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
 | |
| 
 | |
|         (
 | |
|             input,
 | |
|             index,
 | |
|             value,
 | |
|         ) = gen_data()
 | |
|         (
 | |
|             input1,
 | |
|             index1,
 | |
|             value1,
 | |
|         ) = gen_data()
 | |
|         out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
 | |
|             input, [index], value, accumulate=False
 | |
|         )
 | |
|         out2 = torch.index_put(input1, [index1], value1, accumulate=False)
 | |
|         self.assertEqual(out1, out2)
 | |
| 
 | |
|         torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
 | |
|         torch.index_put(input1, [index1], value1, accumulate=False)
 | |
|         self.assertEqual(input, input1)
 | |
| 
 | |
|         def index_put_fn(input, index, value):
 | |
|             return torch.ops.aten._unsafe_index_put(
 | |
|                 input, [index], value, accumulate=False
 | |
|             )
 | |
| 
 | |
|         input2, index2, value2 = gen_data()
 | |
|         script_index_put_fn = torch.jit.script(index_put_fn)
 | |
|         expect = index_put_fn(input2.clone(), index2, value2)
 | |
|         actual = script_index_put_fn(input2.clone(), index2, value2)
 | |
|         self.assertEqual(expect, actual)
 | |
| 
 | |
|         def index_fn(input, index, value):
 | |
|             return torch.ops.aten._unsafe_index_put(
 | |
|                 input, [index], value, accumulate=False
 | |
|             )
 | |
| 
 | |
|         script_index_fn = torch.jit.script(index_fn)
 | |
|         expect = index_fn(input2.clone(), index2, value2)
 | |
|         actual = script_index_fn(input2.clone(), index2, value2)
 | |
|         self.assertEqual(expect, actual)
 | |
| 
 | |
|     def test_export_opnames_interface(self):
 | |
|         @torch.jit.interface
 | |
|         class OneTwoModule(nn.Module):
 | |
|             def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
 | |
|                 pass
 | |
| 
 | |
|             def two(self, x: torch.Tensor) -> torch.Tensor:
 | |
|                 pass
 | |
| 
 | |
|             def forward(self, x: torch.Tensor) -> torch.Tensor:
 | |
|                 pass
 | |
| 
 | |
|         class FooMod(nn.Module):
 | |
|             def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
 | |
|                 return x + y
 | |
| 
 | |
|             def two(self, x: torch.Tensor) -> torch.Tensor:
 | |
|                 return 2 * x
 | |
| 
 | |
|             def forward(self, x: torch.Tensor) -> torch.Tensor:
 | |
|                 return self.one(self.two(x), x)
 | |
| 
 | |
|         class BarMod(nn.Module):
 | |
|             def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
 | |
|                 return x * y
 | |
| 
 | |
|             def two(self, x: torch.Tensor) -> torch.Tensor:
 | |
|                 return 2 / x
 | |
| 
 | |
|             def forward(self, x: torch.Tensor) -> torch.Tensor:
 | |
|                 return self.two(self.one(x, x))
 | |
| 
 | |
|         make_global(OneTwoModule)
 | |
| 
 | |
|         class M(nn.Module):
 | |
|             sub: OneTwoModule
 | |
| 
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.sub = BarMod()
 | |
| 
 | |
|             def forward(self, x: torch.Tensor) -> torch.Tensor:
 | |
|                 return self.sub.forward(x)
 | |
| 
 | |
|         def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
 | |
|             return mod_list[0].forward(x) + mod_list[1].forward(x)
 | |
| 
 | |
|         torch._C._enable_mobile_interface_call_export()
 | |
|         scripted_M_mod = torch.jit.script(M())
 | |
|         self.assertTrue(
 | |
|             {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
 | |
|                 set(torch.jit.export_opnames(scripted_M_mod))
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         scripted_M_mod.sub = torch.jit.script(FooMod())
 | |
|         self.assertTrue(
 | |
|             {"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
 | |
|                 set(torch.jit.export_opnames(scripted_M_mod))
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_math_inf(self):
 | |
|         from math import inf
 | |
| 
 | |
|         def foo():
 | |
|             return inf
 | |
| 
 | |
|         self.checkScript(foo, ())
 | |
| 
 | |
|     def test_list_literal_infer(self):
 | |
|         def expects_intlist(x: List[int]):
 | |
|             x.append(3)
 | |
|             return x
 | |
| 
 | |
|         def foo():
 | |
|             return expects_intlist([])
 | |
| 
 | |
|         self.checkScript(foo, ())
 | |
| 
 | |
|         def annotated_list_fail():
 | |
|             return expects_intlist(torch.jit.annotate([], List[Tensor]))  # noqa: F821
 | |
| 
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             torch.jit.script(annotated_list_fail)
 | |
| 
 | |
|         def non_temporary_fail():
 | |
|             a = []
 | |
|             return expects_intlist(a)
 | |
| 
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             torch.jit.script(non_temporary_fail)
 | |
| 
 | |
|         @torch.jit.script
 | |
|         def test_return():
 | |
|             return []
 | |
| 
 | |
|         FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
 | |
| 
 | |
|     def test_legacy_tensor_constructor(self):
 | |
|         # testing PyObject overload
 | |
|         def test_all_dtypes():
 | |
|             return (
 | |
|                 torch.BoolTensor([2]),
 | |
|                 torch.LongTensor([3]),
 | |
|                 torch.ByteTensor([4]),
 | |
|                 torch.CharTensor([5]),
 | |
|                 torch.DoubleTensor([6]),
 | |
|                 torch.FloatTensor([7]),
 | |
|                 torch.IntTensor([8]),
 | |
|                 torch.ShortTensor([1]),
 | |
|                 torch.HalfTensor([1]),
 | |
|             )
 | |
| 
 | |
|         self.checkScript(test_all_dtypes, ())
 | |
| 
 | |
|         # now test empty overload
 | |
|         def empty_overload():
 | |
|             return torch.LongTensor(2, 3, 4)
 | |
| 
 | |
|         eager = empty_overload()
 | |
|         jit = torch.jit.script(empty_overload)()
 | |
|         eager[:] = 1
 | |
|         jit[:] = 1
 | |
|         self.assertEqual(eager, jit)
 | |
| 
 | |
|         def no_inputs():
 | |
|             return torch.DoubleTensor()
 | |
| 
 | |
|         self.checkScript(no_inputs, ())
 | |
| 
 | |
|         # bad schema
 | |
|         def multiple_args():
 | |
|             return torch.LongTensor(1, [2])
 | |
| 
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError, "multiple positional arguments that were not all integers"
 | |
|         ):
 | |
|             torch.jit.script(multiple_args)
 | |
| 
 | |
|         # kwarg bad schema
 | |
|         def bad_kwarg():
 | |
|             return torch.LongTensor(hello="1")
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "hello"):
 | |
|             torch.jit.script(bad_kwarg)
 | |
| 
 | |
|     def test_broadcasting_list(self):
 | |
|         """
 | |
|         Test BroadcastingList and torch.nn._size_N_t alias
 | |
|         """
 | |
|         from torch._jit_internal import BroadcastingList2
 | |
|         from torch.nn.common_types import _size_2_t
 | |
| 
 | |
|         def sum_i(x: _size_2_t) -> int:
 | |
|             return x[0] + x[1]
 | |
| 
 | |
|         def sum_f(x: BroadcastingList2[float]) -> float:
 | |
|             return x[0] + x[1]
 | |
| 
 | |
|         self.assertTrue(torch.jit.script(sum_i)(4) == 8)
 | |
|         self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
 | |
| 
 | |
|     def test_parse_ir_annotate(self):
 | |
|         ir = """
 | |
|         graph():
 | |
|           %3 : int[] = prim::Constant[value=annotate(List[int], [])]()
 | |
|           return (%3)
 | |
|         """
 | |
|         graph = torch._C.parse_ir(ir, True)
 | |
|         func = torch._C._create_function_from_graph("forward", graph)
 | |
|         ret = func()
 | |
|         self.assertTrue(ret == [])
 | |
| 
 | |
|     def test_parse_ir_single_element_tensor_positive(self):
 | |
|         ir = """
 | |
|         graph():
 | |
|           %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
 | |
|           return (%7)
 | |
|         """
 | |
|         graph = torch._C.parse_ir(ir, True)
 | |
|         func = torch._C._create_function_from_graph("forward", graph)
 | |
|         ret = func()
 | |
|         self.assertTrue(ret.numel() == 1)
 | |
|         self.assertTrue(len(ret.size()) == 1)
 | |
| 
 | |
|     def test_parse_ir_single_element_tensor_negative(self):
 | |
|         ir = """
 | |
|         graph():
 | |
|           %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
 | |
|           return (%7)
 | |
|         """
 | |
|         graph = torch._C.parse_ir(ir, True)
 | |
|         func = torch._C._create_function_from_graph("forward", graph)
 | |
|         ret = func()
 | |
|         self.assertTrue(ret.numel() == 1)
 | |
|         self.assertTrue(len(ret.size()) == 1)
 | |
| 
 | |
|     def test_script_many_decorators(self):
 | |
|         def no_op_decorator(f):
 | |
|             return f
 | |
| 
 | |
|         @no_op_decorator
 | |
|         @no_op_decorator
 | |
|         @no_op_decorator
 | |
|         @no_op_decorator
 | |
|         @no_op_decorator
 | |
|         def foo(x, dim: int):
 | |
|             return x.unsqueeze(dim)
 | |
| 
 | |
|         x = torch.randn(
 | |
|             1,
 | |
|         )
 | |
|         expected = foo(x, 0)
 | |
|         scripted = torch.jit.script(foo)
 | |
|         actual = scripted(x, 0)
 | |
|         torch.testing.assert_close(expected, actual)
 | |
| 
 | |
|     @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support")
 | |
|     def test_pow_multiple_dtype(self):
 | |
|         # https://github.com/pytorch/pytorch/issues/75476
 | |
|         def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
 | |
|             p = torch.sigmoid(p)
 | |
|             result = p**gamma
 | |
|             return result
 | |
| 
 | |
|         x = torch.rand((2, 2), dtype=torch.half, device="cuda")
 | |
| 
 | |
|         ref = fn(x)
 | |
| 
 | |
|         script_fn = torch.jit.script(fn)
 | |
|         for i in range(4):
 | |
|             res = script_fn(x)
 | |
| 
 | |
|         self.assertEqual(ref, res)
 | |
| 
 | |
|     def test_jit_get_operation_order(self):
 | |
|         # See https://github.com/pytorch/pytorch/pull/107138.
 | |
|         # Depending on order of operator registration, you can get different
 | |
|         # order of overloads in the JIT operator registry.
 | |
|         # This is to verify that the order of operators returned by
 | |
|         # _jit_get_operation always puts aten ops first (i.e. by sorting
 | |
|         # to put them first)
 | |
| 
 | |
|         # Make sure that this chooses a "scalar" overload not a "complex" overload
 | |
|         ret = torch.ops.aten.add(4, 3.3)
 | |
|         self.assertFalse("complex" in str(ret.dtype))
 | |
| 
 | |
|         # "Scalar" overload is a normal aten op; "complex" is added by torchscript.
 | |
|         # We want "Scalar" to come before "complex".
 | |
|         op, override_names = torch._C._jit_get_operation("aten::add")
 | |
|         print(override_names)
 | |
|         complex_indices = [
 | |
|             i for i, name in enumerate(override_names) if name == "complex"
 | |
|         ]
 | |
|         Scalar_indices = [
 | |
|             i for i, name in enumerate(override_names) if name == "Scalar"
 | |
|         ]
 | |
| 
 | |
|         self.assertTrue(len(complex_indices) > 0)
 | |
|         self.assertTrue(len(Scalar_indices) > 0)
 | |
|         self.assertTrue(complex_indices[0] > Scalar_indices[0])
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     raise_on_run_directly("test/test_jit.py")
 |