diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index b4cd35b733d1..d56a42b8252f 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import os +import sys import unittest import sympy @@ -262,14 +263,19 @@ class ExprPrinterTests(InductorTestCase): cpu_cases = common_cases + [ ( sympy.Pow(s1 + s2, 2), - lambda c, L: "static_cast((bar + foo)*(bar + foo))", + lambda c, L: "static_cast((bar + foo)*(bar + foo))", ) ] for expr, result in gpu_cases: self.assertEqual(texpr(expr), result(1, "")) self.assertEqual(pexpr(expr), result(1, "")) for expr, result in cpu_cases: - self.assertEqual(cexpr(expr), result(1.0, "L")) # 1.0 for FP div + self.assertEqual( + cexpr(expr), + result(1.0, "LL") + if sys.platform in ["darwin", "win32"] + else result(1.0, "L"), + ) # 1.0 for FP div def test_print_floor(self): for integer in [True, False]: @@ -278,7 +284,7 @@ class ExprPrinterTests(InductorTestCase): if integer: self.assertEqual(pexpr(expr), "math.floor((1/2)*s1)") self.assertEqual( - cexpr(expr), "static_cast(std::floor((1.0/2.0)*s1))" + cexpr(expr), "static_cast(std::floor((1.0/2.0)*s1))" ) else: self.assertExpectedInline(pexpr(expr), """math.floor((1/2)*s1)""") @@ -295,7 +301,7 @@ class ExprPrinterTests(InductorTestCase): if integer: self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""") self.assertExpectedInline( - cexpr(expr), """static_cast(std::ceil((1.0/2.0)*s1))""" + cexpr(expr), """static_cast(std::ceil((1.0/2.0)*s1))""" ) else: self.assertExpectedInline(pexpr(expr), """math.ceil((1/2)*s1)""") @@ -325,13 +331,19 @@ class ExprPrinterTests(InductorTestCase): s2 = sympy.Symbol("s2", integer=True) expr = FloorDiv(s1, s2) self.assertEqual(pexpr(expr), "(s1 // s2)") - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") + self.assertEqual( + cexpr(expr), + "c10::div_floor_integer(static_cast(s1), static_cast(s2))", + ) s1 = sympy.Symbol("s1", integer=True) s2 = sympy.S(-1) expr = FloorDiv(s1, s2) self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") + self.assertEqual(cexpr(expr), "(-1LL)*s1") if sys.platform in [ + "darwin", + "win32", + ] else "(-1L)*s1" def test_print_Min_Max(self): cases = ( @@ -344,14 +356,24 @@ class ExprPrinterTests(InductorTestCase): self.assertEqual( texpr(expr), f"((-2) * ((-2) {cmp}= (x)) + (x) * ((x) {cmp} (-2)))" ) - self.assertEqual(cexpr(expr), f"std::{s}(-2L, x)") + self.assertEqual( + cexpr(expr), + f"std::{s}(static_cast(-2LL), static_cast(x))" + if sys.platform in ["darwin", "win32"] + else f"std::{s}(static_cast(-2L), static_cast(x))", + ) expr = f(x, 2 * x, 3 * x) self.assertEqual( texpr(expr), f"((x) * ((x) {cmp}= (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x))))) + (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) * ((((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) {cmp} (x)))", # noqa: B950 line too long ) - self.assertEqual(cexpr(expr), f"std::{s}({{x, 2L*x, 3L*x}})") + self.assertEqual( + cexpr(expr), + f"std::{s}({{x, 2LL*x, 3LL*x}})" + if sys.platform in ["darwin", "win32"] + else f"std::{s}({{x, 2L*x, 3L*x}})", + ) instantiate_parametrized_tests(ExprPrinterTests) diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index fb899d924887..df125324e897 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -69,11 +69,11 @@ class TestMemoryPlanning(TestCase): result, code = run_and_get_cpp_code(compiled, *args) FileCheck().check( - "pool1 = at::detail::empty_strided_cuda({(4L*s0*s1) + (align(4L*(static_cast(s0*s0)))), }, {1L, }" + "pool1 = at::detail::empty_strided_cuda({(4L*s0*s1) + (align(4L*(static_cast(s0*s0)))), }, {1L, }" ).check_next( "auto buf0 = alloc_from_pool(pool1, 0, at::kFloat, {s0, s0}, {s0, 1L});" ).check( - "auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast(s0*s0)))," + "auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast(s0*s0)))," ).run( code ) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 158a6272c8e4..26d5ad6afaaa 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2227,7 +2227,7 @@ class CppPythonBindingsCodeCache(CppCodeCache): static_assert(std::is_pointer::value, "arg type must be pointer or long"); return static_cast(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); } - template <> inline long parse_arg(PyObject* args, size_t n) { + template <> inline int64_t parse_arg(PyObject* args, size_t n) { auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); if(unlikely(result == -1 && PyErr_Occurred())) throw std::runtime_error("expected int arg"); diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index e5e9e4cd6b08..2a8e291565be 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -83,7 +83,7 @@ LAYOUT_TO_ATEN = { _IS_WINDOWS = sys.platform == "win32" -INDEX_TYPE = "int64_t" if _IS_WINDOWS else "long" +INDEX_TYPE = "int64_t" GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) @@ -222,7 +222,9 @@ class CppCSEVariable(CSEVariable): class CppPrinter(ExprPrinter): def _print_Integer(self, expr): - return f"{int(expr)}LL" if _IS_WINDOWS else f"{int(expr)}L" + return ( + f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L" + ) def _print_Where(self, expr): c = self.paren(self.doprint(expr.args[0])) @@ -236,7 +238,7 @@ class CppPrinter(ExprPrinter): if div != 1: div = self.paren(self.doprint(div)) if expr.is_integer: - x = f"c10::div_floor_integer({x}, {div})" + x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" else: x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" mod = self.paren(self.doprint(mod)) @@ -247,7 +249,7 @@ class CppPrinter(ExprPrinter): x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) if expr.is_integer: - return f"c10::div_floor_integer({x}, {div})" + return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" def _print_floor(self, expr): @@ -345,7 +347,7 @@ class CppPrinter(ExprPrinter): def _print_Min(self, expr): args = [self._print(a) for a in expr.args] if len(args) == 2: - return f"std::min({args[0]}, {args[1]})" + return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" else: # Initializer list overload il = "{" + ", ".join(args) + "}" @@ -354,7 +356,7 @@ class CppPrinter(ExprPrinter): def _print_Max(self, expr): args = [self._print(a) for a in expr.args] if len(args) == 2: - return f"std::max({args[0]}, {args[1]})" + return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" else: # Initializer list overload il = "{" + ", ".join(args) + "}"