mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 19:54:53 +08:00
Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597 Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
29ca44839e
commit
44186a0a4e
@ -13,6 +13,7 @@ import sympy
|
||||
import torch
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.printers import CppPrinter as _CppPrinter
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
@ -25,7 +26,6 @@ from ..virtualized import ops, OpsValue, V
|
||||
from .common import (
|
||||
CSEVariable,
|
||||
deduce_output_dtype_by_name,
|
||||
ExprPrinter,
|
||||
Kernel,
|
||||
KernelArgs,
|
||||
OptimizationContext,
|
||||
@ -232,212 +232,12 @@ class CppCSEVariable(CSEVariable):
|
||||
return itervar in self.dependent_itervars
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr):
|
||||
return (
|
||||
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
|
||||
)
|
||||
|
||||
def _print_Where(self, expr):
|
||||
c = self.paren(self.doprint(expr.args[0]))
|
||||
p = self.paren(self.doprint(expr.args[1]))
|
||||
q = self.paren(self.doprint(expr.args[2]))
|
||||
return f"{c} ? {p} : {q}"
|
||||
|
||||
def _print_ModularIndexing(self, expr):
|
||||
x, div, mod = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
if div != 1:
|
||||
div = self.paren(self.doprint(div))
|
||||
if expr.is_integer:
|
||||
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
else:
|
||||
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
mod = self.paren(self.doprint(mod))
|
||||
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
x, div = expr.args
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
if expr.is_integer:
|
||||
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
|
||||
def _print_floor(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_FloorToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::floor({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_TruncToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::trunc({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})"
|
||||
|
||||
def _print_TruncToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::trunc({self._print(expr.args[0])})"
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"static_cast<double>({self._print(expr.args[0])})"
|
||||
|
||||
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
||||
# tickle though, as the inputs are typically positive (and if we can prove
|
||||
# they are positive, we will have used Mod instead, for which this codegen
|
||||
# is right).
|
||||
def _print_PythonMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_CMod(self, expr):
|
||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
||||
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
# TODO: This is only accurate up to 2**53
|
||||
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
||||
|
||||
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
||||
# use std::pow, that operates on floats
|
||||
def _print_PowByNatural(self, expr):
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
||||
def _print_FloatTrueDiv(self, expr):
|
||||
lhs, rhs = expr.args
|
||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
||||
|
||||
def _print_FloatPow(self, expr):
|
||||
base, exp = expr.args
|
||||
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
||||
|
||||
def _print_Pow(self, expr):
|
||||
# Uses float constants to perform FP div
|
||||
base, exp = expr.args
|
||||
base = self._print(base)
|
||||
|
||||
if exp == 0.5 or exp == -0.5:
|
||||
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
|
||||
if exp.is_integer:
|
||||
exp = int(exp)
|
||||
if exp > 0:
|
||||
r = "*".join([self.paren(base)] * exp)
|
||||
elif exp < 0:
|
||||
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
||||
else: # exp == 0
|
||||
r = "1.0"
|
||||
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
else:
|
||||
# TODO: float vs double
|
||||
return f"std::pow({base}, {float(exp)})"
|
||||
|
||||
def _print_Rational(self, expr):
|
||||
# Uses float constants to perform FP div
|
||||
if expr.q == 1:
|
||||
r = f"{expr.p}"
|
||||
else:
|
||||
r = f"{expr.p}.0/{expr.q}.0"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_CeilToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
r = f"std::ceil({self._print(expr.args[0])})"
|
||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||
|
||||
def _print_Min(self, expr):
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
return f"std::min({il})"
|
||||
|
||||
def _print_Max(self, expr):
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
return f"std::max({il})"
|
||||
|
||||
def _print_Abs(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::abs({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::cos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::cosh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_acos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::acos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::sin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::sinh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_asin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::asin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::tan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::tanh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_atan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"std::atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
||||
return f"std::sqrt({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundToInt(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
# TODO: dispatch to llrint depending on index type
|
||||
return f"std::lrint({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
assert len(expr.args) == 2
|
||||
number, ndigits = expr.args
|
||||
if number.is_integer:
|
||||
# ndigits < 0 should have been filtered by the sympy function
|
||||
assert ndigits < 0
|
||||
raise ValueError(
|
||||
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
||||
)
|
||||
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
|
||||
|
||||
def _print_BooleanTrue(self, expr):
|
||||
return "true"
|
||||
|
||||
def _print_BooleanFalse(self, expr):
|
||||
return "false"
|
||||
class CppPrinter(_CppPrinter):
|
||||
def doprint(self, expr, *, simplify: bool = True, p=True):
|
||||
# TODO: why are people passing strings to the printer here :think:
|
||||
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
||||
expr = V.graph.sizevars.simplify(expr)
|
||||
return super().doprint(expr)
|
||||
|
||||
|
||||
# A function to print, useful for printing sympy symbols.
|
||||
|
||||
Reference in New Issue
Block a user