mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597 Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
29ca44839e
commit
44186a0a4e
@ -580,8 +580,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
.check_regex(
|
||||
"torch.ops._c10d_functional.all_to_all_single.default\\("
|
||||
"arg\\d+_\\d+, "
|
||||
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
|
||||
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
|
||||
"\\[s\\d+ // \\d, s\\d+ // \\d\\], "
|
||||
"\\[s\\d+ // \\d, s\\d+ // \\d\\]"
|
||||
)
|
||||
.run(code)
|
||||
)
|
||||
|
@ -3570,7 +3570,7 @@ class GraphModule(torch.nn.Module):
|
||||
"cast_symbool_to_symint_guardless(L['pred']) == 1",
|
||||
]
|
||||
false_guard_code = [
|
||||
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
|
||||
"cast_symbool_to_symint_guardless(L['pred']) != 1",
|
||||
]
|
||||
test_symbool_guards(
|
||||
f,
|
||||
|
@ -668,7 +668,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
"""\
|
||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
|
||||
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
|
||||
)
|
||||
|
||||
|
@ -10457,7 +10457,7 @@ ShapeEnv not equal: field values don't match:
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {0 < Mod(s0, 3): False, 0 <= Mod(s0, 3): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Mod(s0, 3) < 0: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
|
||||
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
|
||||
> Right: {}
|
||||
==> divisible: values don't match.
|
||||
> Left: {Mod(s0, 3)}
|
||||
@ -10576,7 +10576,7 @@ ShapeEnv not equal: field values don't match:
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {0 < PythonMod(u0, 3): False, 0 <= PythonMod(u0, 3): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) < 0: False, PythonMod(u0, 3) <= 0: True}
|
||||
> Left: {(PythonMod(u0, 3)) < 0: False, (PythonMod(u0, 3)) <= 0: True, 0 < (PythonMod(u0, 3)): False, 0 <= (PythonMod(u0, 3)): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False}
|
||||
> Right: {}
|
||||
==> deferred_runtime_asserts: values don't match.
|
||||
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
|
||||
|
@ -3259,9 +3259,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
(torch.tensor(20),),
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression Eq((u0//2), 0)
|
||||
"torch._check(((i//2)) != 0)",
|
||||
"torch._check((i // 2) != 0)",
|
||||
# Could not guard on data-dependent expression Eq((u0//2), 1)
|
||||
"torch._check(((i//2)) != 1)",
|
||||
"torch._check((i // 2) != 1)",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1426,12 +1426,12 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
||||
xmask = xindex < xnumel
|
||||
x0 = xindex % 20
|
||||
x1 = (xindex // 20) % 20
|
||||
x2 = (xindex // 400)
|
||||
x0 = (xindex % 20)
|
||||
x1 = ((xindex // 20) % 20)
|
||||
x2 = xindex // 400
|
||||
x3 = xindex
|
||||
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
|
||||
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
|
||||
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
|
||||
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
|
||||
tmp2 = tmp0 + tmp1
|
||||
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
|
||||
)
|
||||
|
@ -20,7 +20,9 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
from torch.utils._sympy.functions import (
|
||||
FloorDiv,
|
||||
Mod,
|
||||
ModularIndexing,
|
||||
PythonMod,
|
||||
RoundDecimal,
|
||||
RoundToInt,
|
||||
)
|
||||
@ -236,7 +238,7 @@ class TestIndexingSimplification(InductorTestCase):
|
||||
triton_code = run_and_get_triton_code(f, x)
|
||||
# Make sure the 2 load uses simpified indexing rather than something like
|
||||
# tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
|
||||
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),"))
|
||||
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),"))
|
||||
if DO_PERF_TEST:
|
||||
ms = benchmarker.benchmark_gpu(lambda: f(x))
|
||||
print(f"{ms=:.03f}")
|
||||
@ -313,6 +315,39 @@ class ExprPrinterTests(InductorTestCase):
|
||||
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
|
||||
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
|
||||
|
||||
def test_print_mod(self):
|
||||
x = sympy.Symbol("x", integer=True)
|
||||
expr = Mod(x - 1, 2)
|
||||
self.assertExpectedInline(pexpr(expr), """((-1) + x) % 2""")
|
||||
self.assertExpectedInline(cexpr(expr), """((-1L) + x) % 2L""")
|
||||
self.assertExpectedInline(texpr(expr), """((-1) + x) % 2""")
|
||||
|
||||
expr = (x - 10) % x
|
||||
self.assertExpectedInline(pexpr(expr), """(-10) % x""")
|
||||
self.assertExpectedInline(cexpr(expr), """(-10L) % x""")
|
||||
self.assertExpectedInline(texpr(expr), """(-10) % x""")
|
||||
|
||||
def test_print_mod_index(self):
|
||||
x = sympy.Symbol("x", integer=True)
|
||||
ks = sympy.Symbol("ks", integer=True)
|
||||
expr = ModularIndexing(x - 10, ks, ks)
|
||||
self.assertExpectedInline(pexpr(expr), """((((-10) + x) // ks) % ks)""")
|
||||
self.assertExpectedInline(
|
||||
cexpr(expr),
|
||||
"""(static_cast<int64_t>(c10::div_floor_integer("""
|
||||
"""static_cast<int64_t>((-10L) + x), static_cast<int64_t>(ks))) % static_cast<int64_t>(ks))""",
|
||||
)
|
||||
self.assertExpectedInline(texpr(expr), """((((-10) + x) // ks) % ks)""")
|
||||
|
||||
def test_print_python_mod(self):
|
||||
x = sympy.Symbol("x", integer=True)
|
||||
expr = PythonMod(x - 10, x)
|
||||
self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""")
|
||||
self.assertExpectedInline(cexpr(expr), """((-10L) + x) % x""")
|
||||
self.assertExpectedInline(
|
||||
texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)"""
|
||||
)
|
||||
|
||||
@parametrize("ndigits", [-1, 0, 1])
|
||||
def test_print_round_decimal(self, ndigits):
|
||||
expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits)
|
||||
@ -330,7 +365,7 @@ class ExprPrinterTests(InductorTestCase):
|
||||
s1 = sympy.Symbol("s1", integer=True)
|
||||
s2 = sympy.Symbol("s2", integer=True)
|
||||
expr = FloorDiv(s1, s2)
|
||||
self.assertEqual(pexpr(expr), "(s1 // s2)")
|
||||
self.assertEqual(pexpr(expr), "s1 // s2")
|
||||
self.assertEqual(
|
||||
cexpr(expr),
|
||||
"c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",
|
||||
|
@ -58,13 +58,11 @@ class TestMemoryPlanning(TestCase):
|
||||
result, code = run_and_get_cpp_code(compiled, *args)
|
||||
|
||||
FileCheck().check(
|
||||
"pool1 = empty_strided_"
|
||||
+ GPU_TYPE
|
||||
+ "(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
|
||||
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
|
||||
).check_next(
|
||||
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
|
||||
).check(
|
||||
"buf1 = alloc_from_pool(pool1, align(4*(s0*s0)),"
|
||||
"buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
|
||||
).run(
|
||||
code
|
||||
)
|
||||
@ -103,7 +101,7 @@ class TestMemoryPlanning(TestCase):
|
||||
)
|
||||
|
||||
FileCheck().check(
|
||||
"int64_t int_array_2[] = {24L + (align(12L*s0)), };"
|
||||
"int64_t int_array_2[] = {24L + align(12L*s0), };"
|
||||
).check_next("int64_t int_array_3[] = {1L, };").check_next(
|
||||
"AtenTensorHandle pool1_handle;"
|
||||
).check_next(
|
||||
|
@ -487,7 +487,7 @@ class PaddingTest(TestCaseBase):
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
self.assertTrue(
|
||||
"tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper,
|
||||
"tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
@ -12505,8 +12505,8 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
self.assertExpectedInline(
|
||||
"\n".join(lines),
|
||||
"""\
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
|
||||
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0)
|
||||
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""",
|
||||
)
|
||||
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
@ -12538,7 +12538,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
self.assertExpectedInline(
|
||||
"\n".join(lines),
|
||||
"""\
|
||||
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
|
||||
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
|
||||
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
|
||||
)
|
||||
|
||||
|
@ -275,7 +275,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
||||
"\n".join(load_lines),
|
||||
"""\
|
||||
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
|
||||
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
|
||||
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
"\n".join(store_lines),
|
||||
|
@ -2792,8 +2792,8 @@ class TestGuardsExpressions(TestCase):
|
||||
guard_int(sym_int(s0 / 2.0))
|
||||
guards = shape_env.produce_guards_expression([s0])
|
||||
|
||||
self.assertIn("ToFloat", guards)
|
||||
self.assertIn("FloatTrueDiv", guards)
|
||||
self.assertIn("math.trunc(", guards)
|
||||
self.assertIn("float(", guards)
|
||||
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
||||
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
||||
|
||||
|
@ -23,7 +23,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import sympy
|
||||
from sympy.printing.printer import Printer
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
@ -31,6 +30,7 @@ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
|
||||
@ -609,12 +609,22 @@ class DataTypePropagation:
|
||||
DataTypePropagation.propagate_loopbody(node._body)
|
||||
|
||||
|
||||
# This printer contains rules that are supposed to be generic for both C/C++ and
|
||||
# Python
|
||||
class ExprPrinter(Printer):
|
||||
class PythonPrinter(_PythonPrinter):
|
||||
def doprint(self, expr, *, simplify: bool = True, p=True):
|
||||
# TODO: why are people passing strings to the printer here :think:
|
||||
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
||||
expr = V.graph.sizevars.simplify(expr)
|
||||
return super().doprint(expr)
|
||||
|
||||
|
||||
class OpOverrides:
|
||||
def __init__(self, parent):
|
||||
super().__init__()
|
||||
self._parent = parent
|
||||
|
||||
@staticmethod
|
||||
def paren(string):
|
||||
def all_in_parens(string):
|
||||
def paren(string: str) -> str:
|
||||
def all_in_parens(string: str) -> bool:
|
||||
if string[0] != "(" or len(string) < 2:
|
||||
return False
|
||||
count = 1
|
||||
@ -640,260 +650,6 @@ class ExprPrinter(Printer):
|
||||
return string
|
||||
return f"({string})"
|
||||
|
||||
def _print_Relational(self, expr):
|
||||
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_Mul(self, expr):
|
||||
return "*".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_Add(self, expr):
|
||||
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
# NB: this is OK to put here, because Mod is only defined for positive
|
||||
# numbers, and so across C/Python its behavior is consistent
|
||||
def _print_Mod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_FloatTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
def _print_CleanDiv(self, expr):
|
||||
return self._print_FloorDiv(expr)
|
||||
|
||||
def _print_Identity(self, expr):
|
||||
return self._print(expr.args[0])
|
||||
|
||||
def _print_GreaterThan(self, expr):
|
||||
# GreaterThan: >=
|
||||
# StrictlyGreaterThan: >
|
||||
# Go figure...
|
||||
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
# NB: The C implementation is injected into codegen at
|
||||
# torch/_inductor/codegen/wrapper.py
|
||||
def _print_align(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"align({self._print(expr.args[0])})"
|
||||
|
||||
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
||||
# any explicit intervention. We print it just like x * x, notably, we
|
||||
# never generate sympy.Pow with floats.
|
||||
#
|
||||
# NB: this pow by natural, you should never have used builtin sympy.pow
|
||||
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
||||
# means exp is guaranteed to be integer.
|
||||
def _print_Pow(self, expr):
|
||||
base, exp = expr.args
|
||||
base = self._print(base)
|
||||
assert exp == int(exp), exp
|
||||
exp = int(exp)
|
||||
assert exp >= 0
|
||||
if exp > 0:
|
||||
return "*".join([self.paren(base)] * exp)
|
||||
return "1"
|
||||
|
||||
# Explicit NotImplemented functions are to prevent default sympy printing
|
||||
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
||||
# message is better here because it tells you which printer class it needs
|
||||
# to go in.
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
||||
|
||||
def _print_Infinity(self, expr):
|
||||
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
||||
|
||||
def _print_NegativeInfinity(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_NegativeInfinity not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
||||
|
||||
def _print_PythonMod(self, expr):
|
||||
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
||||
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
||||
|
||||
def _print_PowByNatural(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloatPow(self, expr):
|
||||
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
||||
|
||||
def _print_TruncToInt(self, expr):
|
||||
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
||||
|
||||
def _print_RoundToInt(self, expr):
|
||||
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_RoundDecimal not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
# NB: Some float operations are INTENTIONALLY not implemented for
|
||||
# printers. You can implement them as a quick unblock, but it is better
|
||||
# to ask yourself why we haven't done this computation in the Tensor
|
||||
# universe instead
|
||||
|
||||
def _print_TruncToFloat(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_TruncToFloat not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def doprint(self, expr, *, simplify: bool = True):
|
||||
# TODO: why are people passing strings to the printer here :think:
|
||||
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
||||
expr = V.graph.sizevars.simplify(expr)
|
||||
return super().doprint(expr)
|
||||
|
||||
|
||||
class PythonPrinter(ExprPrinter):
|
||||
def _print_ToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"float({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ModularIndexing(self, expr):
|
||||
x, div, mod = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
mod = self.paren(self.doprint(mod))
|
||||
if div != "1":
|
||||
x = f"({x} // {div})"
|
||||
return f"{x} % {mod}"
|
||||
|
||||
def _print_Infinity(self, expr):
|
||||
return "math.inf"
|
||||
|
||||
def _print_NegativeInfinity(self, expr):
|
||||
return "-math.inf"
|
||||
|
||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||
def _print_PythonMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||
def _print_FloorDiv(self, expr):
|
||||
x, div = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
return f"({x} // {div})"
|
||||
|
||||
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
||||
# does a special algorithm
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
def _helper_sqrt(self, expr):
|
||||
return f"math.sqrt({self._print(expr)})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
||||
return self._helper_sqrt(expr.args[0])
|
||||
|
||||
def _print_FloatPow(self, expr):
|
||||
base, exp = expr.args
|
||||
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
||||
|
||||
# TODO: Not sure this works with Triton, even when base/exp are integral
|
||||
def _print_PowByNatural(self, expr):
|
||||
base, exp = expr.args
|
||||
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
||||
|
||||
def _print_floor(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.floor({self._print(expr.args[0])})"
|
||||
|
||||
def _print_FloorToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.floor({self._print(expr.args[0])})"
|
||||
|
||||
def _print_TruncToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
# This also could have been int(), they'll do the same thing for float
|
||||
return f"math.trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.ceil({self._print(expr.args[0])})"
|
||||
|
||||
def _print_CeilToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.ceil({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Abs(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"abs({self._print(expr.args[0])})"
|
||||
|
||||
# NB: It's expected that we've made explicit any promotion in the sympy
|
||||
# expression, so it doesn't matter that Python max/min doesn't perform
|
||||
# promotion
|
||||
def _print_Max(self, expr):
|
||||
assert len(expr.args) >= 2
|
||||
return f"max({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
def _print_Min(self, expr):
|
||||
assert len(expr.args) >= 2
|
||||
return f"min({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.cos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.cosh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_acos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.acos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.sin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.sinh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_asin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.asin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.tan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.tanh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_atan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"math.atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"round({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
assert len(expr.args) == 2
|
||||
number, ndigits = expr.args
|
||||
assert isinstance(ndigits, sympy.Integer)
|
||||
return f"round({self._print(number)}, {ndigits})"
|
||||
|
||||
|
||||
class OpOverrides:
|
||||
def __init__(self, parent):
|
||||
super().__init__()
|
||||
self._parent = parent
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._parent, item)
|
||||
|
||||
@ -982,31 +738,31 @@ class OpOverrides:
|
||||
|
||||
@staticmethod
|
||||
def bitwise_not(x):
|
||||
return f"~{ExprPrinter.paren(x)}"
|
||||
return f"~{OpOverrides.paren(x)}"
|
||||
|
||||
@staticmethod
|
||||
def logical_not(a):
|
||||
return f"{ExprPrinter.paren(a)} == 0"
|
||||
return f"{OpOverrides.paren(a)} == 0"
|
||||
|
||||
@staticmethod
|
||||
def bitwise_and(x, y):
|
||||
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
||||
return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
|
||||
|
||||
@staticmethod
|
||||
def bitwise_or(x, y):
|
||||
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
||||
return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
|
||||
|
||||
@staticmethod
|
||||
def bitwise_xor(x, y):
|
||||
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
||||
return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
|
||||
|
||||
@staticmethod
|
||||
def bitwise_left_shift(x, y):
|
||||
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
||||
return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
|
||||
|
||||
@staticmethod
|
||||
def bitwise_right_shift(x, y):
|
||||
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
||||
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
|
||||
|
||||
@staticmethod
|
||||
def remainder(a, b):
|
||||
|
@ -13,6 +13,7 @@ import sympy
|
||||
import torch
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.printers import CppPrinter as _CppPrinter
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
@ -25,7 +26,6 @@ from ..virtualized import ops, OpsValue, V
|
||||
from .common import (
|
||||
CSEVariable,
|
||||
deduce_output_dtype_by_name,
|
||||
ExprPrinter,
|
||||
Kernel,
|
||||
KernelArgs,
|
||||
OptimizationContext,
|
||||
@ -232,212 +232,12 @@ class CppCSEVariable(CSEVariable):
|
||||
return itervar in self.dependent_itervars
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr):
|
||||
return (
|
||||
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
|
||||
)
|
||||
|
||||
def _print_Where(self, expr):
|
||||
c = self.paren(self.doprint(expr.args[0]))
|
||||
p = self.paren(self.doprint(expr.args[1]))
|
||||
q = self.paren(self.doprint(expr.args[2]))
|
||||
return f"{c} ? {p} : {q}"
|
||||
|
||||
def _print_ModularIndexing(self, expr):
|
||||
x, div, mod = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
if div != 1:
|
||||
div = self.paren(self.doprint(div))
|
||||
if expr.is_integer:
|
||||
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
else:
|
||||
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
mod = self.paren(self.doprint(mod))
|
||||
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
x, div = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
if expr.is_integer:
|
||||
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
|
||||
def _print_floor(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_FloorToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_TruncToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::trunc({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})"
|
||||
|
||||
def _print_TruncToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"static_cast<double>({self._print(expr.args[0])})"
|
||||
|
||||
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
||||
# tickle though, as the inputs are typically positive (and if we can prove
|
||||
# they are positive, we will have used Mod instead, for which this codegen
|
||||
# is right).
|
||||
def _print_PythonMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_CMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
# TODO: This is only accurate up to 2**53
|
||||
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
||||
|
||||
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
||||
# use std::pow, that operates on floats
|
||||
def _print_PowByNatural(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloatTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
def _print_FloatPow(self, expr):
|
||||
base, exp = expr.args
|
||||
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
||||
|
||||
def _print_Pow(self, expr):
|
||||
# Uses float constants to perform FP div
|
||||
base, exp = expr.args
|
||||
base = self._print(base)
|
||||
|
||||
if exp == 0.5 or exp == -0.5:
|
||||
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
|
||||
if exp.is_integer:
|
||||
exp = int(exp)
|
||||
if exp > 0:
|
||||
r = "*".join([self.paren(base)] * exp)
|
||||
elif exp < 0:
|
||||
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
||||
else: # exp == 0
|
||||
r = "1.0"
|
||||
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
else:
|
||||
# TODO: float vs double
|
||||
return f"std::pow({base}, {float(exp)})"
|
||||
|
||||
def _print_Rational(self, expr):
|
||||
# Uses float constants to perform FP div
|
||||
if expr.q == 1:
|
||||
r = f"{expr.p}"
|
||||
else:
|
||||
r = f"{expr.p}.0/{expr.q}.0"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_CeilToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_Min(self, expr):
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
return f"std::min({il})"
|
||||
|
||||
def _print_Max(self, expr):
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
return f"std::max({il})"
|
||||
|
||||
def _print_Abs(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::abs({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::cos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::cosh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_acos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::acos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::sin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::sinh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_asin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::asin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::tan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::tanh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_atan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
||||
return f"std::sqrt({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
# TODO: dispatch to llrint depending on index type
|
||||
return f"std::lrint({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
assert len(expr.args) == 2
|
||||
number, ndigits = expr.args
|
||||
if number.is_integer:
|
||||
# ndigits < 0 should have been filtered by the sympy function
|
||||
assert ndigits < 0
|
||||
raise ValueError(
|
||||
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
||||
)
|
||||
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
|
||||
|
||||
def _print_BooleanTrue(self, expr):
|
||||
return "true"
|
||||
|
||||
def _print_BooleanFalse(self, expr):
|
||||
return "false"
|
||||
class CppPrinter(_CppPrinter):
|
||||
def doprint(self, expr, *, simplify: bool = True, p=True):
|
||||
# TODO: why are people passing strings to the printer here :think:
|
||||
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
||||
expr = V.graph.sizevars.simplify(expr)
|
||||
return super().doprint(expr)
|
||||
|
||||
|
||||
# A function to print, useful for printing sympy symbols.
|
||||
|
@ -185,8 +185,8 @@ class HalidePrinter(PythonPrinter):
|
||||
return super()._print_FloorDiv(expr)
|
||||
|
||||
x, div = expr.args
|
||||
x = self.cast_float(self.paren(self.doprint(x)))
|
||||
div = self.cast_float(self.paren(self.doprint(div)))
|
||||
x = self.cast_float(self.doprint(x))
|
||||
div = self.cast_float(self.doprint(div))
|
||||
return self.cast_index(f"hl.floor({x} / {div})")
|
||||
|
||||
def _print_Round(self, expr):
|
||||
|
@ -27,6 +27,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import sympy
|
||||
from sympy.printing.precedence import PRECEDENCE
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
@ -504,30 +505,30 @@ class TritonPrinter(PythonPrinter):
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"
|
||||
s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
return f"{s}.to(tl.float64)"
|
||||
|
||||
def _print_PythonMod(self, expr):
|
||||
quot, div = expr.args
|
||||
if quot.is_nonnegative and div.is_nonnegative:
|
||||
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||
quot_s = self._print(quot)
|
||||
div_s = self._print(div)
|
||||
if quot.is_nonnegative and div.is_nonnegative:
|
||||
return f"{self.paren(quot_s)} % {self.paren(div_s)}"
|
||||
return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
assert expr.is_integer
|
||||
quot, div = expr.args
|
||||
if quot.is_nonnegative and div.is_nonnegative:
|
||||
return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5)
|
||||
quot_s = self._print(quot)
|
||||
div_s = self._print(div)
|
||||
if quot.is_nonnegative and div.is_nonnegative:
|
||||
return f"({self.paren(quot_s)} // {self.paren(div_s)})"
|
||||
return f"triton_helpers.div_floor_integer({quot_s}, {div_s})"
|
||||
|
||||
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
|
||||
# precision algorithm, which we would need to replicate here
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
# NB: sympy.floor/ceiling produce integers, so we have to do the
|
||||
# conversion to index dtype
|
||||
@ -646,7 +647,9 @@ class TritonPrinter(PythonPrinter):
|
||||
raise ValueError(
|
||||
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
||||
)
|
||||
return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}"
|
||||
|
||||
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
|
||||
return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}"
|
||||
|
||||
|
||||
texpr = TritonPrinter().doprint
|
||||
|
@ -35,13 +35,12 @@ from .autotune_process import (
|
||||
TritonGPUBenchmarkRequest,
|
||||
)
|
||||
from .codecache import code_hash, PersistentCache, PyCodeCache
|
||||
from .codegen.common import IndentedBuffer, KernelTemplate, WorkspaceArg
|
||||
from .codegen.common import IndentedBuffer, KernelTemplate, OpOverrides, WorkspaceArg
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
from .codegen.triton import (
|
||||
gen_common_triton_imports,
|
||||
texpr,
|
||||
TritonKernel,
|
||||
TritonPrinter,
|
||||
TritonScheduling,
|
||||
)
|
||||
from .codegen.triton_utils import config_of, signature_to_meta
|
||||
@ -562,7 +561,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
assert isinstance(val, str)
|
||||
assert isinstance(mask, (str, type(None)))
|
||||
assert self.template_mask is None
|
||||
indices = list(map(TritonPrinter.paren, indices))
|
||||
indices = list(map(OpOverrides.paren, indices))
|
||||
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
|
||||
lengths = [
|
||||
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
|
||||
@ -630,7 +629,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
assert isinstance(name, str)
|
||||
assert isinstance(mask, str)
|
||||
stride = self.named_input_nodes[name].get_stride()
|
||||
indices = list(map(TritonPrinter.paren, indices))
|
||||
indices = list(map(OpOverrides.paren, indices))
|
||||
assert len(indices) == len(stride)
|
||||
index = " + ".join(
|
||||
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
|
||||
|
@ -10,7 +10,6 @@ need to make use of these APIs to setup dynamic shapes support appropriately.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import builtins
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
@ -82,6 +81,7 @@ from torch.utils._sympy.functions import (
|
||||
PythonMod,
|
||||
)
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.printers import PythonPrinter
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from torch.utils._sympy.solve import try_solve
|
||||
from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT
|
||||
@ -109,8 +109,6 @@ log = logging.getLogger(__name__)
|
||||
|
||||
import sympy
|
||||
from sympy import S
|
||||
from sympy.printing.precedence import PRECEDENCE, precedence
|
||||
from sympy.printing.str import StrPrinter
|
||||
|
||||
|
||||
class GuardOnDataDependentSymNode(RuntimeError):
|
||||
@ -2032,45 +2030,9 @@ def cast_symbool_to_symint_guardless(
|
||||
|
||||
|
||||
SYMPY_INTERP = {
|
||||
"Abs": operator.abs,
|
||||
"Eq": operator.eq,
|
||||
"Ne": operator.ne,
|
||||
"Gt": operator.gt,
|
||||
"Lt": operator.lt,
|
||||
"Le": operator.le,
|
||||
"Ge": operator.ge,
|
||||
"Min": min,
|
||||
"Max": max,
|
||||
"Mod": operator.mod,
|
||||
"PythonMod": operator.mod,
|
||||
"FloorDiv": operator.floordiv,
|
||||
"TrueDiv": operator.truediv,
|
||||
"PowByNatural": operator.pow,
|
||||
"IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense,
|
||||
"floor": math.floor,
|
||||
"ceiling": math.ceil,
|
||||
"FloorToInt": math.floor,
|
||||
"FloatPow": math.pow,
|
||||
"CeilToInt": math.ceil,
|
||||
"cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless,
|
||||
"RoundToInt": builtins.round,
|
||||
"RoundDecimal": builtins.round,
|
||||
"TruncToInt": math.trunc,
|
||||
"IntTrueDiv": operator.truediv,
|
||||
"FloatTrueDiv": operator.truediv,
|
||||
"ToFloat": builtins.float,
|
||||
"OpaqueUnaryFn_cos": math.cos,
|
||||
"OpaqueUnaryFn_cosh": math.cosh,
|
||||
"OpaqueUnaryFn_acos": math.acos,
|
||||
"OpaqueUnaryFn_sin": math.sin,
|
||||
"OpaqueUnaryFn_sinh": math.sinh,
|
||||
"OpaqueUnaryFn_asin": math.asin,
|
||||
"OpaqueUnaryFn_tan": math.tan,
|
||||
"OpaqueUnaryFn_tanh": math.tanh,
|
||||
"OpaqueUnaryFn_atan": math.atan,
|
||||
"OpaqueUnaryFn_sqrt": math.sqrt,
|
||||
"BitwiseFn_bitwise_and": operator.and_,
|
||||
"BitwiseFn_bitwise_or": operator.or_,
|
||||
"math": math,
|
||||
}
|
||||
|
||||
|
||||
@ -2141,12 +2103,12 @@ class RuntimeAssert:
|
||||
|
||||
|
||||
# Used for printing SymExprs in compile_fx
|
||||
class SymExprPrinter(StrPrinter):
|
||||
class SymExprPrinter(PythonPrinter):
|
||||
def _print_Float(self, expr: sympy.Float) -> str:
|
||||
return str(float(expr))
|
||||
|
||||
|
||||
class ShapeGuardPrinter(SymExprPrinter):
|
||||
class ShapeGuardPrinter(PythonPrinter):
|
||||
def __init__(
|
||||
self,
|
||||
symbol_to_source: Mapping[sympy.Symbol, List[Source]],
|
||||
@ -2158,14 +2120,8 @@ class ShapeGuardPrinter(SymExprPrinter):
|
||||
self.source_ref = source_ref
|
||||
self.var_to_sources = var_to_sources
|
||||
|
||||
def _print_Not(self, expr: SympyBoolean) -> str:
|
||||
return "not {}".format(self.parenthesize(expr.args[0], PRECEDENCE["Not"]))
|
||||
|
||||
def _print_And(self, expr: SympyBoolean) -> str:
|
||||
return self.stringify(expr.args, " and ", PRECEDENCE["And"])
|
||||
|
||||
def _print_Or(self, expr: SympyBoolean) -> str:
|
||||
return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
|
||||
def _print_Float(self, expr: sympy.Float) -> str:
|
||||
return str(float(expr))
|
||||
|
||||
def _print_Symbol(self, expr: sympy.Symbol) -> str:
|
||||
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
||||
@ -2191,7 +2147,7 @@ class LoggingShapeGuardPrinter(ShapeGuardPrinter):
|
||||
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
|
||||
|
||||
|
||||
class DynamicDimConstraintPrinter(StrPrinter):
|
||||
class DynamicDimConstraintPrinter(PythonPrinter):
|
||||
"""
|
||||
Printer for dynamic dim constraints.
|
||||
- Instead of symbol s_k it prints its source t.size()[i]
|
||||
@ -2216,9 +2172,6 @@ class DynamicDimConstraintPrinter(StrPrinter):
|
||||
), f"Unknown symbol {expr} created by constraints solver"
|
||||
return self.symbol_to_source[expr][0].name()
|
||||
|
||||
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
|
||||
return f"{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class DimConstraints:
|
||||
"""
|
||||
@ -6656,7 +6609,7 @@ def _blame_user_code(e: Exception, frame: types.FrameType) -> None:
|
||||
e.args = (msg,)
|
||||
|
||||
|
||||
class _PythonPrinter(sympy.printing.str.StrPrinter):
|
||||
class _PythonMsgPrinter(PythonPrinter):
|
||||
"""
|
||||
Util printer that replaces sympy symbols with their source-level names
|
||||
and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline
|
||||
@ -6670,13 +6623,6 @@ class _PythonPrinter(sympy.printing.str.StrPrinter):
|
||||
def _print_Symbol(self, sym: sympy.Symbol) -> str:
|
||||
return self.src_map[sym.name][0]
|
||||
|
||||
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
|
||||
lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr))
|
||||
assert hasattr(expr, "rel_op")
|
||||
rel_op = expr.rel_op
|
||||
rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr))
|
||||
return f"{lhs} {rel_op} {rhs}"
|
||||
|
||||
|
||||
def _suggest_torch_checks(
|
||||
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]]
|
||||
@ -6687,7 +6633,7 @@ def _suggest_torch_checks(
|
||||
if diff:
|
||||
log.warning("Unable to find user code corresponding to {%s}", diff)
|
||||
return
|
||||
printer = _PythonPrinter(src_map)
|
||||
printer = _PythonMsgPrinter(src_map)
|
||||
msg = e.args[0]
|
||||
msg += "\nTo fix the error, insert one of the following checks before this call:"
|
||||
# suggested fixes to resolve `cond`` are to tell the compiler to assume
|
||||
|
@ -191,7 +191,7 @@ class FloorDiv(sympy.Function):
|
||||
"""
|
||||
|
||||
nargs: Tuple[int, ...] = (2,)
|
||||
precedence: int = 50 # precedence of mul # noqa: F811
|
||||
precedence: int = 35 # lower precedence than add
|
||||
is_integer: bool = True
|
||||
|
||||
@property
|
||||
@ -291,6 +291,7 @@ class ModularIndexing(sympy.Function):
|
||||
|
||||
nargs: Tuple[int, ...] = (3,)
|
||||
is_integer: bool = True
|
||||
precedence: int = 35 # lower precedence than add
|
||||
|
||||
@classmethod
|
||||
def eval(
|
||||
@ -360,6 +361,7 @@ class Where(sympy.Function):
|
||||
"""
|
||||
|
||||
nargs: Tuple[int, ...] = (3,)
|
||||
precedence: int = 35 # lower precedence than add
|
||||
|
||||
def _eval_is_integer(self) -> Optional[bool]:
|
||||
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
||||
@ -389,6 +391,7 @@ class Where(sympy.Function):
|
||||
class PythonMod(sympy.Function):
|
||||
nargs: Tuple[int, ...] = (2,)
|
||||
|
||||
precedence: int = 35 # lower precedence than add
|
||||
is_integer: bool = True
|
||||
|
||||
@classmethod
|
||||
@ -447,6 +450,7 @@ class PythonMod(sympy.Function):
|
||||
# Generic modulus: only defined on non-negative arguments
|
||||
class Mod(sympy.Function):
|
||||
nargs = (2,)
|
||||
precedence: int = 35 # lower precedence than add
|
||||
|
||||
is_integer = True
|
||||
is_nonnegative = True
|
||||
@ -1014,6 +1018,8 @@ def _safe_pow(base, exponent):
|
||||
class PowByNatural(sympy.Function):
|
||||
is_integer = True
|
||||
|
||||
precedence: int = 50 # precedence of mul
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, exp):
|
||||
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
|
||||
@ -1039,6 +1045,8 @@ class PowByNatural(sympy.Function):
|
||||
class FloatPow(sympy.Function):
|
||||
is_real = True
|
||||
|
||||
precedence: int = 60 # precedence of pow
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, exp):
|
||||
# NB: These test sympy.Number, not sympy.Float, because:
|
||||
@ -1059,6 +1067,8 @@ class FloatPow(sympy.Function):
|
||||
class FloatTrueDiv(sympy.Function):
|
||||
is_real = True
|
||||
|
||||
precedence: int = 35 # lower precedence than add
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
# assert base.is_integer is not True, base
|
||||
@ -1082,6 +1092,8 @@ class FloatTrueDiv(sympy.Function):
|
||||
class IntTrueDiv(sympy.Function):
|
||||
is_real = True
|
||||
|
||||
precedence: int = 35 # lower precedence than add
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
if divisor.is_zero:
|
||||
@ -1254,6 +1266,8 @@ class Identity(sympy.Function):
|
||||
Prevents expansion and other optimizations
|
||||
"""
|
||||
|
||||
precedence = 10
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
return f"Identity({self.args[0]})"
|
||||
|
||||
|
459
torch/utils/_sympy/printers.py
Normal file
459
torch/utils/_sympy/printers.py
Normal file
@ -0,0 +1,459 @@
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import sympy
|
||||
from sympy.printing.precedence import PRECEDENCE, precedence
|
||||
from sympy.printing.str import StrPrinter
|
||||
|
||||
|
||||
INDEX_TYPE = "int64_t"
|
||||
|
||||
|
||||
# This printer contains rules that are supposed to be generic for both C/C++ and
|
||||
# Python
|
||||
class ExprPrinter(StrPrinter):
|
||||
# override this so that _print_FloorDiv is used
|
||||
printmethod = "_torch_sympystr"
|
||||
|
||||
def _print_Mul(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, "*", precedence(expr))
|
||||
|
||||
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
|
||||
return self.stringify(expr.args, " + ", precedence(expr))
|
||||
|
||||
def _print_Relational(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
|
||||
|
||||
def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " & ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " | ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
# NB: this is OK to put here, because Mod is only defined for positive
|
||||
# numbers, and so across C/Python its behavior is consistent
|
||||
def _print_Mod(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
|
||||
s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
||||
return f"({s})"
|
||||
|
||||
def _print_CleanDiv(self, expr: sympy.Expr) -> str:
|
||||
return self._print_FloorDiv(expr)
|
||||
|
||||
def _print_Identity(self, expr: sympy.Expr) -> str:
|
||||
return self._print(expr.args[0])
|
||||
|
||||
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
||||
# any explicit intervention. We print it just like x * x, notably, we
|
||||
# never generate sympy.Pow with floats.
|
||||
#
|
||||
# NB: this pow by natural, you should never have used builtin sympy.pow
|
||||
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
||||
# means exp is guaranteed to be integer.
|
||||
def _print_Pow(self, expr: sympy.Expr) -> str:
|
||||
base, exp = expr.args
|
||||
assert exp == int(exp), exp
|
||||
exp = int(exp)
|
||||
assert exp >= 0
|
||||
if exp > 0:
|
||||
return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
|
||||
return "1"
|
||||
|
||||
# Explicit NotImplemented functions are to prevent default sympy printing
|
||||
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
||||
# message is better here because it tells you which printer class it needs
|
||||
# to go in.
|
||||
|
||||
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
||||
|
||||
def _print_Infinity(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
||||
|
||||
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(
|
||||
f"_print_NegativeInfinity not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
||||
|
||||
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
||||
|
||||
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
||||
|
||||
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
||||
|
||||
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
||||
|
||||
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
||||
|
||||
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(
|
||||
f"_print_RoundDecimal not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
# NB: Some float operations are INTENTIONALLY not implemented for
|
||||
# printers. You can implement them as a quick unblock, but it is better
|
||||
# to ask yourself why we haven't done this computation in the Tensor
|
||||
# universe instead
|
||||
|
||||
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(
|
||||
f"_print_TruncToFloat not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
|
||||
class PythonPrinter(ExprPrinter):
|
||||
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"float({self._print(expr.args[0])})"
|
||||
|
||||
def _print_And(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " and ", precedence(expr))
|
||||
|
||||
def _print_Or(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " or ", precedence(expr))
|
||||
|
||||
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||
x, div, mod = (
|
||||
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
|
||||
)
|
||||
if div != "1":
|
||||
x = f"({x} // {div})"
|
||||
return f"({x} % {mod})"
|
||||
|
||||
def _print_Infinity(self, expr: sympy.Expr) -> str:
|
||||
return "math.inf"
|
||||
|
||||
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
|
||||
return "-math.inf"
|
||||
|
||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
||||
x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
|
||||
return f"{x} // {div}"
|
||||
|
||||
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
||||
# does a special algorithm
|
||||
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
def _helper_sqrt(self, expr: sympy.Expr) -> str:
|
||||
return f"math.sqrt({self._print(expr)})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
|
||||
return self._helper_sqrt(expr.args[0])
|
||||
|
||||
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
|
||||
|
||||
# TODO: Not sure this works with Triton, even when base/exp are integral
|
||||
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
|
||||
|
||||
def _print_floor(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.floor({self._print(expr.args[0])})"
|
||||
|
||||
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.floor({self._print(expr.args[0])})"
|
||||
|
||||
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
# This also could have been int(), they'll do the same thing for float
|
||||
return f"math.trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ceiling(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.ceil({self._print(expr.args[0])})"
|
||||
|
||||
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.ceil({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Abs(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"abs({self._print(expr.args[0])})"
|
||||
|
||||
# NB: It's expected that we've made explicit any promotion in the sympy
|
||||
# expression, so it doesn't matter that Python max/min doesn't perform
|
||||
# promotion
|
||||
def _print_Max(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) >= 2
|
||||
return f"max({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
def _print_Min(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) >= 2
|
||||
return f"min({', '.join(map(self._print, expr.args))})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.cos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.cosh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.acos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.sin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.sinh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.asin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.tan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.tanh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"round({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 2
|
||||
number, ndigits = expr.args
|
||||
assert isinstance(ndigits, sympy.Integer)
|
||||
return f"round({self._print(number)}, {ndigits})"
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr: sympy.Expr) -> str:
|
||||
return (
|
||||
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
|
||||
)
|
||||
|
||||
def _print_Where(self, expr: sympy.Expr) -> str:
|
||||
c, p, q = (
|
||||
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
|
||||
)
|
||||
return f"{c} ? {p} : {q}"
|
||||
|
||||
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||
x, div, mod = expr.args
|
||||
x = self.doprint(x)
|
||||
if div != 1:
|
||||
div = self.doprint(div)
|
||||
if expr.is_integer:
|
||||
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
else:
|
||||
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
mod = self.doprint(mod)
|
||||
return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
|
||||
|
||||
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
||||
x, div = expr.args
|
||||
x = self.doprint(x)
|
||||
div = self.doprint(div)
|
||||
if expr.is_integer:
|
||||
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
|
||||
def _print_floor(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::trunc({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})"
|
||||
|
||||
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"static_cast<double>({self._print(expr.args[0])})"
|
||||
|
||||
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
||||
# tickle though, as the inputs are typically positive (and if we can prove
|
||||
# they are positive, we will have used Mod instead, for which this codegen
|
||||
# is right).
|
||||
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||
lhs, rhs = expr.args
|
||||
# TODO: This is only accurate up to 2**53
|
||||
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
||||
|
||||
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
||||
# use std::pow, that operates on floats
|
||||
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||
base, exp = expr.args
|
||||
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
||||
|
||||
def _print_Pow(self, expr: sympy.Expr) -> str:
|
||||
# Uses float constants to perform FP div
|
||||
base, exp = expr.args
|
||||
|
||||
if exp == 0.5 or exp == -0.5:
|
||||
base = self._print(base)
|
||||
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
|
||||
if exp.is_integer:
|
||||
exp = int(exp)
|
||||
if exp > 0:
|
||||
r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
|
||||
elif exp < -1:
|
||||
r = (
|
||||
"1.0/("
|
||||
+ self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"])
|
||||
+ ")"
|
||||
)
|
||||
elif exp == -1:
|
||||
r = "1.0/" + self._print(base)
|
||||
else: # exp == 0
|
||||
r = "1.0"
|
||||
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
else:
|
||||
# TODO: float vs double
|
||||
return f"std::pow({base}, {float(exp)})"
|
||||
|
||||
def _print_Rational(self, expr: sympy.Expr) -> str:
|
||||
# Uses float constants to perform FP div
|
||||
if expr.q == 1:
|
||||
r = f"{expr.p}"
|
||||
else:
|
||||
r = f"{expr.p}.0/{expr.q}.0"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_ceiling(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_Min(self, expr: sympy.Expr) -> str:
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
return f"std::min({il})"
|
||||
|
||||
def _print_Max(self, expr: sympy.Expr) -> str:
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
return f"std::max({il})"
|
||||
|
||||
def _print_Abs(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::abs({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::cos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::cosh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::acos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::sin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::sinh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::asin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::tan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::tanh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"std::atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
|
||||
return f"std::sqrt({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
# TODO: dispatch to llrint depending on index type
|
||||
return f"std::lrint({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 2
|
||||
number, ndigits = expr.args
|
||||
if number.is_integer:
|
||||
# ndigits < 0 should have been filtered by the sympy function
|
||||
assert ndigits < 0
|
||||
raise ValueError(
|
||||
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
||||
)
|
||||
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
|
||||
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
|
||||
|
||||
def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
|
||||
return "true"
|
||||
|
||||
def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
|
||||
return "false"
|
Reference in New Issue
Block a user