Move Sympy printers to torch/utils/_sympy/printers.py (#140597)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
Isuru Fernando
2024-11-25 22:22:07 +00:00
committed by PyTorch MergeBot
parent 29ca44839e
commit 44186a0a4e
20 changed files with 589 additions and 579 deletions

View File

@ -13,6 +13,7 @@ import sympy
import torch
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.printers import CppPrinter as _CppPrinter
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
@ -25,7 +26,6 @@ from ..virtualized import ops, OpsValue, V
from .common import (
CSEVariable,
deduce_output_dtype_by_name,
ExprPrinter,
Kernel,
KernelArgs,
OptimizationContext,
@ -232,212 +232,12 @@ class CppCSEVariable(CSEVariable):
return itervar in self.dependent_itervars
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return (
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
)
def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
p = self.paren(self.doprint(expr.args[1]))
q = self.paren(self.doprint(expr.args[2]))
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
if div != 1:
div = self.paren(self.doprint(div))
if expr.is_integer:
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.paren(self.doprint(mod))
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
if expr.is_integer:
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr):
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_CMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr):
# Uses float constants to perform FP div
base, exp = expr.args
base = self._print(base)
if exp == 0.5 or exp == -0.5:
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = "*".join([self.paren(base)] * exp)
elif exp < 0:
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr):
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return f"std::sqrt({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
class CppPrinter(_CppPrinter):
def doprint(self, expr, *, simplify: bool = True, p=True):
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
# A function to print, useful for printing sympy symbols.