Files
pytorch/torch/_inductor/codegen/cpp_utils.py
PyTorch MergeBot 5ae9daa4a2 Revert "[AOTI] support freezing for MKLDNN (#124350)"
This reverts commit 654afb6f3ae3ddbd926a753f9af95a6f6e22131c.

Reverted https://github.com/pytorch/pytorch/pull/124350 on behalf of https://github.com/clee2000 due to Seems to have broken inductor/test_aot_inductor.py::AOTInductorTestNonABICompatibleCpu::test_freezing_non_abi_compatible_cpu 654afb6f3a https://github.com/pytorch/pytorch/actions/runs/9224838183/job/25382780192 ([comment](https://github.com/pytorch/pytorch/pull/124350#issuecomment-2129889809))
2024-05-24 16:03:07 +00:00

369 lines
13 KiB
Python

import contextlib
import copy
import math
from collections import namedtuple
from typing import Dict, List
from unittest.mock import patch
import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import ir
from ..virtualized import V
from .common import ExprPrinter, Kernel
DTYPE_TO_CPP = {
torch.float32: "float",
torch.float64: "double",
torch.float16: "half",
torch.int64: "int64_t",
torch.int32: "int",
torch.int16: "short",
torch.int8: "signed char",
torch.uint64: "uint64_t",
torch.uint32: "unsigned int",
torch.uint16: "unsigned short",
torch.uint8: "unsigned char",
torch.bool: "bool",
torch.bfloat16: "bfloat16",
torch.complex64: "complex64",
torch.float8_e4m3fn: "float8_e4m3fn",
torch.float8_e5m2: "float8_e5m2",
}
DTYPE_TO_ATEN = {
torch.float32: "at::kFloat",
torch.float64: "at::kDouble",
torch.float16: "at::kHalf",
torch.int64: "at::kLong",
torch.int32: "at::kInt",
torch.int16: "at::kShort",
torch.int8: "at::kChar",
torch.uint64: "at::kUInt64",
torch.uint32: "at::kUInt32",
torch.uint16: "at::kUInt16",
torch.uint8: "at::kByte",
torch.uint32: "at::kUInt32",
torch.uint64: "at::kUInt64",
torch.bool: "at::kBool",
torch.bfloat16: "at::kBFloat16",
torch.complex32: "at::kComplexHalf",
torch.complex64: "at::kComplexFloat",
torch.complex128: "at::kComplexDouble",
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
torch.float8_e5m2: "at::kFloat8_e5m2",
torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
}
DEVICE_TO_ATEN = {
"cpu": "at::kCPU",
"cuda": "at::kCUDA",
}
INDEX_TYPE = "long"
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return f"{int(expr)}L"
def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
p = self.paren(self.doprint(expr.args[1]))
q = self.paren(self.doprint(expr.args[2]))
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
if div != 1:
div = self.paren(self.doprint(div))
if expr.is_integer:
x = f"c10::div_floor_integer({x}, {div})"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.paren(self.doprint(mod))
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
if expr.is_integer:
return f"c10::div_floor_integer({x}, {div})"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Trunc(self, expr):
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Pow(self, expr):
# Uses float constants to perform FP div
base, exp = expr.args
base = self._print(base)
if exp == 0.5 or exp == -0.5:
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = "*".join([self.paren(base)] * exp)
elif exp < 0:
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr):
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min({args[0]}, {args[1]})"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max({args[0]}, {args[1]})"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return f"std::sqrt({self._print(expr.args[0])})"
def _print_Round(self, expr):
assert len(expr.args) == 1
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
# A function to print, useful for printing sympy symbols.
cexpr = CppPrinter().doprint
def cexpr_index(index):
return f"static_cast<{INDEX_TYPE}>({cexpr(index)})"
def value_to_cpp(value, cpp_type):
if value == float("-inf"):
return f"-std::numeric_limits<{cpp_type}>::infinity()"
elif value == float("inf"):
return f"std::numeric_limits<{cpp_type}>::infinity()"
elif isinstance(value, bool):
return f"static_cast<{cpp_type}>({str(value).lower()})"
elif math.isnan(value):
return f"std::numeric_limits<{cpp_type}>::quiet_NaN()"
else:
return f"static_cast<{cpp_type}>({repr(value)})"
class LocalBufferScope:
"""
This class creates a context that helps to generate code involving Inductor IR with
function local buffers. These buffers are constructed during the codegen process and
are used to store intermediate results such as local accumulators. We do not want to
add them to `V.graph` since they are not global and we do not want to add them as
function arguments either. So we patch the codegen processes under this scope to support
these buffers without exposure to the outside world.
"""
def __init__(self, kernel: Kernel):
self.kernel = kernel
self.exit_stack = contextlib.ExitStack()
self.local_buffers: Dict[str, ir.Buffer] = {}
def __enter__(self):
self.exit_stack.__enter__()
original_get_dtype = V.graph.get_dtype
def get_dtype(name):
if name in self.local_buffers:
return self.local_buffers[name].get_dtype()
return original_get_dtype(name)
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
original_input = self.kernel.args.input
def input(name):
if name in self.local_buffers:
return name
return original_input(name)
self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input))
original_output = self.kernel.args.output
def output(name):
if name in self.local_buffers:
return name
return original_output(name)
self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.local_buffers.clear()
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
def add_local_buffer(self, buffer: ir.Buffer):
assert buffer.get_name() not in self.local_buffers
self.local_buffers[buffer.get_name()] = buffer
def localize_buffer(
self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
) -> List[ir.IRNode]:
"""
Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
the loads and stores are redirected to `local_buf`. This helps the fused loops to
work on smaller-sized local buffers for better data locality.
The `local_buf` should already be registered in the local scope and the data access
is assumed to be contiguous with the same order as the `global_buf`.
"""
assert local_buf.get_name() in self.local_buffers
assert len(global_buf.get_size()) == len(local_buf.get_size())
assert len(nodes) > 0
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
def __init__(self, inner):
super().__init__(inner)
def localize(self, name: str, index: sympy.Expr):
if name == global_buf.get_name():
name = local_buf.get_name()
index_vars = sorted(
[
s
for s in index.free_symbols
if symbol_is_type(s, SymT.INDEX)
],
key=str,
)
index = local_buf.layout.make_indexer()(index_vars)
return name, index
def load(self, name: str, index: sympy.Expr):
return self._inner.load(*self.localize(name, index))
def store(self, name, index, value, mode=None):
return self._inner.store(*self.localize(name, index), value, mode)
def store_reduction(self, name, index, value):
return self._inner.store_reduction(*self.localize(name, index), value)
def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
assert isinstance(loops, ir.Loops)
new_loops = copy.copy(loops)
if isinstance(node, ir.ComputedBuffer):
new_node = ir.ComputedBuffer(
node.get_name(), node.get_layout(), new_loops
)
else:
new_node = new_loops # type: ignore[assignment]
new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
return new_node
def inner_fn_wrapper(inner_fn):
def inner(index):
with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
return inner_fn(index)
return inner
return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]