Files
pytorch/torch/_inductor/codegen/common.py
PyTorch MergeBot dbba1d4bf5 Revert "Some minor type stub improvements (#118529)"
This reverts commit c978f38bd4aedeff4ee9ae693349217daea01412.

Reverted https://github.com/pytorch/pytorch/pull/118529 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/118529#issuecomment-1922362331))
2024-02-01 22:18:36 +00:00

1433 lines
47 KiB
Python

import contextlib
import dataclasses
import functools
import itertools
import logging
import operator
import re
from collections import namedtuple
from itertools import chain
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
import sympy
from sympy.printing.printer import Printer
import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from .. import config, metrics
from ..utils import (
DeferredLineBase,
do_bench,
free_symbol_startswith,
IndentedBuffer,
sympy_dot,
sympy_index_symbol,
sympy_subs,
unique,
)
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from ..ir import TensorBox
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
def data_type_logger(msg):
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Data type propagation: %s", msg)
TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype", "check_alignment"])
SizeArg = namedtuple("SizeArg", ["name", "expr"])
DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"])
device_codegens: Dict[str, DeviceCodegen] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
device: str, device_scheduling: type, device_wrapper_codegen: type
):
device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)
def get_scheduling_for_device(device: str):
return device_codegens[device].scheduling if device in device_codegens else None
def get_wrapper_codegen_for_device(device: str):
return (
device_codegens[device].wrapper_codegen if device in device_codegens else None
)
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
from ..ir import FlexibleLayout
# added contiguous index prevents reordering
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
def get_device_op_overrides(device: str):
assert isinstance(device, str)
if device == "cuda":
from .cuda.device_op_overrides import CUDADeviceOpOverrides
return CUDADeviceOpOverrides()
return DeviceOpOverrides()
@functools.lru_cache(None)
def boolean_ops():
return (
"is_inf",
"is_nan",
"bitwise_xor",
"logical_not",
"signbit",
"le",
"lt",
"ge",
"gt",
"eq",
"ne",
)
DTYPE_TO_COMPUTATION_DTYPE = {
torch.bfloat16: torch.float,
torch.float16: torch.float,
**{
dtype: dtype
for dtype in [
torch.bool,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
]
},
}
class DataTypePropagation:
def __init__(self, body) -> None:
self.body = body
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
"root": body.root_block.graph
}
for k, v in body.subblocks.items():
self.graphs[k] = v.graph
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
inputs = node.all_input_nodes
input_nodes = [
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
]
if len(input_nodes) == 0:
return None
all_input_nodes_propogated = all(
OptimizationContext.key in n.meta
and n.meta[OptimizationContext.key].dtype is not None
for n in input_nodes
)
if not all_input_nodes_propogated:
return None
return functools.reduce(
torch.promote_types,
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
)
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
sub_graph = self.graphs[node.target]
dtype = self.propagate_graph(sub_graph)
assert dtype
return dtype
def deduce_node_dtype(self, node: torch.fx.Node):
if node.target in boolean_ops():
return torch.bool
if node.op == "placeholder":
return None
if node.target == "output":
# we can infer output node if it only have 1 arg
if len(node.args) != 1:
return None
if node.target in (
"to_dtype",
"index_expr",
):
return node.args[-1]
if node.target in (
"rand",
"randn",
):
return torch.float
if node.target in (
"get_index",
"index_expr",
):
return torch.int64
if node.target in (
"load",
"store",
"store_reduction",
):
buf_name = node.args[1]
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
assert isinstance(node.target, str)
if node.target == "reduction":
return node.args[1]
if node.target == "constant":
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)
return self.deduce_node_dtype_by_inputs(node)
def propagate_graph(self, graph: torch.fx.Graph):
assert graph.nodes
graph_dtype = None
# For masked_subblock, we use output's dtype to represent
# the dtype of this subgraph. For other cases, graph_dtype
# might be None
for node in graph.nodes:
if OptimizationContext.key in node.meta:
opt_ctx = node.meta[OptimizationContext.key]
else:
opt_ctx = OptimizationContext()
opt_ctx.dtype = self.deduce_node_dtype(node)
node.meta[OptimizationContext.key] = opt_ctx
if node.target == "output":
graph_dtype = opt_ctx.dtype
return graph_dtype
def propagate(self):
self.propagate_graph(self.graphs["root"])
@classmethod
def propagate_loopbody(cls, body):
return cls(body).propagate()
@classmethod
def propagate_scheduler_node(cls, node):
from ..ir import LoopBody
from ..scheduler import SchedulerNode
assert isinstance(node, SchedulerNode)
assert isinstance(node._body, LoopBody)
DataTypePropagation.propagate_loopbody(node._body)
class ExprPrinter(Printer):
@staticmethod
def paren(string):
def all_in_parens(string):
if string[0] != "(" or len(string) < 2:
return False
count = 1
for i, char in enumerate(string[1:]):
if char == "(":
count += 1
elif char == ")":
count -= 1
if count == 0 and i != len(string) - 2:
return False
assert count == 0
return True
if (
isinstance(string, CSEVariable)
or re.match(r"^[a-z0-9_.]+$", string, re.I)
or re.match(r"^\([^)]*\)$", string, re.I)
or string == ""
):
return string
# don't put extra parens for strings that are already wrapped in parens
if all_in_parens(string):
return string
return f"({string})"
def _print_Infinity(self, expr):
return "math.inf"
def _print_NegativeInfinity(self, expr):
return "-math.inf"
def _print_Relational(self, expr):
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mul(self, expr):
return "*".join(map(self.paren, map(self._print, expr.args)))
def _print_Add(self, expr):
return " + ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_FloorDiv(self, expr):
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr)
def _print_GreaterThan(self, expr):
# GreaterThan: >=
# StrictlyGreaterThan: >
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))
def _print_align(self, expr):
assert len(expr.args) == 1
return f"align({self._print(expr.args[0])})"
class PythonPrinter(ExprPrinter):
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
mod = self.paren(self.doprint(mod))
if div != "1":
x = f"({x} // {div})"
return f"{x} % {mod}"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
def _helper_sqrt(self, expr):
return f"math.sqrt({self._print(expr)})"
def _print_Pow(self, expr):
# Pow() confuses triton
base, exp = expr.args
# NB: Remember this is sizevar computation! You don't typically
# expect to have to do floating point computation including exponents
# in sizevar compute. Instead of adding support for floating
# point pow, you should make upstream retranslate the Sympy expression
# into Tensor expressions earlier and do that instead.
if exp == 0.5:
return self._helper_sqrt(base)
elif exp == -0.5:
return "1/" + self._helper_sqrt(base)
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
if exp > 0:
return "*".join([self.paren(base)] * exp)
elif exp < 0:
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
return "1"
def _print_floor(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_cos(self, expr):
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_cosh(self, expr):
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_acos(self, expr):
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_sin(self, expr):
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_sinh(self, expr):
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_asin(self, expr):
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_tan(self, expr):
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_tanh(self, expr):
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_atan(self, expr):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_Round(self, expr):
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
def __getattr__(self, item):
return getattr(self._parent, item)
@staticmethod
def identity(value):
# used to trigger cse
return value
@staticmethod
def constant(value, dtype):
return repr(value)
@staticmethod
def reciprocal(x):
return ops.truediv("1", x)
@staticmethod
def square(x):
return ops.mul(x, x)
@staticmethod
def bitwise_not(x):
return f"~{ExprPrinter.paren(x)}"
@staticmethod
def logical_not(a):
return f"{ExprPrinter.paren(a)} == 0"
@staticmethod
def bitwise_and(x, y):
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_or(x, y):
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_xor(x, y):
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_left_shift(x, y):
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_right_shift(x, y):
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
@staticmethod
def load_seed(name, offset):
return ops.load(name, sympy.Integer(offset))
# Use mypy to check protocol implemented correctly
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
return h
class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError()
def set_device(self, device_idx):
raise NotImplementedError()
def synchronize(self):
raise NotImplementedError()
def device_guard(self, device_idx):
raise NotImplementedError()
class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
def __init__(self, name, line):
super().__init__(line)
self.name = name
assert not isinstance(line, DeferredLineBase)
def __call__(self):
if all(
self.name not in x
for x in (
V.graph.removed_buffers,
V.kernel.removed_buffers,
V.graph.inplaced_to_remove,
V.kernel.inplaced_to_remove,
)
):
return self.line
return None
def _new_line(self, line):
return DeferredLine(self.name, line)
class BracesBuffer(IndentedBuffer):
def indent(self, offset=1):
@contextlib.contextmanager
def ctx():
for _ in range(offset):
self.writeline("{")
self._indent += 1
for _ in range(-offset):
self._indent -= 1
self.writeline("}")
yield
for _ in range(-offset):
self.writeline("{")
self._indent += 1
for _ in range(offset):
self._indent -= 1
self.writeline("}")
return ctx()
class InplacedBuffer(NamedTuple):
inner_name: str
other_names: List[str]
class KernelArgs:
@staticmethod
def _lookup(prefix, odict, name):
assert isinstance(name, (str, sympy.Symbol))
if name not in odict:
odict[name] = f"{prefix}{len(odict)}"
return odict[name]
def __init__(self, sizevars=None):
self.input_buffers = dict()
self.output_buffers = dict()
self.inplace_buffers = dict()
self.sizevars = sizevars or dict()
def __repr__(self):
return "KernelArgs({})".format(
", ".join(
map(
repr,
[
self.input_buffers,
self.output_buffers,
self.inplace_buffers,
self.sizevars,
],
)
)
)
def _buffer_is_marked_removed(self, name):
return isinstance(name, str) and name.startswith("REMOVED")
def input(self, name):
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.output_buffers:
return self.output_buffers[name]
if name in self.inplace_buffers:
return self.inplace_buffers[name].inner_name
if name.startswith("seed"):
return self._lookup("seed", self.input_buffers, name)
return self._lookup("in_ptr", self.input_buffers, name)
def output(self, name):
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.inplace_buffers:
return self.inplace_buffers[name].inner_name
return self._lookup("out_ptr", self.output_buffers, name)
def make_inplace(self, input_name, output_name):
assert output_name not in self.inplace_buffers
if input_name in self.inplace_buffers:
buf = self.inplace_buffers[input_name]
buf.other_names.append(output_name)
self.inplace_buffers[output_name] = buf
else:
buf = InplacedBuffer(
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
[input_name, output_name],
)
self.inplace_buffers[input_name] = buf
self.inplace_buffers[output_name] = buf
def seed_offset(self, name, value):
if value in self.sizevars:
return self.sizevars[value]
if name in self.sizevars.values():
name = (
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
)
self.sizevars[value] = name
return name
def size(self, name):
if str(name) == "seed":
self.sizevars["seed"] = "seed"
return "seed"
return self._lookup("ks", self.sizevars, name)
def call_names(self):
return chain(
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
)
def wrap_ptr_arg(self, buf, dtype):
return buf
def wrap_size_arg(self, size):
return str(size)
def cpp_argdefs(self):
from .cpp import DTYPE_TO_CPP, INDEX_TYPE
call_args = []
arg_defs = []
arg_types = []
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
outer = inplaced.other_names[-1]
inner = inplaced.inner_name
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.input_buffers.items():
if outer in self.inplace_buffers:
continue
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"const {cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"const {cpp_dtype}*")
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
arg_defs.append(f"const {INDEX_TYPE} {inner}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
return arg_defs, call_args, arg_types
def python_argdefs(self):
arg_defs = []
call_args = []
precompile_args: List[Union[TensorArg, SizeArg]] = []
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
arg_defs.append(inplaced.inner_name)
call_args.append(inplaced.other_names[-1])
precompile_args.append(
TensorArg(
inplaced.inner_name,
inplaced.other_names[-1],
V.graph.get_dtype(inplaced.other_names[-1]),
True,
)
)
for outer, inner in chain(
self.input_buffers.items(), self.output_buffers.items()
):
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
arg_defs.append(inner)
call_args.append(outer)
precompile_args.append(
TensorArg(inner, outer, V.graph.get_dtype(outer), True)
)
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
call_args.append(outer)
precompile_args.append(SizeArg(inner, outer))
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
return arg_defs, call_args, precompile_args
def aliases(self):
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
for other in inplaced.other_names:
if (
other in V.graph.inplaced_to_remove
or other in V.kernel.inplaced_to_remove
):
continue
if other in self.input_buffers:
yield self.input_buffers[other], inplaced.inner_name
if other in self.output_buffers:
yield self.output_buffers[other], inplaced.inner_name
def is_removed(self, name):
def _is_removed(name, buffers):
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
return _is_removed(name, self.output_buffers) and _is_removed(
name, self.inplace_buffers
)
# Includes inplace buffers, excludes removed buffers. Essentially,
# after you do a call into this kernel, which buffers actually contain
# updated data? Modeled off of python_argdefs.
def live_output_buffers(self):
live_outs = set()
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
live_outs.add(inplaced.other_names[-1])
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
live_outs.add(outer)
return live_outs
class CSEVariable:
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
To do so, the backends can simply overload `Kernel.create_cse_var`
The "CSEVariable.update_on_args" method gives you a hook for annotations
See example of TritonCSEVariable in triton.py
"""
def __init__(self, name, bounds: ValueRanges):
assert isinstance(bounds, ValueRanges)
self.name = name
self.bounds = bounds
def __str__(self):
return self.name
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other) -> bool:
return type(other) == type(self) and other.name == self.name
def update_on_args(self, name, args, kwargs):
pass
class CppWrapperKernelArgs(KernelArgs):
def wrap_ptr_arg(self, buf, dtype):
from .cpp import DTYPE_TO_CPP
if config.aot_inductor.abi_compatible:
# In the abi_compatible model, we just return the buf here.
# We will form correct call args later in wrapper.generate_kernel_all.
return buf
else:
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
def wrap_size_arg(self, size):
return f"{size}"
class CSE:
"""Common subexpression elimination"""
def __init__(
self,
prefix="",
suffix="",
name_prefix="tmp",
iter_buffers=None,
store_cache=None,
reduction_cache=None,
varname_map=None,
):
self.prefix = prefix
self.suffix = suffix
self.cache = {}
self.name_prefix = name_prefix
self.store_cache = store_cache or {}
self.reduction_cache = reduction_cache or {}
self.iter_buffer_ids = iter_buffers or itertools.count()
self.invalidated_stores = set()
self.varname_map = varname_map or {}
def invalidate(self, keep_vars: Set[str]):
for name, tmp in list(self.store_cache.items()):
if tmp not in keep_vars:
del self.store_cache[name]
self.invalidated_stores.add(name)
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
def clone(self):
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
return CSE(
prefix=self.prefix,
suffix=self.suffix,
name_prefix=self.name_prefix,
iter_buffers=self.iter_buffer_ids,
store_cache=self.store_cache,
varname_map=self.varname_map,
)
def generate(
self,
buffer: IndentedBuffer,
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
*,
bounds: ValueRanges = ValueRanges.unknown(),
write=True,
assignment=True,
) -> CSEVariable:
if isinstance(expr, OpsValue):
expr = expr.value
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
assert write or assignment
if isinstance(expr, CSEVariable):
# If the expressions were always created with all the information, we could
# assert expr.bounds == bounds, but sometimes the expression is created
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
expr.bounds = expr.bounds.tighten(bounds)
return expr
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
var = self.cache.get(cache_key, None)
if not var:
var = self.newvar(bounds) if assignment else None
self.cache[cache_key] = var
if write:
if V.kernel.current_node:
V.kernel.current_node.codegen_originating_info(
buffer, only_once=True
)
if isinstance(expr, IndentedBuffer):
if assignment:
buffer.writeline(f"{self.prefix}{var} =")
buffer.splice(expr)
buffer.writeline(self.suffix)
else:
if assignment:
line = f"{self.prefix}{var} = {expr}{self.suffix}"
else:
line = f"{expr}{self.suffix}"
buffer.writeline(line)
else:
var.bounds = var.bounds.tighten(bounds)
return var
def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable:
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
var = V.kernel.create_cse_var(var_name, bounds)
self.varname_map[var_name] = var
return var
class IndirectAssertLine(DeferredLineBase):
def __init__(self, line, assert_fn, var, mask, size_map):
self.var = var
self.mask = mask
self.line = line
self.assert_fn = assert_fn
self.size_map = size_map
def __call__(self):
size, size_str = self.size_map[(self.var, self.mask)]
# We assert if we've not been able to prove the bound
assert_min = (self.var.bounds.lower >= 0) != sympy.true
assert_max = (self.var.bounds.upper < size) != sympy.true
# FooBar interview question
if not (assert_min or assert_max):
return None
elif assert_min and assert_max:
# The conditions need to be in parens because of Python's operator precedence.
# It'd be less error-prone to use and/or/not, which is suported by triton
cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
cond_print = f"0 <= {self.var} < {size_str}"
elif assert_min:
cond = f"0 <= {self.var}"
cond_print = cond
else:
assert assert_max
cond = f"{self.var} < {size_str}"
cond_print = cond
if self.mask:
cond = f"({cond}) | ~{self.mask}"
return self.line.format(
assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
)
def _new_line(self, line):
return IndirectAssertLine(
line, self.assert_fn, self.var, self.mask, self.size_map
)
class CodeGen:
def __init__(self):
super().__init__()
self.exit_stack = contextlib.ExitStack()
def __enter__(self):
self.exit_stack.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
class Kernel(CodeGen):
newvar_prefix = ""
suffix = ""
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
# TODO: these look dead, but with all the getattr it's hard to tell...
load_format: None = None
store_format: None = None
def __init__(self, args=None, increase_kernel_count=True):
super().__init__()
if increase_kernel_count:
metrics.generated_kernel_count += 1
self.args = args or KernelArgs()
self.loads = IndentedBuffer()
self.compute = IndentedBuffer()
self.stores = IndentedBuffer()
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
self.must_keep_buffers = set()
self.store_buffer_names = set()
self._load_mask = None
# set in set_current_node
self.current_node = None
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None
# Upper bounds for indirect_indexing and their str representation
# NB: None, None is never stored in map, but it is the assumed
# "not set" value for the dict
self.indirect_max_sizes: Dict[
Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]]
] = {}
self.removed_buffers = set()
self.inplaced_to_remove = set()
# key: the buffer to write
# value: the buffer to read and whose memory can be reused for
# the buffer specified by key
self.inplace_update_buffers = dict()
# Set minimum number of elements processed per thread.
self.min_elem_per_thread = 1
self.kernel_name = None
@contextlib.contextmanager
def set_current_node(self, node):
prior = self.current_node
self.current_node = node
self.node_to_bounds = node._body.bounds().get_bounds()
try:
yield
finally:
self.current_node = prior
@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
if cb is None:
cb = lb
loads = self.loads
compute = self.compute
stores = self.stores
cse = self.cse
self.loads = lb
self.compute = cb
self.stores = sb
self.cse = cse.clone()
try:
yield
finally:
self.loads = loads
self.compute = compute
self.stores = stores
self.cse = cse
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
raise NotImplementedError()
def indirect_load(self, name: str, index: sympy.Expr):
"""A load the depends on an index we have read"""
prior = self.loads
try:
# put the load in the compute section as it might have deps
self.loads = self.compute
return self.load(name, index)
finally:
self.loads = prior
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
raise NotImplementedError()
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
raise NotImplementedError()
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
raise NotImplementedError()
def scan(
self,
dtype: torch.dtype,
combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
value: CSEVariable,
init: int,
) -> CSEVariable:
raise NotImplementedError()
def bucketize(
self,
values: CSEVariable,
offsets_name: str,
offsets_size: sympy.Expr,
indexing_dtype: torch.dtype,
right: bool,
) -> CSEVariable:
"""
See [Note: Inductor bucketize op]
"""
raise NotImplementedError()
@property
def assert_function(self) -> str:
raise NotImplementedError()
def index_to_str(self, index: sympy.Expr) -> str:
raise NotImplementedError()
def __enter__(self):
# TODO: hoist this to top level
class CSEProxy:
self.name = "CSEProxy"
@staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs):
# TritonTemplateKernel has no current_node
buf_bounds = ValueRanges.unknown()
if hasattr(V.interpreter, "current_node"):
fx_node = V.interpreter.current_node
assert isinstance(self.node_to_bounds, dict)
buf_bounds = self.node_to_bounds.get(
fx_node, ValueRanges.unknown()
)
csevar = self.cse.generate(
self.compute,
getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type]
bounds=buf_bounds,
)
csevar.update_on_args(name, args, kwargs)
return csevar
return inner
@staticmethod
def indirect_indexing(
var: CSEVariable, size: sympy.Expr, check: bool = True
):
# Skip CSE since this doesn't return an expression
if var.bounds.lower < 0: # type: ignore[operator]
new_bounds = ValueRanges.unknown()
if var.bounds != ValueRanges.unknown() and isinstance(
size, sympy.Number
):
# Take the negative part of the bound and add size to it
# Then take union of that and the positive part
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
neg = var.bounds & ValueRanges(-sympy.oo, -1)
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
# We don't have a good way of representing the empty range
if var.bounds.upper >= 0: # type: ignore[operator]
pos = var.bounds & ValueRanges(0, sympy.oo)
new_bounds = new_bounds | pos
stm = ops.add(var, self.rename_indexing(size))
# Mixed negative and non-negative
if var.bounds.upper >= 0: # type: ignore[operator]
lt = ops.lt(var, "0")
stm = ops.where(lt, stm, var)
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
new_var.update_on_args("index_wrap", (var,), {})
var = new_var
if self.generate_assert(check):
mask = self.load_mask(var)
# An assertion line may have been written already, if so just
# update the max size.
map_key = (var, mask)
existing_size, _ = self.indirect_max_sizes.get(
map_key, (None, None)
)
if existing_size is not None:
size = sympy.Min(size, existing_size)
else:
line = (
'{assert_fn}({cond}, "index out of bounds: {cond_print}")'
)
self.compute.writeline(
IndirectAssertLine(
line,
self.assert_function,
var,
mask,
self.indirect_max_sizes,
)
)
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
return sympy_index_symbol(str(var))
@staticmethod
def load(name: str, index: sympy.Expr) -> CSEVariable:
if name in self.cse.invalidated_stores:
# A load from an invalidated store requires us to
# keep the actual buffer around
V.kernel.must_keep_buffers.add(name)
if free_symbol_startswith(index, "tmp"):
return self.indirect_load(name, index)
store_cache = self.cse.store_cache
if name in store_cache:
return store_cache[name]
return self.load(name, index)
@staticmethod
def store(
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
self.store_buffer_names.add(name)
if mode is None:
self.cse.store_cache[name] = value
if self.current_node:
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if name not in V.graph.removed_buffers:
return self.store(name, index, value, mode=mode)
else:
return None # type: ignore[return-value]
@staticmethod
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
self.store_buffer_names.add(name)
self.cse.store_cache[name] = value
if self.current_node:
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if name not in V.graph.removed_buffers:
return self.store_reduction(name, index, value)
@staticmethod
def reduction(
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
return self.reduction(dtype, src_dtype, reduction_type, value)
@staticmethod
def scan(
dtype: torch.dtype,
combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
value: CSEVariable,
init: int,
) -> CSEVariable:
return self.scan(dtype, combine_fn, value, init)
@staticmethod
def bucketize(
values: CSEVariable,
offsets_name: str,
offsets_size: sympy.Expr,
indexing_dtype: torch.dtype,
right: bool,
) -> CSEVariable:
"""
[Note: Inductor bucketize op]
Given values (tensor) and offsets_name (reference to the name of a 1D
tensor), calculate the bucket that each value belongs to.
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
Offsets must be non-decreasing or the result is undefined.
"""
return self.bucketize(
values, offsets_name, offsets_size, indexing_dtype, right
)
# Use mypy to check protocol implemented correctly
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
return h
super().__enter__()
assert self.overrides
parent_handler = self.overrides(V.get_ops_handler())
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Note that V.graph.scheduler can be None when codegening triton template
kernels.
"""
if V.graph.scheduler:
V.graph.scheduler.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb)
def generate_assert(self, check):
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
def load_mask(self, var) -> str:
# only the triton kernel requires mask
return ""
def rename_indexing(self, index) -> sympy.Expr:
# adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)):
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
index = V.graph.sizevars.simplify(index)
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
replacements = {
x: self.args.size(x)
for x in sorted_symbols
if x.name.startswith(("s", "u", "ps"))
or (x.name.startswith("i") and not x.name.startswith("idx"))
}
return sympy_subs(index, replacements)
def create_cse_var(self, *args, **kwargs):
return CSEVariable(*args, **kwargs)
@dataclasses.dataclass
class OptimizationContext:
key: ClassVar[str] = "opt_ctx"
# Load value as mask
is_load_as_mask: bool = False
dtype: Optional[torch.dtype] = None
ops_name: str = ""
# Load uint8 value as float32
is_load_uint8_as_float: bool = False
@functools.lru_cache(None)
def jinja2_env():
try:
import jinja2
return jinja2.Environment(
undefined=jinja2.StrictUndefined,
)
except ImportError:
return None
class ChoiceCaller:
"""
Represents a possible choice used in autotune_process.py.
During autotuning, self.benchmark() is first called to get benchmark result,
and if this choice is selected, self.output_node() is called to get the output_node.
Children classes: TritonTemplateCaller, CUDATemplateCaller.
"""
def __init__(self, name, input_nodes, layout):
super().__init__()
self.name = name
self.layout = layout
self.input_nodes = input_nodes
def benchmark(self, *args, out) -> float:
algo = self.to_callable()
return do_bench(lambda: algo(*args, out=out))
def call_name(self) -> str:
raise NotImplementedError()
def to_callable(self):
raise NotImplementedError()
def hash_key(self) -> str:
raise NotImplementedError()
def output_node(self) -> "TensorBox":
raise NotImplementedError()
class KernelTemplate:
"""
Base class for defining kernel templates.
Children classes: TritonTemplate, CUDATemplate
"""
@staticmethod
def _template_from_string(source):
env = jinja2_env()
if env is not None:
return env.from_string(source)
return None
@staticmethod
def _fake_get_dtype(fake_out):
_get_dtype_real = V.graph.get_dtype
def get_dtype(name):
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
def __init__(self, name: str):
self.name = name
def maybe_append_choice(self, choices, **kwargs):
"""
Maybe generates a new ChoiceCaller and appends it into existing choices.
choices: A list of ChoiceCallers.
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
"""
try:
choices.append(self.generate(**kwargs))
except NotImplementedError:
pass
def generate(self, **kwargs) -> ChoiceCaller:
"""
Generates a ChoiceCaller instance from the given arguments.
"""
raise NotImplementedError()