Move Sympy printers to torch/utils/_sympy/printers.py (#140597)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
Isuru Fernando
2024-11-25 22:22:07 +00:00
committed by PyTorch MergeBot
parent 29ca44839e
commit 44186a0a4e
20 changed files with 589 additions and 579 deletions

View File

@ -580,8 +580,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
.check_regex(
"torch.ops._c10d_functional.all_to_all_single.default\\("
"arg\\d+_\\d+, "
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
"\\[s\\d+ // \\d, s\\d+ // \\d\\], "
"\\[s\\d+ // \\d, s\\d+ // \\d\\]"
)
.run(code)
)

View File

@ -3570,7 +3570,7 @@ class GraphModule(torch.nn.Module):
"cast_symbool_to_symint_guardless(L['pred']) == 1",
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"cast_symbool_to_symint_guardless(L['pred']) != 1",
]
test_symbool_guards(
f,

View File

@ -668,7 +668,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
"""\
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- LAMBDA_GUARD: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
)

View File

@ -10457,7 +10457,7 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < Mod(s0, 3): False, 0 <= Mod(s0, 3): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Mod(s0, 3) < 0: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
> Right: {}
==> divisible: values don't match.
> Left: {Mod(s0, 3)}
@ -10576,7 +10576,7 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < PythonMod(u0, 3): False, 0 <= PythonMod(u0, 3): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) < 0: False, PythonMod(u0, 3) <= 0: True}
> Left: {(PythonMod(u0, 3)) < 0: False, (PythonMod(u0, 3)) <= 0: True, 0 < (PythonMod(u0, 3)): False, 0 <= (PythonMod(u0, 3)): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False}
> Right: {}
==> deferred_runtime_asserts: values don't match.
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}

View File

@ -3259,9 +3259,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
(torch.tensor(20),),
fixes=[
# Could not guard on data-dependent expression Eq((u0//2), 0)
"torch._check(((i//2)) != 0)",
"torch._check((i // 2) != 0)",
# Could not guard on data-dependent expression Eq((u0//2), 1)
"torch._check(((i//2)) != 1)",
"torch._check((i // 2) != 1)",
],
)

View File

@ -1426,12 +1426,12 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 20
x1 = (xindex // 20) % 20
x2 = (xindex // 400)
x0 = (xindex % 20)
x1 = ((xindex // 20) % 20)
x2 = xindex // 400
x3 = xindex
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
)

View File

@ -20,7 +20,9 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
from torch.utils._sympy.functions import (
FloorDiv,
Mod,
ModularIndexing,
PythonMod,
RoundDecimal,
RoundToInt,
)
@ -236,7 +238,7 @@ class TestIndexingSimplification(InductorTestCase):
triton_code = run_and_get_triton_code(f, x)
# Make sure the 2 load uses simpified indexing rather than something like
# tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),"))
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),"))
if DO_PERF_TEST:
ms = benchmarker.benchmark_gpu(lambda: f(x))
print(f"{ms=:.03f}")
@ -313,6 +315,39 @@ class ExprPrinterTests(InductorTestCase):
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
def test_print_mod(self):
x = sympy.Symbol("x", integer=True)
expr = Mod(x - 1, 2)
self.assertExpectedInline(pexpr(expr), """((-1) + x) % 2""")
self.assertExpectedInline(cexpr(expr), """((-1L) + x) % 2L""")
self.assertExpectedInline(texpr(expr), """((-1) + x) % 2""")
expr = (x - 10) % x
self.assertExpectedInline(pexpr(expr), """(-10) % x""")
self.assertExpectedInline(cexpr(expr), """(-10L) % x""")
self.assertExpectedInline(texpr(expr), """(-10) % x""")
def test_print_mod_index(self):
x = sympy.Symbol("x", integer=True)
ks = sympy.Symbol("ks", integer=True)
expr = ModularIndexing(x - 10, ks, ks)
self.assertExpectedInline(pexpr(expr), """((((-10) + x) // ks) % ks)""")
self.assertExpectedInline(
cexpr(expr),
"""(static_cast<int64_t>(c10::div_floor_integer("""
"""static_cast<int64_t>((-10L) + x), static_cast<int64_t>(ks))) % static_cast<int64_t>(ks))""",
)
self.assertExpectedInline(texpr(expr), """((((-10) + x) // ks) % ks)""")
def test_print_python_mod(self):
x = sympy.Symbol("x", integer=True)
expr = PythonMod(x - 10, x)
self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""")
self.assertExpectedInline(cexpr(expr), """((-10L) + x) % x""")
self.assertExpectedInline(
texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)"""
)
@parametrize("ndigits", [-1, 0, 1])
def test_print_round_decimal(self, ndigits):
expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits)
@ -330,7 +365,7 @@ class ExprPrinterTests(InductorTestCase):
s1 = sympy.Symbol("s1", integer=True)
s2 = sympy.Symbol("s2", integer=True)
expr = FloorDiv(s1, s2)
self.assertEqual(pexpr(expr), "(s1 // s2)")
self.assertEqual(pexpr(expr), "s1 // s2")
self.assertEqual(
cexpr(expr),
"c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",

View File

@ -58,13 +58,11 @@ class TestMemoryPlanning(TestCase):
result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check(
"pool1 = empty_strided_"
+ GPU_TYPE
+ "(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
).check_next(
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
).check(
"buf1 = alloc_from_pool(pool1, align(4*(s0*s0)),"
"buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
).run(
code
)
@ -103,7 +101,7 @@ class TestMemoryPlanning(TestCase):
)
FileCheck().check(
"int64_t int_array_2[] = {24L + (align(12L*s0)), };"
"int64_t int_array_2[] = {24L + align(12L*s0), };"
).check_next("int64_t int_array_3[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;"
).check_next(

View File

@ -487,7 +487,7 @@ class PaddingTest(TestCaseBase):
# make sure the load for softmax is aligned
self.assertTrue(
"tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper,
"tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -12505,8 +12505,8 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""",
)
@config.patch("triton.use_block_ptr", True)
@ -12538,7 +12538,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
)

View File

@ -275,7 +275,7 @@ class TritonBlockPointerTest(InductorTestCase):
"\n".join(load_lines),
"""\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
)
self.assertExpectedInline(
"\n".join(store_lines),

View File

@ -2792,8 +2792,8 @@ class TestGuardsExpressions(TestCase):
guard_int(sym_int(s0 / 2.0))
guards = shape_env.produce_guards_expression([s0])
self.assertIn("ToFloat", guards)
self.assertIn("FloatTrueDiv", guards)
self.assertIn("math.trunc(", guards)
self.assertIn("float(", guards)
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))

View File

@ -23,7 +23,6 @@ from typing import (
)
import sympy
from sympy.printing.printer import Printer
import torch
import torch.fx
@ -31,6 +30,7 @@ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
@ -609,12 +609,22 @@ class DataTypePropagation:
DataTypePropagation.propagate_loopbody(node._body)
# This printer contains rules that are supposed to be generic for both C/C++ and
# Python
class ExprPrinter(Printer):
class PythonPrinter(_PythonPrinter):
def doprint(self, expr, *, simplify: bool = True, p=True):
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
@staticmethod
def paren(string):
def all_in_parens(string):
def paren(string: str) -> str:
def all_in_parens(string: str) -> bool:
if string[0] != "(" or len(string) < 2:
return False
count = 1
@ -640,260 +650,6 @@ class ExprPrinter(Printer):
return string
return f"({string})"
def _print_Relational(self, expr):
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mul(self, expr):
return "*".join(map(self.paren, map(self._print, expr.args)))
def _print_Add(self, expr):
return " + ".join(map(self.paren, map(self._print, expr.args)))
# NB: this is OK to put here, because Mod is only defined for positive
# numbers, and so across C/Python its behavior is consistent
def _print_Mod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr)
def _print_Identity(self, expr):
return self._print(expr.args[0])
def _print_GreaterThan(self, expr):
# GreaterThan: >=
# StrictlyGreaterThan: >
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))
# NB: The C implementation is injected into codegen at
# torch/_inductor/codegen/wrapper.py
def _print_align(self, expr):
assert len(expr.args) == 1
return f"align({self._print(expr.args[0])})"
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
# any explicit intervention. We print it just like x * x, notably, we
# never generate sympy.Pow with floats.
#
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
def _print_Pow(self, expr):
base, exp = expr.args
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
assert exp >= 0
if exp > 0:
return "*".join([self.paren(base)] * exp)
return "1"
# Explicit NotImplemented functions are to prevent default sympy printing
# behavior, which will just barf out ToFloat(...) to your IR. The error
# message is better here because it tells you which printer class it needs
# to go in.
def _print_ToFloat(self, expr):
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
def _print_Infinity(self, expr):
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
def _print_NegativeInfinity(self, expr):
raise NotImplementedError(
f"_print_NegativeInfinity not implemented for {type(self)}"
)
def _print_FloorDiv(self, expr):
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_PythonMod(self, expr):
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
def _print_IntTrueDiv(self, expr):
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr):
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
def _print_TruncToInt(self, expr):
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
def _print_RoundToInt(self, expr):
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
def _print_RoundDecimal(self, expr):
raise NotImplementedError(
f"_print_RoundDecimal not implemented for {type(self)}"
)
# NB: Some float operations are INTENTIONALLY not implemented for
# printers. You can implement them as a quick unblock, but it is better
# to ask yourself why we haven't done this computation in the Tensor
# universe instead
def _print_TruncToFloat(self, expr):
raise NotImplementedError(
f"_print_TruncToFloat not implemented for {type(self)}"
)
def doprint(self, expr, *, simplify: bool = True):
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
class PythonPrinter(ExprPrinter):
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"float({self._print(expr.args[0])})"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
mod = self.paren(self.doprint(mod))
if div != "1":
x = f"({x} // {div})"
return f"{x} % {mod}"
def _print_Infinity(self, expr):
return "math.inf"
def _print_NegativeInfinity(self, expr):
return "-math.inf"
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
# does a special algorithm
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _helper_sqrt(self, expr):
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return self._helper_sqrt(expr.args[0])
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
# TODO: Not sure this works with Triton, even when base/exp are integral
def _print_PowByNatural(self, expr):
base, exp = expr.args
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
def _print_floor(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
# This also could have been int(), they'll do the same thing for float
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
# NB: It's expected that we've made explicit any promotion in the sympy
# expression, so it doesn't matter that Python max/min doesn't perform
# promotion
def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
def __getattr__(self, item):
return getattr(self._parent, item)
@ -982,31 +738,31 @@ class OpOverrides:
@staticmethod
def bitwise_not(x):
return f"~{ExprPrinter.paren(x)}"
return f"~{OpOverrides.paren(x)}"
@staticmethod
def logical_not(a):
return f"{ExprPrinter.paren(a)} == 0"
return f"{OpOverrides.paren(a)} == 0"
@staticmethod
def bitwise_and(x, y):
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
@staticmethod
def bitwise_or(x, y):
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
@staticmethod
def bitwise_xor(x, y):
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
@staticmethod
def bitwise_left_shift(x, y):
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
@staticmethod
def bitwise_right_shift(x, y):
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
@staticmethod
def remainder(a, b):

View File

@ -13,6 +13,7 @@ import sympy
import torch
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.printers import CppPrinter as _CppPrinter
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
@ -25,7 +26,6 @@ from ..virtualized import ops, OpsValue, V
from .common import (
CSEVariable,
deduce_output_dtype_by_name,
ExprPrinter,
Kernel,
KernelArgs,
OptimizationContext,
@ -232,212 +232,12 @@ class CppCSEVariable(CSEVariable):
return itervar in self.dependent_itervars
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return (
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
)
def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
p = self.paren(self.doprint(expr.args[1]))
q = self.paren(self.doprint(expr.args[2]))
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
if div != 1:
div = self.paren(self.doprint(div))
if expr.is_integer:
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.paren(self.doprint(mod))
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
if expr.is_integer:
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr):
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_CMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr):
# Uses float constants to perform FP div
base, exp = expr.args
base = self._print(base)
if exp == 0.5 or exp == -0.5:
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = "*".join([self.paren(base)] * exp)
elif exp < 0:
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr):
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return f"std::sqrt({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
class CppPrinter(_CppPrinter):
def doprint(self, expr, *, simplify: bool = True, p=True):
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
# A function to print, useful for printing sympy symbols.

View File

@ -185,8 +185,8 @@ class HalidePrinter(PythonPrinter):
return super()._print_FloorDiv(expr)
x, div = expr.args
x = self.cast_float(self.paren(self.doprint(x)))
div = self.cast_float(self.paren(self.doprint(div)))
x = self.cast_float(self.doprint(x))
div = self.cast_float(self.doprint(div))
return self.cast_index(f"hl.floor({x} / {div})")
def _print_Round(self, expr):

View File

@ -27,6 +27,7 @@ from typing import (
)
import sympy
from sympy.printing.precedence import PRECEDENCE
import torch
import torch._logging
@ -504,30 +505,30 @@ class TritonPrinter(PythonPrinter):
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"
s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5)
return f"{s}.to(tl.float64)"
def _print_PythonMod(self, expr):
quot, div = expr.args
if quot.is_nonnegative and div.is_nonnegative:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
quot_s = self._print(quot)
div_s = self._print(div)
if quot.is_nonnegative and div.is_nonnegative:
return f"{self.paren(quot_s)} % {self.paren(div_s)}"
return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
def _print_FloorDiv(self, expr):
assert expr.is_integer
quot, div = expr.args
if quot.is_nonnegative and div.is_nonnegative:
return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5)
quot_s = self._print(quot)
div_s = self._print(div)
if quot.is_nonnegative and div.is_nonnegative:
return f"({self.paren(quot_s)} // {self.paren(div_s)})"
return f"triton_helpers.div_floor_integer({quot_s}, {div_s})"
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
# precision algorithm, which we would need to replicate here
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
# NB: sympy.floor/ceiling produce integers, so we have to do the
# conversion to index dtype
@ -646,7 +647,9 @@ class TritonPrinter(PythonPrinter):
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}"
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}"
texpr = TritonPrinter().doprint

View File

@ -35,13 +35,12 @@ from .autotune_process import (
TritonGPUBenchmarkRequest,
)
from .codecache import code_hash, PersistentCache, PyCodeCache
from .codegen.common import IndentedBuffer, KernelTemplate, WorkspaceArg
from .codegen.common import IndentedBuffer, KernelTemplate, OpOverrides, WorkspaceArg
from .codegen.simd_kernel_features import SIMDKernelFeatures
from .codegen.triton import (
gen_common_triton_imports,
texpr,
TritonKernel,
TritonPrinter,
TritonScheduling,
)
from .codegen.triton_utils import config_of, signature_to_meta
@ -562,7 +561,7 @@ class TritonTemplateKernel(TritonKernel):
assert isinstance(val, str)
assert isinstance(mask, (str, type(None)))
assert self.template_mask is None
indices = list(map(TritonPrinter.paren, indices))
indices = list(map(OpOverrides.paren, indices))
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
lengths = [
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
@ -630,7 +629,7 @@ class TritonTemplateKernel(TritonKernel):
assert isinstance(name, str)
assert isinstance(mask, str)
stride = self.named_input_nodes[name].get_stride()
indices = list(map(TritonPrinter.paren, indices))
indices = list(map(OpOverrides.paren, indices))
assert len(indices) == len(stride)
index = " + ".join(
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)

View File

@ -10,7 +10,6 @@ need to make use of these APIs to setup dynamic shapes support appropriately.
"""
import atexit
import builtins
import collections
import functools
import inspect
@ -82,6 +81,7 @@ from torch.utils._sympy.functions import (
PythonMod,
)
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT
@ -109,8 +109,6 @@ log = logging.getLogger(__name__)
import sympy
from sympy import S
from sympy.printing.precedence import PRECEDENCE, precedence
from sympy.printing.str import StrPrinter
class GuardOnDataDependentSymNode(RuntimeError):
@ -2032,45 +2030,9 @@ def cast_symbool_to_symint_guardless(
SYMPY_INTERP = {
"Abs": operator.abs,
"Eq": operator.eq,
"Ne": operator.ne,
"Gt": operator.gt,
"Lt": operator.lt,
"Le": operator.le,
"Ge": operator.ge,
"Min": min,
"Max": max,
"Mod": operator.mod,
"PythonMod": operator.mod,
"FloorDiv": operator.floordiv,
"TrueDiv": operator.truediv,
"PowByNatural": operator.pow,
"IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense,
"floor": math.floor,
"ceiling": math.ceil,
"FloorToInt": math.floor,
"FloatPow": math.pow,
"CeilToInt": math.ceil,
"cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless,
"RoundToInt": builtins.round,
"RoundDecimal": builtins.round,
"TruncToInt": math.trunc,
"IntTrueDiv": operator.truediv,
"FloatTrueDiv": operator.truediv,
"ToFloat": builtins.float,
"OpaqueUnaryFn_cos": math.cos,
"OpaqueUnaryFn_cosh": math.cosh,
"OpaqueUnaryFn_acos": math.acos,
"OpaqueUnaryFn_sin": math.sin,
"OpaqueUnaryFn_sinh": math.sinh,
"OpaqueUnaryFn_asin": math.asin,
"OpaqueUnaryFn_tan": math.tan,
"OpaqueUnaryFn_tanh": math.tanh,
"OpaqueUnaryFn_atan": math.atan,
"OpaqueUnaryFn_sqrt": math.sqrt,
"BitwiseFn_bitwise_and": operator.and_,
"BitwiseFn_bitwise_or": operator.or_,
"math": math,
}
@ -2141,12 +2103,12 @@ class RuntimeAssert:
# Used for printing SymExprs in compile_fx
class SymExprPrinter(StrPrinter):
class SymExprPrinter(PythonPrinter):
def _print_Float(self, expr: sympy.Float) -> str:
return str(float(expr))
class ShapeGuardPrinter(SymExprPrinter):
class ShapeGuardPrinter(PythonPrinter):
def __init__(
self,
symbol_to_source: Mapping[sympy.Symbol, List[Source]],
@ -2158,14 +2120,8 @@ class ShapeGuardPrinter(SymExprPrinter):
self.source_ref = source_ref
self.var_to_sources = var_to_sources
def _print_Not(self, expr: SympyBoolean) -> str:
return "not {}".format(self.parenthesize(expr.args[0], PRECEDENCE["Not"]))
def _print_And(self, expr: SympyBoolean) -> str:
return self.stringify(expr.args, " and ", PRECEDENCE["And"])
def _print_Or(self, expr: SympyBoolean) -> str:
return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
def _print_Float(self, expr: sympy.Float) -> str:
return str(float(expr))
def _print_Symbol(self, expr: sympy.Symbol) -> str:
assert isinstance(expr, sympy.Symbol), str(type(expr))
@ -2191,7 +2147,7 @@ class LoggingShapeGuardPrinter(ShapeGuardPrinter):
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
class DynamicDimConstraintPrinter(StrPrinter):
class DynamicDimConstraintPrinter(PythonPrinter):
"""
Printer for dynamic dim constraints.
- Instead of symbol s_k it prints its source t.size()[i]
@ -2216,9 +2172,6 @@ class DynamicDimConstraintPrinter(StrPrinter):
), f"Unknown symbol {expr} created by constraints solver"
return self.symbol_to_source[expr][0].name()
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
return f"{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}" # type: ignore[attr-defined]
class DimConstraints:
"""
@ -6656,7 +6609,7 @@ def _blame_user_code(e: Exception, frame: types.FrameType) -> None:
e.args = (msg,)
class _PythonPrinter(sympy.printing.str.StrPrinter):
class _PythonMsgPrinter(PythonPrinter):
"""
Util printer that replaces sympy symbols with their source-level names
and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline
@ -6670,13 +6623,6 @@ class _PythonPrinter(sympy.printing.str.StrPrinter):
def _print_Symbol(self, sym: sympy.Symbol) -> str:
return self.src_map[sym.name][0]
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr))
assert hasattr(expr, "rel_op")
rel_op = expr.rel_op
rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr))
return f"{lhs} {rel_op} {rhs}"
def _suggest_torch_checks(
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]]
@ -6687,7 +6633,7 @@ def _suggest_torch_checks(
if diff:
log.warning("Unable to find user code corresponding to {%s}", diff)
return
printer = _PythonPrinter(src_map)
printer = _PythonMsgPrinter(src_map)
msg = e.args[0]
msg += "\nTo fix the error, insert one of the following checks before this call:"
# suggested fixes to resolve `cond`` are to tell the compiler to assume

View File

@ -191,7 +191,7 @@ class FloorDiv(sympy.Function):
"""
nargs: Tuple[int, ...] = (2,)
precedence: int = 50 # precedence of mul # noqa: F811
precedence: int = 35 # lower precedence than add
is_integer: bool = True
@property
@ -291,6 +291,7 @@ class ModularIndexing(sympy.Function):
nargs: Tuple[int, ...] = (3,)
is_integer: bool = True
precedence: int = 35 # lower precedence than add
@classmethod
def eval(
@ -360,6 +361,7 @@ class Where(sympy.Function):
"""
nargs: Tuple[int, ...] = (3,)
precedence: int = 35 # lower precedence than add
def _eval_is_integer(self) -> Optional[bool]:
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
@ -389,6 +391,7 @@ class Where(sympy.Function):
class PythonMod(sympy.Function):
nargs: Tuple[int, ...] = (2,)
precedence: int = 35 # lower precedence than add
is_integer: bool = True
@classmethod
@ -447,6 +450,7 @@ class PythonMod(sympy.Function):
# Generic modulus: only defined on non-negative arguments
class Mod(sympy.Function):
nargs = (2,)
precedence: int = 35 # lower precedence than add
is_integer = True
is_nonnegative = True
@ -1014,6 +1018,8 @@ def _safe_pow(base, exponent):
class PowByNatural(sympy.Function):
is_integer = True
precedence: int = 50 # precedence of mul
@classmethod
def eval(cls, base, exp):
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
@ -1039,6 +1045,8 @@ class PowByNatural(sympy.Function):
class FloatPow(sympy.Function):
is_real = True
precedence: int = 60 # precedence of pow
@classmethod
def eval(cls, base, exp):
# NB: These test sympy.Number, not sympy.Float, because:
@ -1059,6 +1067,8 @@ class FloatPow(sympy.Function):
class FloatTrueDiv(sympy.Function):
is_real = True
precedence: int = 35 # lower precedence than add
@classmethod
def eval(cls, base, divisor):
# assert base.is_integer is not True, base
@ -1082,6 +1092,8 @@ class FloatTrueDiv(sympy.Function):
class IntTrueDiv(sympy.Function):
is_real = True
precedence: int = 35 # lower precedence than add
@classmethod
def eval(cls, base, divisor):
if divisor.is_zero:
@ -1254,6 +1266,8 @@ class Identity(sympy.Function):
Prevents expansion and other optimizations
"""
precedence = 10
def __repr__(self): # type: ignore[override]
return f"Identity({self.args[0]})"

View File

@ -0,0 +1,459 @@
import sys
from typing import Optional
import sympy
from sympy.printing.precedence import PRECEDENCE, precedence
from sympy.printing.str import StrPrinter
INDEX_TYPE = "int64_t"
# This printer contains rules that are supposed to be generic for both C/C++ and
# Python
class ExprPrinter(StrPrinter):
# override this so that _print_FloorDiv is used
printmethod = "_torch_sympystr"
def _print_Mul(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, "*", precedence(expr))
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
return self.stringify(expr.args, " + ", precedence(expr))
def _print_Relational(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " & ", PRECEDENCE["Atom"] - 0.5)
def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " | ", PRECEDENCE["Atom"] - 0.5)
# NB: this is OK to put here, because Mod is only defined for positive
# numbers, and so across C/Python its behavior is consistent
def _print_Mod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
return f"({s})"
def _print_CleanDiv(self, expr: sympy.Expr) -> str:
return self._print_FloorDiv(expr)
def _print_Identity(self, expr: sympy.Expr) -> str:
return self._print(expr.args[0])
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
# any explicit intervention. We print it just like x * x, notably, we
# never generate sympy.Pow with floats.
#
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
def _print_Pow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
assert exp == int(exp), exp
exp = int(exp)
assert exp >= 0
if exp > 0:
return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
return "1"
# Explicit NotImplemented functions are to prevent default sympy printing
# behavior, which will just barf out ToFloat(...) to your IR. The error
# message is better here because it tells you which printer class it needs
# to go in.
def _print_ToFloat(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
def _print_Infinity(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_NegativeInfinity not implemented for {type(self)}"
)
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_PythonMod(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_RoundDecimal not implemented for {type(self)}"
)
# NB: Some float operations are INTENTIONALLY not implemented for
# printers. You can implement them as a quick unblock, but it is better
# to ask yourself why we haven't done this computation in the Tensor
# universe instead
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_TruncToFloat not implemented for {type(self)}"
)
class PythonPrinter(ExprPrinter):
def _print_ToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"float({self._print(expr.args[0])})"
def _print_And(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " and ", precedence(expr))
def _print_Or(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " or ", precedence(expr))
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = (
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
)
if div != "1":
x = f"({x} // {div})"
return f"({x} % {mod})"
def _print_Infinity(self, expr: sympy.Expr) -> str:
return "math.inf"
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
return "-math.inf"
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_PythonMod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
return f"{x} // {div}"
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
# does a special algorithm
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
def _helper_sqrt(self, expr: sympy.Expr) -> str:
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
return self._helper_sqrt(expr.args[0])
def _print_FloatPow(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
# TODO: Not sure this works with Triton, even when base/exp are integral
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
def _print_floor(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# This also could have been int(), they'll do the same thing for float
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
# NB: It's expected that we've made explicit any promotion in the sympy
# expression, so it doesn't matter that Python max/min doesn't perform
# promotion
def _print_Max(self, expr: sympy.Expr) -> str:
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr: sympy.Expr) -> str:
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr: sympy.Expr) -> str:
return (
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
)
def _print_Where(self, expr: sympy.Expr) -> str:
c, p, q = (
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
)
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = expr.args
x = self.doprint(x)
if div != 1:
div = self.doprint(div)
if expr.is_integer:
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.doprint(mod)
return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
x, div = expr.args
x = self.doprint(x)
div = self.doprint(div)
if expr.is_integer:
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr: sympy.Expr) -> str:
# Uses float constants to perform FP div
base, exp = expr.args
if exp == 0.5 or exp == -0.5:
base = self._print(base)
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
elif exp < -1:
r = (
"1.0/("
+ self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"])
+ ")"
)
elif exp == -1:
r = "1.0/" + self._print(base)
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr: sympy.Expr) -> str:
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr: sympy.Expr) -> str:
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr: sympy.Expr) -> str:
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
return f"std::sqrt({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
return "true"
def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
return "false"