# Owner(s): ["module: inductor"] import unittest from sympy import Symbol, sympify import torch from torch._inductor.fx_utils import count_flops_fx, countable_fx from torch._inductor.utils import get_device_tflops, sympy_str, sympy_subs from torch._inductor.virtualized import V from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, ) from torch.testing._internal.common_utils import run_tests, TestCase class TestUtils(TestCase): def test_zip_schema(self): def foo(x: torch.Tensor) -> None: pass result = torch.library.custom_op("mylib::foo", foo, mutates_args={"x"}) schema = result._opoverload._schema g = torch.tensor([11, 2]) found = False for arg, val in torch._library.utils.zip_schema(schema, [], {"x": g}): if arg.name == "x": found = True self.assertTrue(found) found = False for arg, val in torch._library.utils.zip_schema(schema, [g], {}): if arg.name == "x": found = True self.assertTrue(found) def testSympySubs(self): # integer and nonnegetaive attributes are preserved. expr = Symbol("x") result = sympy_subs(expr, {expr: "y"}) self.assertEqual(result.name, "y") self.assertEqual(result.is_integer, None) self.assertEqual(result.is_nonnegative, None) expr = Symbol("x", integer=True, nonnegative=False) result = sympy_subs(expr, {expr: "y"}) self.assertEqual(result.name, "y") self.assertEqual(result.is_integer, True) self.assertEqual(result.is_nonnegative, False) # invalid replacement. expr = Symbol("x", integer=True) result = sympy_subs(expr, {Symbol("x"): Symbol("y")}) self.assertEqual(result.name, "x") # valid replacement since properties match. expr = Symbol("x", integer=True) result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")}) self.assertEqual(result.name, "y") # invalid replacement. expr = Symbol("x", integer=None) result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")}) self.assertEqual(result.name, "x") # replaced cant be string self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"}) # replaced can be an expression expr = Symbol("x") expr = abs(expr) self.assertEqual(expr.is_integer, None) self.assertEqual(expr.is_nonnegative, None) # replace abs(x) with y # propagte abs(x) sympy properties. result = sympy_subs(expr, {expr: Symbol("y")}) self.assertEqual(result.name, "y") self.assertEqual(result.is_integer, None) self.assertEqual(result.is_nonnegative, None) def test_sympy_str(self): self.assertEqual(sympy_str(sympify("a+b+c")), "a + b + c") self.assertEqual(sympy_str(sympify("a*b+c")), "c + a * b") self.assertEqual(sympy_str(sympify("a+b*(c+d)")), "a + b * (c + d)") self.assertEqual(sympy_str(sympify("(a+b)*(c+d)")), "(a + b) * (c + d)") self.assertEqual(sympy_str(sympify("-a")), "-a") self.assertEqual(sympy_str(sympify("a-b")), "a - b") self.assertEqual(sympy_str(sympify("a+-b")), "a - b") def test_flops_fx(self): def create_fx_node( aten: torch._ops.OpOverloadPacket, args, kwargs ) -> tuple[torch.fx.Node, torch.fx.Node]: node1 = torch.fx.Node( graph=torch.fx.Graph(), name="", op="call_function", target=aten, args=args, kwargs=kwargs, ) name: str = aten.overloads()[0] op_overload: torch._ops.OpOverload = getattr(aten, name) node2 = torch.fx.Node( graph=torch.fx.Graph(), name="", op="call_function", target=op_overload, args=args, kwargs=kwargs, ) return node1, node2 with V.set_fake_mode( torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) ): trues = [ ( torch.ops.aten.addmm, (torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)), {}, ), ( torch.ops.aten.bmm, (torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)), {}, ), (torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}), ( torch.ops.aten.convolution, ( torch.Tensor(2, 3, 3), torch.Tensor(2, 2, 2), torch.Tensor(2), (1, 1), (0, 0), (1, 1), True, (0, 0), 1, ), {}, ), ( torch.ops.aten._convolution, ( torch.Tensor(2, 2, 2), torch.Tensor(2, 2, 2), torch.Tensor(2), (1,), (0,), (1,), True, (0,), 1, False, True, False, ), {}, ), ] # we don't support pointwise ops falses = [ ( torch.ops.aten.add, (torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)), {}, ), ( torch.ops.aten.mul, (torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)), {}, ), ] for t, args, kwargs in trues: fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs) self.assertTrue( countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}" ) self.assertTrue( countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}" ) self.assertNotEqual(count_flops_fx(fx_node_1), None) self.assertNotEqual(count_flops_fx(fx_node_2), None) for f, args, kwargs in falses: fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs) self.assertFalse( countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}" ) self.assertFalse( countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}" ) @unittest.skipIf(not torch.cuda.is_available(), "skip if no device") @dtypes(torch.float16, torch.bfloat16, torch.float32) def test_get_device_tflops(self, dtype): ret = get_device_tflops(dtype) self.assertTrue(type(ret) == float) instantiate_device_type_tests(TestUtils, globals()) if __name__ == "__main__": run_tests()