mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Adds NNC-like logging that is configured through an env var `TORCH_COMPILE_LOGS` Examples: `TORCH_LOGS="dynamo,guards" python script.py` - prints dynamo logs at level INFO with guards of all functions that are compiled `TORCH_LOGS="+dynamo,guards,graph" python script.py` - prints dynamo logs at level DEBUG with guards and graphs (in tabular) format of all graphs that are compiled [More examples with full output](https://gist.github.com/mlazos/b17f474457308ce15e88c91721ac1cce) Implementation: The implementation parses the log settings from the environment, finds any components (aot, dynamo, inductor) or other loggable objects (guards, graph, etc.) and generates a log_state object. This object contains all of the enabled artifacts, and a qualified log name -> level mapping. _init_logs then adds handlers to the highest level logs (the registered logs), and sets any artifact loggers to level DEBUG if the artifact is enabled. Note: set_logs is an alternative for manipulating the log_state, but if the environment contains TORCH_LOGS, the environment settings will be prioritized. Adding a new log: To add a new log, a dev should add their log name to torch._logging._registrations (there are examples there already). Adding a new artifact: To add a new artifact, a dev should add their artifact name to torch._logging._registrations as well. Additionally, wherever the artifact is logged, `torch._logging.getArtifactLogger(__name__, <artifact_name>)` should be used instead of the standard logging implementation. [design doc](https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#) Pull Request resolved: https://github.com/pytorch/pytorch/pull/94858 Approved by: https://github.com/ezyang
1873 lines
65 KiB
Python
1873 lines
65 KiB
Python
import collections
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import operator
|
|
from typing import Dict, List, Set
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
|
|
import torch._logging
|
|
from ..._dynamo import config as dynamo_config
|
|
from .. import config, ir, scheduler
|
|
from ..codecache import get_code_path
|
|
from ..ir import ReductionHint
|
|
from ..optimize_indexing import indexing_dtype_strength_reduction
|
|
from ..utils import (
|
|
get_fused_kernel_name,
|
|
get_kernel_metadata,
|
|
instance_descriptor,
|
|
next_power_of_2,
|
|
sympy_product,
|
|
sympy_subs,
|
|
sympy_symbol,
|
|
)
|
|
from ..virtualized import ops, V
|
|
|
|
from .common import (
|
|
CSEVariable,
|
|
DeferredLine,
|
|
free_symbol_startswith,
|
|
IndentedBuffer,
|
|
index_prevent_reordering,
|
|
Kernel,
|
|
OpOverrides,
|
|
PythonPrinter,
|
|
SizeArg,
|
|
TensorArg,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
|
|
|
|
|
def signature_of(arg):
|
|
from triton.runtime.jit import JITFunction
|
|
|
|
if isinstance(arg, TensorArg):
|
|
tye = JITFunction._type_of(arg.dtype)
|
|
if V.graph.is_unspec_arg(arg.buffer):
|
|
# had unwrapped 0d tensor as scalar
|
|
new_tye = tye.lstrip("*")
|
|
if new_tye in ["fp16", "bf16"]:
|
|
return "fp32"
|
|
else:
|
|
return new_tye
|
|
else:
|
|
return tye
|
|
if isinstance(arg, SizeArg):
|
|
return JITFunction._key_of(V.graph.sizevars.size_hint(arg.expr))
|
|
raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
|
|
|
|
|
|
def config_of(args):
|
|
from ..compile_fx import ALIGNMENT
|
|
|
|
def is_aligned(x):
|
|
if isinstance(x, TensorArg):
|
|
return x.buffer not in V.graph.unaligned_buffers
|
|
if isinstance(x, SizeArg):
|
|
return V.graph.sizevars.maybe_guard_multiple_of(x.expr, ALIGNMENT)
|
|
raise NotImplementedError(f"unhandled {type(x)}: {x}")
|
|
|
|
divisible_by_16 = [i for i, arg in enumerate(args) if is_aligned(arg)]
|
|
return instance_descriptor(tuple(divisible_by_16), ())
|
|
|
|
|
|
class TritonPrinter(PythonPrinter):
|
|
def _print_floor(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"tl.libdevice.floor({self.paren(self._print(expr.args[0]))})"
|
|
|
|
def _print_ceiling(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"tl.libdevice.ceil({self.paren(self._print(expr.args[0]))})"
|
|
|
|
|
|
texpr = TritonPrinter().doprint
|
|
pexpr = PythonPrinter().doprint
|
|
|
|
|
|
def triton_compute_type(dtype):
|
|
triton_type_name = str(dtype).split(".")[-1]
|
|
if triton_type_name == "bool":
|
|
triton_type_name = "int1"
|
|
if triton_type_name in ("float16", "bfloat16"):
|
|
# float16 math is done in float32 inside the kernel
|
|
triton_type_name = "float32"
|
|
return f"tl.{triton_type_name}"
|
|
|
|
|
|
def triton_constant(value):
|
|
if value == float("inf"):
|
|
return 'float("inf")'
|
|
elif value == float("-inf"):
|
|
return 'float("-inf")'
|
|
elif math.isnan(value):
|
|
return 'float("nan")'
|
|
return repr(value)
|
|
|
|
|
|
class TritonCSEVariable(CSEVariable):
|
|
def __init__(self, name):
|
|
super().__init__(name)
|
|
# We'll use this to track which masks the variable needs when used for indirect indexing
|
|
self.mask_vars: Set[str] = set()
|
|
|
|
def update_on_args(self, name, args, kwargs):
|
|
# When making a variable that is going to be used in indirect indexing
|
|
# if a where clause is used it should mean that the result is always a
|
|
# valid index, so you shouldn't include any of the dependent variables
|
|
# in the resulting load mask
|
|
if name == "where":
|
|
return
|
|
for arg in args:
|
|
if isinstance(arg, TritonCSEVariable):
|
|
self.mask_vars.update(arg.mask_vars)
|
|
|
|
|
|
class TritonOverrides(OpOverrides):
|
|
"""Map element-wise ops to Triton"""
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype: torch.dtype):
|
|
if dtype == torch.bool:
|
|
return f"({x} != 0)"
|
|
elif dtype == torch.uint8:
|
|
# to work around llvm uint conversion semantics
|
|
# that produces 0's for negative values
|
|
return f"{x}.to(tl.int8).to(tl.uint8)"
|
|
return f"{x}.to({triton_compute_type(dtype)})"
|
|
|
|
@staticmethod
|
|
def constant(value, dtype):
|
|
type_ = torch._prims_common.dtype_to_type(dtype)
|
|
return triton_constant(type_(value))
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return f"tl.abs({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_abs(x):
|
|
return f"tl.libdevice.abs({x})"
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return f"tl.exp({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_exp(x):
|
|
return f"tl.libdevice.exp({x})"
|
|
|
|
@staticmethod
|
|
def exp2(x):
|
|
return f"tl.libdevice.exp2({x})"
|
|
|
|
@staticmethod
|
|
def expm1(x):
|
|
return f"tl.libdevice.expm1({x})"
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return f"tl.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_sqrt(x):
|
|
return f"tl.libdevice.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def relu(x):
|
|
return ops.maximum("0", x)
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return f"tl.where({a} != {a}, {a}, tl.where({a} < {b}, {a}, {b}))"
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return f"tl.where({a} != {a}, {a}, tl.where({a} > {b}, {a}, {b}))"
|
|
|
|
@staticmethod
|
|
def where(a, b, c):
|
|
return f"tl.where({a}, {b}, {c})"
|
|
|
|
@staticmethod
|
|
def cos(x):
|
|
return f"tl.cos({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_cos(x):
|
|
return f"tl.libdevice.cos({x})"
|
|
|
|
@staticmethod
|
|
def sin(x):
|
|
return f"tl.sin({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_sin(x):
|
|
return f"tl.libdevice.sin({x})"
|
|
|
|
@staticmethod
|
|
def index_expr(expr, dtype):
|
|
return V.kernel.indexing(expr)[0]
|
|
|
|
@staticmethod
|
|
def masked(mask, body, other):
|
|
with V.kernel.mask_loads(mask) as new_mask:
|
|
result = body()
|
|
return ops.where(new_mask, result, triton_constant(other))
|
|
|
|
@staticmethod
|
|
def lgamma(x):
|
|
return f"tl.libdevice.lgamma({x})"
|
|
|
|
@staticmethod
|
|
def erf(x):
|
|
return f"tl.libdevice.erf({x})"
|
|
|
|
@staticmethod
|
|
def cosh(x):
|
|
return f"tl.libdevice.cosh({x})"
|
|
|
|
@staticmethod
|
|
def sinh(x):
|
|
return f"tl.libdevice.sinh({x})"
|
|
|
|
@staticmethod
|
|
def acos(x):
|
|
return f"tl.libdevice.acos({x})"
|
|
|
|
@staticmethod
|
|
def acosh(x):
|
|
return f"tl.libdevice.acosh({x})"
|
|
|
|
@staticmethod
|
|
def asin(x):
|
|
return f"tl.libdevice.asin({x})"
|
|
|
|
@staticmethod
|
|
def asinh(x):
|
|
return f"tl.libdevice.asinh({x})"
|
|
|
|
@staticmethod
|
|
def atan2(x, y):
|
|
return f"tl.libdevice.atan2({x}, {y})"
|
|
|
|
@staticmethod
|
|
def atan(x):
|
|
return f"tl.libdevice.atan({x})"
|
|
|
|
@staticmethod
|
|
def atanh(x):
|
|
return f"tl.libdevice.atanh({x})"
|
|
|
|
@staticmethod
|
|
def copysign(x, y):
|
|
return f"tl.libdevice.copysign({x}, {y})"
|
|
|
|
@staticmethod
|
|
def erfc(x):
|
|
return f"tl.libdevice.erfc({x})"
|
|
|
|
@staticmethod
|
|
def hypot(x, y):
|
|
return f"tl.libdevice.hypot({x}, {y})"
|
|
|
|
@staticmethod
|
|
def log10(x):
|
|
return f"tl.libdevice.log10({x})"
|
|
|
|
@staticmethod
|
|
def nextafter(x, y):
|
|
return f"tl.libdevice.nextafter({x}, {y})"
|
|
|
|
@staticmethod
|
|
def logical_and(a, b):
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def logical_or(a, b):
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand op
|
|
return f"tl.rand({seed}, {offset})"
|
|
|
|
@staticmethod
|
|
def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op
|
|
return f"tl.randn({seed}, {offset})"
|
|
|
|
@staticmethod
|
|
def rsqrt(x):
|
|
return f"tl.libdevice.rsqrt({x})"
|
|
|
|
@staticmethod
|
|
def log1p(x):
|
|
return f"tl.libdevice.log1p({x})"
|
|
|
|
@staticmethod
|
|
def tan(x):
|
|
return f"tl.libdevice.tan({x})"
|
|
|
|
@staticmethod
|
|
def tanh(x):
|
|
return f"tl.libdevice.tanh({x})"
|
|
|
|
@staticmethod
|
|
def sigmoid(x):
|
|
return f"tl.sigmoid({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_sigmoid(x):
|
|
return f"1/(1 + tl.libdevice.exp(-({x})))"
|
|
|
|
@staticmethod
|
|
def signbit(x):
|
|
# XX: This is wrong for the value -0.0 in floating point
|
|
return f"tl.libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
|
|
|
|
@staticmethod
|
|
def fmod(a, b):
|
|
return f"tl.libdevice.fmod({a}, {b})"
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return f"tl.libdevice.pow({a}, {b})"
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return f"tl.log({x})"
|
|
|
|
@staticmethod
|
|
def libdevice_log(x):
|
|
return f"tl.libdevice.log({x})"
|
|
|
|
@staticmethod
|
|
def isinf(x):
|
|
return f"tl.libdevice.isinf({x})"
|
|
|
|
@staticmethod
|
|
def isnan(x):
|
|
return f"tl.libdevice.isnan({x})"
|
|
|
|
@staticmethod
|
|
def round(x):
|
|
return f"tl.libdevice.nearbyint({x})"
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return f"tl.libdevice.floor({x})"
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
# See the comment in lowering.div_mode. a and b are integer type.
|
|
# Similar to div_floor_kernel_cuda in pytorch core.
|
|
# Notice that // in triton behaves as truncdiv instead of floordiv
|
|
quot = f"{a} // {b}"
|
|
rem = f"{a} % {b}"
|
|
return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"
|
|
|
|
@staticmethod
|
|
def trunc(x):
|
|
return f"tl.libdevice.trunc({x})"
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
# See the comment in lowering.div_mode. a and b are integer type.
|
|
# Notice that // in triton behaves as truncdiv instead of floordiv
|
|
return f"{a} // {b}"
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return f"tl.libdevice.ceil({x})"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class IterationRanges:
|
|
"""
|
|
Each range tree represents multiple sets of iteration indexing
|
|
in a single tiled dimension in the output kernel.
|
|
|
|
If you have two loops ranges one (4, 3, 2) and another (4, 6),
|
|
then the range tree will be:
|
|
4 (i0)
|
|
3 (i1) 6 (i3)
|
|
2 (i2)
|
|
Where i0 is shared between both loops, but then the split into
|
|
different indexing vars. All loop ranges must iterate over
|
|
the same number of elements.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
var_list: List[sympy.Symbol],
|
|
var_ranges: Dict[sympy.Symbol, sympy.Expr],
|
|
numel: sympy.Expr,
|
|
prefix: str,
|
|
*,
|
|
kernel: "Kernel",
|
|
divisor=sympy.Integer(1),
|
|
length=sympy.Integer(1),
|
|
):
|
|
super().__init__()
|
|
self.name = name
|
|
self.var_list = var_list
|
|
self.var_ranges = var_ranges
|
|
self.numel = numel
|
|
self.prefix = prefix
|
|
self.divisor = divisor
|
|
self.length = length
|
|
self.kernel = kernel
|
|
|
|
def is_loop(self):
|
|
return self.prefix == "r" and not self.kernel.persistent_reduction
|
|
|
|
|
|
class IterationRangesRoot(IterationRanges):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
numel: sympy.Expr,
|
|
prefix: str,
|
|
index: int,
|
|
kernel: "Kernel",
|
|
pid_cache=None,
|
|
):
|
|
if pid_cache is None:
|
|
pid_cache = {}
|
|
super().__init__(
|
|
name=name,
|
|
var_list=[],
|
|
var_ranges={},
|
|
numel=numel,
|
|
prefix=prefix,
|
|
kernel=kernel,
|
|
)
|
|
self.index = index
|
|
# Store all the nodes in one flat list
|
|
self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {}
|
|
# This is for re-ordering program ID in triton mm template
|
|
# pid_cache["tl.program_id(0)"] = pid_m
|
|
self.pid_cache: Dict[str, str] = pid_cache
|
|
|
|
def cache_clear(self):
|
|
for node in self.nodes.values():
|
|
node.cache_clear()
|
|
|
|
def lookup(self, divisor, length):
|
|
"""
|
|
Lookup a given RangeTreeEntry, creating it if needed
|
|
"""
|
|
if V.graph.sizevars.maybe_guard_equals(divisor * length, self.numel):
|
|
expr = ir.FloorDiv(sympy_symbol(f"{self.prefix}index"), divisor)
|
|
else:
|
|
expr = ir.ModularIndexing(
|
|
sympy_symbol(f"{self.prefix}index"), divisor, length
|
|
)
|
|
|
|
if expr not in self.nodes:
|
|
node = IterationRangesEntry(
|
|
f"{self.prefix}{next(V.kernel.iter_vars_count)}",
|
|
divisor,
|
|
length,
|
|
expr,
|
|
self,
|
|
)
|
|
V.kernel.range_tree_nodes[node.symbol()] = node
|
|
self.var_list.append(node.symbol())
|
|
self.var_ranges[node.symbol()] = length
|
|
self.nodes[expr] = node
|
|
return self.nodes[expr]
|
|
|
|
def construct_entries(self, lengths: List[sympy.Expr]):
|
|
divisor = sympy.Integer(1)
|
|
itervars = []
|
|
for length in reversed(lengths):
|
|
itervars.append(self.lookup(divisor, length))
|
|
divisor = divisor * length
|
|
return list(reversed(itervars))
|
|
|
|
def construct(self, lengths: List[sympy.Expr]):
|
|
return [e.symbol() for e in self.construct_entries(lengths)]
|
|
|
|
def vars_and_sizes(self, index: sympy.Expr):
|
|
"""Figure out vars from this tree used in index"""
|
|
nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
|
|
nodes = [n for n in nodes if n and n.prefix == self.prefix]
|
|
nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor))
|
|
divisor = sympy.Integer(1)
|
|
index_vars = []
|
|
sizes = []
|
|
|
|
def add(node):
|
|
nonlocal divisor
|
|
index_vars.append(node.symbol())
|
|
sizes.append(node.length)
|
|
divisor = divisor * node.length
|
|
|
|
for node in nodes:
|
|
if not V.graph.sizevars.maybe_guard_equals(node.divisor, divisor):
|
|
# fill in unused index var
|
|
add(self.lookup(divisor, ir.FloorDiv(node.divisor, divisor)))
|
|
divisor = node.divisor
|
|
add(node)
|
|
if not V.graph.sizevars.maybe_guard_equals(self.numel, divisor):
|
|
# fill in unused index var
|
|
add(self.lookup(divisor, ir.FloorDiv(self.numel, divisor)))
|
|
|
|
return list(reversed(index_vars)), list(reversed(sizes))
|
|
|
|
def ranges_code(self):
|
|
size = self.kernel.indexing_size_str(self.index, self.prefix)
|
|
return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}"
|
|
|
|
def pid_cache_lookup(self, key):
|
|
if key in self.pid_cache:
|
|
return self.pid_cache[key]
|
|
return key
|
|
|
|
def codegen_header(self, code):
|
|
x = self.prefix
|
|
if self.is_loop():
|
|
code.writeline(f"{self.name} = {x}offset + {x}base")
|
|
elif x == "r" and self.kernel.persistent_reduction:
|
|
# no need to "roffset = "
|
|
code.writeline(
|
|
f"{self.name} = {self.ranges_code()}",
|
|
)
|
|
else:
|
|
pid = self.pid_cache_lookup(f"tl.program_id({self.index})")
|
|
code.writelines(
|
|
[
|
|
f"{x}offset = {pid} * {x.upper()}BLOCK",
|
|
f"{self.name} = {x}offset + {self.ranges_code()}",
|
|
]
|
|
)
|
|
code.writeline(f"{x}mask = {self.name} < {x}numel")
|
|
|
|
|
|
class IterationRangesEntry(IterationRanges):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
divisor: sympy.Expr,
|
|
length: sympy.Expr,
|
|
expr: sympy.Expr,
|
|
parent: IterationRanges,
|
|
):
|
|
super().__init__(
|
|
name=name,
|
|
numel=parent.numel / length,
|
|
var_list=parent.var_list,
|
|
var_ranges=parent.var_ranges,
|
|
prefix=parent.prefix,
|
|
divisor=divisor,
|
|
length=length,
|
|
kernel=parent.kernel,
|
|
)
|
|
self.parent = parent
|
|
self.codegen = functools.lru_cache(None)(self._codegen)
|
|
self.expr = expr
|
|
|
|
def set_name(self, name):
|
|
self.codegen = lambda: name
|
|
self.codegen.cache_clear = lambda: None
|
|
self.name = name
|
|
|
|
def cache_clear(self):
|
|
self.codegen.cache_clear()
|
|
|
|
def writeline(self, line):
|
|
if self.is_loop():
|
|
V.kernel.indexing_code.writeline(line)
|
|
else:
|
|
# lift non-reduction stores outside loop
|
|
V.kernel.body.writeline(line)
|
|
|
|
def _codegen(self):
|
|
self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr)))
|
|
return self.name
|
|
|
|
def precomputed_args(self):
|
|
# for dynamic shapes, find parts of indexing expressions that have to be precomputed
|
|
precomputed_args = []
|
|
if isinstance(self.expr, sympy.Symbol):
|
|
return precomputed_args
|
|
assert isinstance(self.expr, (ir.FloorDiv, ir.ModularIndexing)), type(self.expr)
|
|
for arg in self.expr.args[1:]:
|
|
if not isinstance(arg, (sympy.Integer, sympy.Symbol)):
|
|
symbols = arg.free_symbols
|
|
if len(symbols) > 0 and all(s.name.startswith("s") for s in symbols):
|
|
precomputed_args.append(arg)
|
|
return precomputed_args
|
|
|
|
def symbol(self):
|
|
return sympy_symbol(self.name)
|
|
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
|
|
def __eq__(self, other):
|
|
return self.name == other.name
|
|
|
|
|
|
class TritonKernel(Kernel):
|
|
overrides = TritonOverrides
|
|
sexpr = pexpr
|
|
|
|
def __init__(
|
|
self,
|
|
*groups,
|
|
mutations=None,
|
|
pid_cache=None,
|
|
reduction_hint=ReductionHint.DEFAULT,
|
|
):
|
|
if pid_cache is None:
|
|
pid_cache = {}
|
|
super().__init__()
|
|
self.numels = [V.graph.sizevars.simplify(s) for s in groups]
|
|
self.mutations = mutations
|
|
self.range_trees = []
|
|
self.range_tree_nodes = {}
|
|
self.iter_vars_count = itertools.count()
|
|
self.inside_reduction = self.numels[-1] != 1
|
|
self._load_mask = None
|
|
self.body = IndentedBuffer()
|
|
self.indexing_code = IndentedBuffer()
|
|
self.suffix = IndentedBuffer()
|
|
self.outside_loop_vars = set()
|
|
self.reduction_hint = reduction_hint
|
|
self.persistent_reduction = self.should_use_persistent_reduction()
|
|
self.initialize_range_tree(pid_cache)
|
|
|
|
# define this in a closure to make cache local to object
|
|
@functools.lru_cache(None)
|
|
def simplify_indexing(index: sympy.Expr):
|
|
index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
|
|
for tree in self.range_trees:
|
|
index = self.combine_contiguous_dims(index, tree)
|
|
return index
|
|
|
|
self.simplify_indexing = simplify_indexing
|
|
|
|
def should_use_persistent_reduction(self):
|
|
"""
|
|
Heuristic to set self.persistent_reduction and add guards
|
|
if needed.
|
|
"""
|
|
if not (self.inside_reduction and config.triton.persistent_reductions):
|
|
return False
|
|
threshold = {
|
|
ReductionHint.INNER: 1024,
|
|
}.get(self.reduction_hint, 64)
|
|
hint = V.graph.sizevars.size_hint(self.numels[-1])
|
|
if hint > threshold:
|
|
return False
|
|
# will need to recompile if we cross a larger power of 2 boundary
|
|
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint))
|
|
return True
|
|
|
|
def initialize_range_tree(self, pid_cache):
|
|
names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
|
|
for i in range(len(self.numels)):
|
|
self.range_trees.append(
|
|
IterationRangesRoot(
|
|
names[i], self.numels[i], names[i][0], i, self, pid_cache
|
|
)
|
|
)
|
|
for tree in self.range_trees:
|
|
# reduction indexing goes inside a loop
|
|
if not tree.is_loop():
|
|
tree.codegen_header(self.body)
|
|
if self.inside_reduction and self.range_trees[-1].is_loop():
|
|
# workaround for this issue:
|
|
# https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
|
|
self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}")
|
|
|
|
def disable_reduction(self):
|
|
@contextlib.contextmanager
|
|
def ctx():
|
|
if self.numels[-1] == 1:
|
|
assert not self.inside_reduction
|
|
yield
|
|
return
|
|
if not self.persistent_reduction:
|
|
# calling codegen_body() will flush all the pending buffers
|
|
# and write out a reduction loop
|
|
self.codegen_body()
|
|
self.inside_reduction = False
|
|
try:
|
|
yield
|
|
if not self.persistent_reduction:
|
|
# flush out any code before opening the next loop
|
|
self.codegen_body()
|
|
finally:
|
|
self.inside_reduction = True
|
|
|
|
return ctx()
|
|
|
|
def set_ranges(self, *lengths):
|
|
assert len(lengths) == len(self.range_trees)
|
|
return [
|
|
ranges.construct(length)
|
|
for length, ranges in zip(lengths, self.range_trees)
|
|
]
|
|
|
|
@staticmethod
|
|
def _split_iteration_ranges(
|
|
groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]
|
|
):
|
|
sv = V.graph.sizevars
|
|
new_ranges = [[] for _ in groups]
|
|
remaining = [sv.simplify(g) for g in groups]
|
|
var_count = itertools.count()
|
|
|
|
def add_range(i, expr):
|
|
expr = sv.simplify(expr)
|
|
if not sv.maybe_guard_multiple_of(remaining[i], expr):
|
|
raise CantSplit()
|
|
# guard on the last item out
|
|
sv.maybe_guard_equals(remaining[i], expr)
|
|
remaining[i] = ir.FloorDiv(remaining[i], expr)
|
|
new_ranges[i].append(expr)
|
|
return next(var_count)
|
|
|
|
def make_combined(size, idx1, idx2):
|
|
def getter(flat_vars):
|
|
return size * flat_vars[idx1] + flat_vars[idx2]
|
|
|
|
return getter
|
|
|
|
return_getters_groups = []
|
|
current_group = 0
|
|
for length_group in lengths:
|
|
return_getters = []
|
|
for size in length_group:
|
|
if sv.maybe_guard_equals(size, 1):
|
|
return_getters.append(lambda _: sympy.Integer(0))
|
|
continue
|
|
|
|
while (
|
|
current_group < len(remaining)
|
|
and sv.size_hint(remaining[current_group]) == 1
|
|
):
|
|
# scroll to next group with remaining elements
|
|
current_group += 1
|
|
|
|
if sv.size_hint(size) > sv.size_hint(remaining[current_group]):
|
|
# need to break size in two
|
|
if not sv.maybe_guard_multiple_of(size, remaining[current_group]):
|
|
raise CantSplit()
|
|
size1 = remaining[current_group]
|
|
size2 = ir.FloorDiv(size, remaining[current_group])
|
|
return_getters.append(
|
|
make_combined(
|
|
size2,
|
|
add_range(current_group, size1),
|
|
add_range(current_group + 1, size2),
|
|
)
|
|
)
|
|
else:
|
|
return_getters.append(
|
|
operator.itemgetter(add_range(current_group, size))
|
|
)
|
|
return_getters_groups.append(return_getters)
|
|
|
|
assert all(
|
|
V.graph.sizevars.size_hint(s) == 1 for s in remaining
|
|
), f"failed to set ranges {remaining} {lengths}"
|
|
|
|
return new_ranges, return_getters_groups
|
|
|
|
@classmethod
|
|
def is_compatible(cls, groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]):
|
|
try:
|
|
cls._split_iteration_ranges(groups, lengths)
|
|
return True
|
|
except CantSplit:
|
|
return False
|
|
|
|
def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
|
|
"""
|
|
We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
|
|
|
|
To do this we need to split up the iteration space of i0 into something like:
|
|
for i1 in s0:
|
|
for i2 in s1:
|
|
i0 = i1*s1 + i2
|
|
....
|
|
|
|
This function matches and resplits lengths to the groups of
|
|
this kernel to enable tiled + non-tiled fusions.
|
|
"""
|
|
groups = [rt.numel for rt in self.range_trees]
|
|
if not self.inside_reduction:
|
|
groups[-1] = sympy.Integer(1)
|
|
|
|
if len(lengths) == len(self.range_trees) and all(
|
|
V.graph.sizevars.simplify(sympy_product(x) - g) == 0
|
|
for x, g in zip(lengths, groups)
|
|
):
|
|
return self.set_ranges(*lengths)
|
|
|
|
new_ranges, return_getters_groups = self._split_iteration_ranges(
|
|
groups, lengths
|
|
)
|
|
itervars = list(itertools.chain(*self.set_ranges(*new_ranges)))
|
|
return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
|
|
|
|
def is_indirect_indexing(self, index: sympy.Expr):
|
|
# tmpX means indirect indexing
|
|
return free_symbol_startswith(index, "tmp")
|
|
|
|
def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
|
|
"""
|
|
More aggressive simplification to merge contiguous dims
|
|
"""
|
|
if isinstance(index, (sympy.Integer, sympy.Symbol)):
|
|
return index
|
|
index_vars, sizes = tree.vars_and_sizes(index)
|
|
if len(sizes) <= 1:
|
|
return index
|
|
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
|
|
index_vars, sizes, index_prevent_reordering([index], index_vars, sizes)
|
|
)
|
|
if new_sizes == sizes:
|
|
return index
|
|
new_index_vars = tree.construct(new_sizes)
|
|
new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
|
|
return new_index
|
|
|
|
def indexing(
|
|
self,
|
|
index: sympy.Expr,
|
|
*,
|
|
copy_shape=None,
|
|
dense_indexing=False,
|
|
override_mask=None,
|
|
):
|
|
"""
|
|
Compute the index and mask to pass to tl.load() or tl.store()
|
|
"""
|
|
index = self.simplify_indexing(index)
|
|
index = sympy_subs(index, V.graph.sizevars.precomputed_replacements)
|
|
# if simple replacements didn't get rid of floor/ceil, try full subs
|
|
if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)):
|
|
index = index.subs(V.graph.sizevars.precomputed_replacements)
|
|
index_vars = index.free_symbols
|
|
index_str = texpr(self.rename_indexing(self.codegen_indexing(index)))
|
|
|
|
mask_vars: Set[str] = set()
|
|
for var in index_vars:
|
|
if override_mask:
|
|
pass
|
|
elif var.name.startswith("tmp"):
|
|
# indirect indexing
|
|
cse_var = self.cse.varname_map[var.name]
|
|
mask_vars.update(cse_var.mask_vars)
|
|
elif var.name.startswith(("s", "ps")):
|
|
pass
|
|
else:
|
|
# var is one of xN, yN or rN
|
|
assert var.name[0] in "xyr", var.name
|
|
mask_vars.add(f"{var.name[0]}mask")
|
|
|
|
need_dense = (
|
|
config.triton.dense_indexing
|
|
or dense_indexing
|
|
or self._load_mask is not None
|
|
) and index != 0
|
|
|
|
have_dense = True
|
|
have_loop_vars = False
|
|
dense_mask_vars = set()
|
|
|
|
for tree in self.range_trees:
|
|
if tree.prefix == "r" and not self.inside_reduction:
|
|
continue
|
|
if index_vars.intersection(tree.var_list):
|
|
have_loop_vars = True
|
|
have_dense = False
|
|
dense_mask_vars.add(f"{tree.prefix}mask")
|
|
|
|
if (need_dense and not have_dense) or isinstance(index, sympy.Integer):
|
|
if copy_shape:
|
|
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
|
|
else:
|
|
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
|
|
if isinstance(index, sympy.Integer):
|
|
return index_str, set(), "None"
|
|
else:
|
|
mask_vars = dense_mask_vars
|
|
elif not have_loop_vars and copy_shape:
|
|
mask_vars = dense_mask_vars
|
|
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
|
|
|
|
if override_mask:
|
|
mask_vars = {override_mask}
|
|
|
|
if self._load_mask:
|
|
mask_vars.add(self._load_mask)
|
|
|
|
self.filter_masks(mask_vars)
|
|
|
|
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
|
|
return index_str, mask_vars, mask_str
|
|
|
|
def filter_masks(self, mask_vars):
|
|
for tree in self.range_trees:
|
|
# Masks are superfluous if we only have one element
|
|
if V.graph.sizevars.maybe_guard_equals(tree.numel, 1):
|
|
mask_vars.discard(f"{tree.prefix}mask")
|
|
continue
|
|
# Masks are superfluous if numel is a multiple of BLOCK
|
|
# (We use the fact that BLOCK is required by triton to be a power of 2)
|
|
if tree.prefix.upper() not in config.triton.max_block:
|
|
continue
|
|
max_block = config.triton.max_block[tree.prefix.upper()]
|
|
if V.graph.sizevars.maybe_guard_multiple_of(tree.numel, max_block):
|
|
mask_vars.discard(f"{tree.prefix}mask")
|
|
|
|
def var_ranges(self):
|
|
return dict(
|
|
itertools.chain.from_iterable(
|
|
tree.var_ranges.items() for tree in self.range_trees
|
|
)
|
|
)
|
|
|
|
def codegen_indexing(self, expr: sympy.Expr):
|
|
expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
|
|
for sym in sorted(expr.free_symbols, key=str):
|
|
if sym in self.range_tree_nodes:
|
|
# if indexing expression is complicated, we precompute it on the host side
|
|
# and send the result as a kernel argument
|
|
replacements = {}
|
|
for ps in self.range_tree_nodes[sym].precomputed_args():
|
|
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
|
|
if len(replacements) > 0:
|
|
self.range_tree_nodes[sym].expr = sympy_subs(
|
|
self.range_tree_nodes[sym].expr, replacements
|
|
)
|
|
self.range_tree_nodes[sym].codegen()
|
|
return expr
|
|
|
|
@contextlib.contextmanager
|
|
def mask_loads(self, mask):
|
|
"""Context manager to add an additional mask to tl.load/store"""
|
|
prior = self._load_mask
|
|
if prior:
|
|
mask = self.cse.generate(self.compute, f"{mask} & {prior}")
|
|
|
|
self._load_mask = mask
|
|
try:
|
|
with self.swap_buffers(self.compute, self.compute):
|
|
# TODO(jansel): do we need a reshape here?
|
|
yield mask
|
|
finally:
|
|
self._load_mask = prior
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
var = self.args.input(name)
|
|
indirect_indexing = self.is_indirect_indexing(index)
|
|
original_index = index
|
|
index, mask_vars, mask = self.indexing(index)
|
|
|
|
if "rmask" in mask and not self.persistent_reduction:
|
|
# This eviction policy heuristic is untested.
|
|
# ptillet suggested we should try only doing this for
|
|
# the first N-1 loops and not for the final loop.
|
|
ep = ", eviction_policy='evict_last'"
|
|
else:
|
|
ep = ""
|
|
|
|
# "other" below is a workaround for https://github.com/openai/triton/issues/737
|
|
# for bool, even though it's likely subject to the same bug, setting `other` leads
|
|
# to LLVM errors so we are skipping it for now
|
|
if ("tmp" in mask or "rmask" in mask) and V.graph.get_dtype(name) != torch.bool:
|
|
other = ", other=0"
|
|
else:
|
|
other = ""
|
|
|
|
append_broadcast = None
|
|
if V.graph.is_unspec_arg(name):
|
|
line = var
|
|
else:
|
|
if isinstance(original_index, sympy.Integer):
|
|
dense_size = self.dense_size_str()
|
|
line = f"tl.load({var} + ({original_index}))"
|
|
append_broadcast = dense_size
|
|
else:
|
|
line = f"tl.load({var} + ({index}), {mask}{ep}{other})"
|
|
if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16):
|
|
line += ".to(tl.float32)"
|
|
|
|
if (
|
|
self.inside_reduction
|
|
and not self.persistent_reduction
|
|
and "rmask" not in mask
|
|
and "tmp" not in mask
|
|
and not indirect_indexing
|
|
):
|
|
# can lift a common load outside of reduction loop
|
|
# One exception is when this is an indirect_load.
|
|
result_var = self.cse.generate(
|
|
self.body, line, append_broadcast=append_broadcast
|
|
)
|
|
else:
|
|
result_var = self.cse.generate(
|
|
self.loads, line, append_broadcast=append_broadcast
|
|
)
|
|
|
|
result_var.mask_vars = mask_vars
|
|
|
|
if not self.inside_reduction or "rmask" not in mask:
|
|
self.outside_loop_vars.add(result_var)
|
|
|
|
return result_var
|
|
|
|
def store(self, name, index, value, mode=None):
|
|
var = self.args.output(name)
|
|
index, mask_vars, mask = self.indexing(index, dense_indexing=True)
|
|
if mode is None:
|
|
line = f"tl.store({var} + ({index}), {value}, {mask})"
|
|
elif mode == "atomic_add":
|
|
line = f"tl.atomic_add({var} + ({index}), {value}, {mask})"
|
|
else:
|
|
raise NotImplementedError(f"store mode={mode}")
|
|
self.stores.writeline(name, line)
|
|
if not self.inside_reduction:
|
|
self.outside_loop_vars.add(value)
|
|
|
|
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
|
assert self.inside_reduction
|
|
default = triton_constant(ir.Reduction.default_value(reduction_type, src_dtype))
|
|
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
|
|
self.filter_masks(masks)
|
|
masks = sorted(masks)
|
|
if self._load_mask:
|
|
masks.append(self._load_mask)
|
|
sizes = [":" for _ in self.range_trees]
|
|
sizes[-1] = "None"
|
|
reduction_range_prefix = self.range_trees[-1].prefix
|
|
reduction_sizes = ["None" for _ in self.range_trees]
|
|
reduction_sizes[-1] = ":"
|
|
|
|
if reduction_type == "any":
|
|
reduction_type = "max"
|
|
|
|
dim = len(self.range_trees) - 1
|
|
result_var = self.cse.newvar()
|
|
result_var.mask_vars = {var for var in masks if var[0] != "r"}
|
|
if self.persistent_reduction:
|
|
cond = " & ".join(masks)
|
|
masked_value = self.cse.generate(
|
|
self.compute, f"tl.where({cond}, {value}, {default})"
|
|
)
|
|
result_var = self.cse.generate(
|
|
self.compute,
|
|
f"tl.{reduction_type}({masked_value}, {dim})[{', '.join(sizes)}]",
|
|
)
|
|
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
|
|
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
|
|
accumulator = f"_{result_var}"
|
|
default_value = f" + {default}" if default != 0 else ""
|
|
self.body.writeline(
|
|
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
|
|
)
|
|
accumulator_index = None
|
|
if reduction_type in {"argmax", "argmin"}:
|
|
accumulator_index = f"_{result_var}_index"
|
|
self.body.writeline(
|
|
f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
|
|
)
|
|
|
|
updated = value
|
|
if reduction_type in {"min", "argmin"}:
|
|
masks.append(f"({accumulator} > {value})")
|
|
elif reduction_type in {"max", "argmax"}:
|
|
masks.append(f"({accumulator} < {value})")
|
|
elif reduction_type == "sum":
|
|
updated = f"{accumulator} + {value}"
|
|
else:
|
|
raise NotImplementedError(f"reduction_type {reduction_type}")
|
|
|
|
cond = " & ".join(masks)
|
|
|
|
if accumulator_index:
|
|
# argmax or argmin
|
|
self.compute.writeline(
|
|
f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
|
|
)
|
|
self.compute.writeline(
|
|
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
|
|
)
|
|
|
|
if accumulator_index:
|
|
# argmax, argmin
|
|
self.suffix.writelines(
|
|
[
|
|
f"{accumulator_index}_reduce = "
|
|
f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)",
|
|
f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)"
|
|
f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce",
|
|
f"{result_var} = tl.sum("
|
|
f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]",
|
|
]
|
|
)
|
|
else:
|
|
self.suffix.writeline(
|
|
f"{result_var} = tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}]"
|
|
)
|
|
else:
|
|
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
|
|
self.suffix.writeline(f"{result_var} = {var_name}")
|
|
result_var.mask_vars = var_name.mask_vars
|
|
self.inside_reduction = False
|
|
index, mask_vars, mask = self.indexing(index)
|
|
assert "rmask" not in index
|
|
self.inside_reduction = True
|
|
self.outside_loop_vars.add(result_var)
|
|
self.cse.store_cache[name] = result_var
|
|
if name not in V.graph.removed_buffers:
|
|
var = self.args.output(name)
|
|
self.suffix.writeline(
|
|
DeferredLine(name, f"tl.store({var} + {index}, {result_var}, {mask})")
|
|
)
|
|
|
|
def codegen_body(self):
|
|
"""
|
|
Concat output code from index_code, loads, compute, stores,
|
|
suffix into self.body.
|
|
|
|
For pointwise kernels, this is called just once at the end.
|
|
|
|
For reduction kernels, this generates a loop over the reduction
|
|
axis.
|
|
"""
|
|
if not (
|
|
self.indexing_code
|
|
or self.loads
|
|
or self.stores
|
|
or self.compute
|
|
or self.suffix
|
|
):
|
|
return
|
|
|
|
if self.inside_reduction and not self.persistent_reduction:
|
|
self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
|
|
with self.body.indent():
|
|
# last range tree is always reduction
|
|
self.range_trees[-1].codegen_header(self.body)
|
|
self.body.splice(self.indexing_code)
|
|
self.body.splice(self.loads)
|
|
self.body.splice(self.compute)
|
|
self.body.splice(self.stores)
|
|
|
|
# invalidate any caches that came from inside the reduction loop
|
|
self.cse.invalidate(self.outside_loop_vars)
|
|
self.range_trees[-1].cache_clear()
|
|
else:
|
|
self.body.splice(self.indexing_code)
|
|
self.body.splice(self.loads)
|
|
self.body.splice(self.compute)
|
|
self.body.splice(self.stores)
|
|
self.body.splice(self.suffix)
|
|
self.indexing_code.clear()
|
|
self.loads.clear()
|
|
self.compute.clear()
|
|
self.stores.clear()
|
|
self.suffix.clear()
|
|
|
|
def codegen_kernel_benchmark(self):
|
|
result = IndentedBuffer()
|
|
argdefs, call_args, signature = self.args.python_argdefs()
|
|
|
|
result.writelines(["", "", "def get_args():"])
|
|
with result.indent():
|
|
for arg_name in call_args:
|
|
buf = V.graph.get_buffer(arg_name)
|
|
if buf:
|
|
result.writeline(
|
|
f"{arg_name} = rand_strided({tuple(buf.get_size())}, {tuple(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
|
)
|
|
elif arg_name in V.graph.constants:
|
|
# note that random seed is put in V.graph.constants
|
|
const_tensor = V.graph.constants[arg_name]
|
|
result.writeline(
|
|
f"{arg_name} = rand_strided({tuple(const_tensor.size())}, {tuple(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
|
|
)
|
|
else:
|
|
raise KeyError(
|
|
f"Don't find the buffer or const tensor for {arg_name}"
|
|
)
|
|
result.writeline(f"return {', '.join(call_args)},")
|
|
|
|
result.writelines(["\n", "\n", "def call(args):"])
|
|
grid = []
|
|
extra_args = []
|
|
extra_args_str = None
|
|
index = V.graph.scheduler.current_device.index
|
|
with result.indent():
|
|
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
|
|
with result.indent():
|
|
result.writeline(
|
|
f"torch.cuda.set_device({index})"
|
|
) # no-op to ensure context
|
|
for tree in self.range_trees:
|
|
expr = pexpr(tree.numel)
|
|
if tree.prefix != "r" or self.inside_reduction:
|
|
extra_args.append(expr)
|
|
if tree.prefix != "r":
|
|
grid.append(expr)
|
|
|
|
stream_name = f"stream{index}"
|
|
result.writeline(f"{stream_name} = get_cuda_stream({index})")
|
|
extra_args_str = ", ".join(map(str, extra_args)) + ", "
|
|
result.writeline(
|
|
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
|
|
)
|
|
|
|
# benchmark all configs
|
|
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
|
|
with result.indent():
|
|
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
|
|
with result.indent():
|
|
result.writeline(
|
|
f"torch.cuda.set_device({index})"
|
|
) # no-op to ensure context
|
|
result.writeline(
|
|
f"return triton_.benchmark_all_configs(*args, {extra_args_str}grid=grid({', '.join(grid)}))"
|
|
)
|
|
|
|
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
|
with result.indent():
|
|
result.writeline("from torch._inductor.utils import get_num_bytes")
|
|
result.writeline("from triton.testing import do_bench")
|
|
result.writeline("")
|
|
|
|
result.writeline("args = get_args()")
|
|
result.writeline(
|
|
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)[0]"
|
|
)
|
|
result.writeline("num_gb = get_num_bytes(*args) / 1e9")
|
|
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
|
result.writeline(
|
|
'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")'
|
|
)
|
|
|
|
return result
|
|
|
|
def codegen_kernel(self, name=None):
|
|
from triton import next_power_of_2
|
|
|
|
code = IndentedBuffer()
|
|
size_hints = [
|
|
next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels
|
|
]
|
|
if self.persistent_reduction:
|
|
assert self.inside_reduction
|
|
heuristics = "persistent_reduction"
|
|
elif self.inside_reduction:
|
|
heuristics = "reduction"
|
|
else:
|
|
size_hints.pop()
|
|
heuristics = "pointwise"
|
|
|
|
if name is None:
|
|
code.splice(
|
|
f"""
|
|
import triton
|
|
import triton.language as tl
|
|
from torch._inductor.ir import ReductionHint
|
|
from torch._inductor.ir import TileHint
|
|
from torch._inductor.triton_ops.autotune import {heuristics}
|
|
from torch._inductor.utils import instance_descriptor
|
|
"""
|
|
)
|
|
if config.benchmark_kernel:
|
|
code.splice(
|
|
"""
|
|
from torch._dynamo.testing import rand_strided
|
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
|
import torch
|
|
from torch._inductor.triton_ops.autotune import grid
|
|
"""
|
|
)
|
|
|
|
argdefs, _, signature = self.args.python_argdefs()
|
|
# maps actual expression to SizeArg if its in sizevars replacements
|
|
for i, arg in enumerate(signature):
|
|
if (
|
|
isinstance(arg, SizeArg)
|
|
and arg.expr in V.graph.sizevars.inv_precomputed_replacements
|
|
):
|
|
signature[i] = SizeArg(
|
|
arg.name, V.graph.sizevars.inv_precomputed_replacements[arg.expr]
|
|
)
|
|
|
|
mutated_args = set()
|
|
for mutation in self.mutations:
|
|
if mutation in self.args.input_buffers:
|
|
mutated_args.add(self.args.input_buffers[mutation])
|
|
if mutation in self.args.inplace_buffers:
|
|
mutated_args.add(self.args.inplace_buffers[mutation].inner_name)
|
|
if mutation in self.args.output_buffers:
|
|
mutated_args.add(self.args.output_buffers[mutation])
|
|
mutated_args = sorted(mutated_args)
|
|
|
|
triton_meta = {
|
|
"signature": dict(enumerate(map(signature_of, signature))),
|
|
"device": V.graph.scheduler.current_device.index,
|
|
"constants": {},
|
|
"mutated_arg_names": mutated_args,
|
|
}
|
|
|
|
for tree in self.range_trees:
|
|
if tree.prefix != "r" or self.inside_reduction:
|
|
sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
|
|
signature.append(sizearg)
|
|
triton_meta["signature"][len(argdefs)] = signature_of(sizearg)
|
|
argdefs.append(f"{tree.prefix}numel")
|
|
# constexpr version causes issues, see
|
|
# https://github.com/pytorch/torchdynamo/pull/1362
|
|
# triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
|
|
# tree.numel
|
|
# )
|
|
# argdefs.append(f"{tree.prefix}numel: tl.constexpr")
|
|
triton_meta["configs"] = [config_of(signature)]
|
|
|
|
for tree in self.range_trees:
|
|
if tree.prefix != "r" or self.inside_reduction:
|
|
argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
|
|
|
|
if self.inside_reduction:
|
|
reduction_hint = self.reduction_hint
|
|
heuristics_line = f"""
|
|
@{heuristics}(
|
|
size_hints={size_hints!r},
|
|
reduction_hint={reduction_hint},
|
|
filename=__file__,
|
|
meta={triton_meta!r}
|
|
)
|
|
@triton.jit
|
|
"""
|
|
else:
|
|
tile_hint = ""
|
|
if len(size_hints) == 2:
|
|
if len(signature) == 4: # input, output and 2 args
|
|
tile_hint = "tile_hint=TileHint.SQUARE,"
|
|
else:
|
|
tile_hint = "tile_hint=TileHint.DEFAULT,"
|
|
heuristics_line = f"""
|
|
@{heuristics}(size_hints={size_hints!r}, {tile_hint}filename=__file__, meta={triton_meta!r})
|
|
@triton.jit
|
|
"""
|
|
code.splice(heuristics_line)
|
|
code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):")
|
|
self.codegen_body()
|
|
with code.indent():
|
|
if not dynamo_config.dynamic_shapes:
|
|
self.codegen_static_numels(code)
|
|
for old, new in self.args.aliases():
|
|
code.writeline(f"{old} = {new}")
|
|
code.splice(self.body)
|
|
|
|
if config.benchmark_kernel:
|
|
code.splice(self.codegen_kernel_benchmark())
|
|
|
|
if name is not None:
|
|
return code.getvalue()
|
|
|
|
return code.getvalue()
|
|
|
|
def codegen_static_numels(self, code):
|
|
"""
|
|
We get a small speedup from hard coding numels if they are static.
|
|
"""
|
|
for tree in self.range_trees:
|
|
if tree.prefix != "r" or self.inside_reduction:
|
|
if isinstance(V.graph.sizevars.simplify(tree.numel), sympy.Integer):
|
|
code.writeline(
|
|
f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)}"
|
|
)
|
|
elif not dynamo_config.dynamic_shapes:
|
|
code.writeline(
|
|
f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)} # dynamic_shapes=False"
|
|
)
|
|
|
|
def indexing_size_str(self, i=None, x=None):
|
|
sizes = ["None"] * (len(self.range_trees) - int(self.numels[-1] == 1))
|
|
if i is not None:
|
|
sizes[i] = ":"
|
|
return f"[{', '.join(sizes)}]"
|
|
|
|
def dense_size_str(self):
|
|
sizes = []
|
|
for tree in self.range_trees:
|
|
if tree.prefix != "r" or self.inside_reduction:
|
|
sizes.append(f"{tree.prefix.upper()}BLOCK")
|
|
elif tree.prefix == "r" and tree.numel != 1:
|
|
sizes.append("1")
|
|
return f"[{', '.join(sizes)}]"
|
|
|
|
def call_kernel(self, code, name: str):
|
|
_, call_args, _ = self.args.python_argdefs()
|
|
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
|
for i in range(len(call_args)):
|
|
if V.graph.is_unspec_arg(call_args[i]):
|
|
call_args[i] = call_args[i] + ".item()"
|
|
grid = []
|
|
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
|
|
for tree in self.range_trees:
|
|
if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
|
|
expr = pexpr(tree.numel)
|
|
else:
|
|
expr = f"{name}_{tree.prefix}numel"
|
|
code.writeline(f"{expr} = {pexpr(tree.numel)}")
|
|
if tree.prefix != "r" or self.inside_reduction:
|
|
call_args.append(expr)
|
|
if tree.prefix != "r":
|
|
grid.append(expr)
|
|
call_args = ", ".join(call_args)
|
|
stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
|
|
code.writeline(
|
|
f"{name}.run({call_args}, grid=grid({', '.join(grid)}), stream={stream_name})"
|
|
)
|
|
|
|
def create_cse_var(self, *args, **kwargs):
|
|
return TritonCSEVariable(*args, **kwargs)
|
|
|
|
|
|
class TritonScheduling:
|
|
def __init__(self, scheduler):
|
|
self.scheduler = scheduler
|
|
|
|
def group_fn(self, sizes):
|
|
return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
|
|
|
|
def can_fuse(self, node1, node2):
|
|
"""
|
|
Hook called by Scheduler to determine if the Triton backend
|
|
can fuse node1 and node2. These nodes might already be
|
|
FusedSchedulerNodes.
|
|
"""
|
|
_, (numel1, rnumel1) = node1.group
|
|
_, (numel2, rnumel2) = node2.group
|
|
|
|
if node1.is_reduction() and node2.is_reduction():
|
|
return numel1 == numel2 and rnumel1 == rnumel2
|
|
|
|
if not node1.is_reduction() and not node2.is_reduction():
|
|
if not (numel1 == numel2 and rnumel1 == rnumel2):
|
|
return False
|
|
|
|
if node1.is_template():
|
|
return True # skip checks for compatible tiling
|
|
|
|
# check for a bad combined tiling
|
|
tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
|
|
tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
|
|
tiling3 = self.select_tiling(
|
|
node1.get_nodes() + node2.get_nodes(), numel1, rnumel1
|
|
)
|
|
if config.triton.tiling_prevents_pointwise_fusion:
|
|
if len(tiling1) > 2:
|
|
if len(tiling2) > 2:
|
|
return tiling1 == tiling2 == tiling3
|
|
else:
|
|
return tiling1 == tiling3
|
|
elif len(tiling2) > 2:
|
|
return tiling2 == tiling3
|
|
|
|
return True
|
|
|
|
if not node1.is_reduction() and node2.is_reduction():
|
|
assert rnumel1 == 1 and rnumel2 != 1
|
|
if numel1 == numel2 * rnumel2:
|
|
if not all(
|
|
TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
|
|
for n in node1.get_nodes()
|
|
):
|
|
return False
|
|
if (
|
|
config.triton.tiling_prevents_reduction_fusion
|
|
and not node1.is_template()
|
|
):
|
|
return self.select_tiling(node1.get_nodes(), numel1) in (
|
|
(numel1, 1),
|
|
(numel2, rnumel2, 1),
|
|
)
|
|
return True
|
|
|
|
return numel1 == numel2
|
|
|
|
assert node1.is_reduction() and not node2.is_reduction()
|
|
# swap args to hit the case above
|
|
return self.can_fuse_horizontal(node2, node1)
|
|
|
|
can_fuse_vertical = can_fuse
|
|
can_fuse_horizontal = can_fuse
|
|
|
|
def codegen_nodes(self, nodes):
|
|
"""
|
|
Given a set of pre-fused nodes, generate a Triton kernel.
|
|
"""
|
|
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
|
|
node_schedule = []
|
|
current_loop_writes = set()
|
|
is_current_reductions = set()
|
|
done = set()
|
|
|
|
def fits_in_main_body(n):
|
|
_, (node_numel, node_rnumel) = n.group
|
|
return (node_numel == numel and node_rnumel == rnumel) or (
|
|
node_numel == numel * rnumel and node_rnumel == 1
|
|
)
|
|
|
|
def fits_outside_reduction(n):
|
|
_, (node_numel, node_rnumel) = n.group
|
|
return node_numel == numel and node_rnumel == 1 and rnumel != 1
|
|
|
|
@contextlib.contextmanager
|
|
def end_current_reduction_loop():
|
|
if current_loop_writes:
|
|
# flush out any other runnable nodes to reduce number of loops
|
|
for other_node in nodes[index + 1 :]:
|
|
if (
|
|
node not in done
|
|
and fits_in_main_body(other_node)
|
|
and not (
|
|
current_loop_writes & other_node.recursive_predecessors
|
|
)
|
|
):
|
|
done.add(node)
|
|
current_loop_writes.add(node.get_name())
|
|
is_current_reductions.add(node.is_reduction())
|
|
node_schedule.append(node)
|
|
|
|
if node_schedule and node_schedule[-1] is EnableReduction:
|
|
node_schedule.pop()
|
|
else:
|
|
node_schedule.append(DisableReduction)
|
|
yield
|
|
node_schedule.append(EnableReduction)
|
|
current_loop_writes.clear()
|
|
is_current_reductions.clear()
|
|
|
|
for index, node in enumerate(nodes):
|
|
if node in done:
|
|
continue
|
|
done.add(node)
|
|
|
|
def requires_closing_previous_reduction(node, node_schedule):
|
|
if rnumel == 1:
|
|
return False
|
|
if not current_loop_writes & node.recursive_predecessors:
|
|
return False
|
|
assert node_schedule and not isinstance(
|
|
node_schedule[-1], (EnableReduction, DisableReduction)
|
|
)
|
|
return True in is_current_reductions
|
|
|
|
if fits_in_main_body(node):
|
|
if requires_closing_previous_reduction(node, node_schedule):
|
|
with end_current_reduction_loop():
|
|
pass # need to start a new reduction loop
|
|
current_loop_writes.add(node.get_name())
|
|
is_current_reductions.add(node.is_reduction())
|
|
node_schedule.append(node)
|
|
elif fits_outside_reduction(node):
|
|
with end_current_reduction_loop():
|
|
node_schedule.append(node)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
|
|
)
|
|
|
|
if schedule_log.isEnabledFor(logging.DEBUG):
|
|
schedule_log.debug(f"Schedule:\n {node_schedule}")
|
|
return self.codegen_node_schedule(node_schedule, numel, rnumel)
|
|
|
|
@staticmethod
|
|
def reduction_hint(node):
|
|
assert node.is_reduction()
|
|
if all(
|
|
dep.is_contiguous()
|
|
for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
|
|
):
|
|
return ReductionHint.INNER
|
|
else:
|
|
return node.node.data.reduction_hint
|
|
|
|
def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
|
|
tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
|
|
reductions = list(
|
|
filter(
|
|
lambda n: n not in (EnableReduction, DisableReduction)
|
|
and n.is_reduction(),
|
|
node_schedule,
|
|
)
|
|
)
|
|
if len(reductions) > 0:
|
|
hints = [self.reduction_hint(n) for n in reductions]
|
|
if hints.count(hints[0]) == len(hints):
|
|
reduction_hint_val = hints[0]
|
|
else:
|
|
reduction_hint_val = ReductionHint.DEFAULT
|
|
else:
|
|
reduction_hint_val = ReductionHint.DEFAULT
|
|
|
|
mutations = set()
|
|
for node in node_schedule:
|
|
if hasattr(node, "get_mutations"):
|
|
mutations.update(node.get_mutations())
|
|
|
|
with TritonKernel(
|
|
*tiled_groups, reduction_hint=reduction_hint_val, mutations=mutations
|
|
) as kernel:
|
|
stack = contextlib.ExitStack()
|
|
for node in node_schedule:
|
|
if node not in (EnableReduction, DisableReduction):
|
|
node.mark_run()
|
|
for node in node_schedule:
|
|
if node is DisableReduction:
|
|
stack.enter_context(kernel.disable_reduction())
|
|
elif node is EnableReduction:
|
|
stack.close()
|
|
else:
|
|
# TODO - mostly works but needs a couple fixes
|
|
if not dynamo_config.dynamic_shapes:
|
|
# TODO - use split ranges ?
|
|
indexing_dtype_strength_reduction(node._body)
|
|
index_vars = kernel.split_and_set_ranges(node.get_ranges())
|
|
node.codegen(index_vars)
|
|
|
|
src_code = kernel.codegen_kernel()
|
|
kernel_name = self.define_kernel(src_code, node_schedule)
|
|
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
|
|
self.scheduler.free_buffers()
|
|
|
|
def define_kernel(self, src_code, node_schedule):
|
|
wrapper = V.graph.wrapper_code
|
|
if src_code in wrapper.kernels:
|
|
kernel_name = wrapper.kernels[src_code]
|
|
else:
|
|
fused_name = (
|
|
get_fused_kernel_name(node_schedule)
|
|
if config.triton.descriptive_names
|
|
else ""
|
|
)
|
|
kernel_name = "_".join(["triton", fused_name, wrapper.next_kernel_suffix()])
|
|
wrapper.kernels[src_code] = kernel_name
|
|
subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"
|
|
src_code = src_code.replace("KERNEL_NAME", subs_name)
|
|
|
|
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
|
|
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
|
|
src_code = src_code.replace("#pragma CMT", "#")
|
|
|
|
_, _, kernel_path = get_code_path(src_code, "py", extra="")
|
|
compile_wrapper = IndentedBuffer()
|
|
compile_wrapper.writeline("async_compile.triton('''")
|
|
compile_wrapper.splice(src_code, strip=True)
|
|
compile_wrapper.writeline("''')")
|
|
|
|
metadata_comment = f"# kernel path: {kernel_path}"
|
|
metadata_comment += "\n" + get_kernel_metadata(node_schedule)
|
|
wrapper.define_kernel(
|
|
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
|
)
|
|
return kernel_name
|
|
|
|
def codegen_template(self, template_node, epilogue_nodes):
|
|
"""
|
|
Codegen a triton template
|
|
"""
|
|
_, (numel, rnumel) = template_node.group
|
|
assert rnumel == 1
|
|
kernel, render = template_node.node.make_kernel_render(template_node.node)
|
|
with kernel:
|
|
for node in [template_node, *epilogue_nodes]:
|
|
node.mark_run()
|
|
render() # warmup run to get the args right
|
|
for node in epilogue_nodes:
|
|
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
|
|
|
|
src_code = render()
|
|
kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
|
|
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
|
|
self.scheduler.free_buffers()
|
|
|
|
def codegen_sync(self):
|
|
V.graph.wrapper_code.writeline("torch.cuda.synchronize()")
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(32)
|
|
def candidate_tilings(node):
|
|
ranges, reduction_ranges = node.get_ranges()
|
|
if len(ranges) <= 1:
|
|
return ()
|
|
|
|
rw = node.pointwise_read_writes()
|
|
assert len(rw.range_vars) == len(ranges)
|
|
|
|
deps = [
|
|
dep
|
|
for dep in itertools.chain(rw.reads, rw.writes)
|
|
if dep.name not in V.graph.removed_buffers
|
|
]
|
|
write_names = {dep.name for dep in rw.writes}
|
|
|
|
tilings = []
|
|
|
|
for dep in deps:
|
|
strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars)
|
|
assert len(strides) == len(ranges)
|
|
try:
|
|
split = strides.index(1) + 1
|
|
if split == len(ranges):
|
|
continue
|
|
if all(s == 0 for s in strides[split:]):
|
|
# if this is a broadcasted tensor and all dimensions after split are broadcast,
|
|
# this is not a real split
|
|
continue
|
|
|
|
except ValueError:
|
|
continue
|
|
tiled_groups = (
|
|
V.graph.sizevars.simplify(sympy_product(ranges[:split])),
|
|
V.graph.sizevars.simplify(sympy_product(ranges[split:])),
|
|
)
|
|
# score by number of elements
|
|
score = V.graph.sizevars.size_hint(
|
|
sympy_product(
|
|
size for size, stride in zip(ranges, strides) if stride != 0
|
|
)
|
|
)
|
|
if dep.name in write_names:
|
|
# ngimel said contiguous writes is more important than reads
|
|
score *= 2
|
|
if CandidateTiling.is_good_size(tiled_groups[0]):
|
|
score *= 2
|
|
if CandidateTiling.is_good_size(tiled_groups[1]):
|
|
score *= 2
|
|
|
|
if (
|
|
V.graph.sizevars.size_hint(
|
|
score - sympy_product(itertools.chain(ranges, reduction_ranges))
|
|
)
|
|
>= 0
|
|
):
|
|
tilings.append(CandidateTiling(tiled_groups, score, dep.name))
|
|
return tilings
|
|
|
|
@classmethod
|
|
def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
|
|
"""
|
|
Heuristics to decide how to tile kernels.
|
|
Currently, we tile based on stride-1 dimensions.
|
|
|
|
Returns:
|
|
`(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
|
|
|
|
"""
|
|
if reduction_numel != 1 or config.triton.max_tiles <= 1:
|
|
# TODO(jansel): should we tile reductions?
|
|
return (numel, reduction_numel)
|
|
|
|
seen_names = set()
|
|
candidate_tiles = collections.Counter()
|
|
for node in EnableReduction.filter(node_schedule):
|
|
for tiling in cls.candidate_tilings(node):
|
|
if tiling.name in seen_names:
|
|
continue
|
|
seen_names.add(tiling.name)
|
|
candidate_tiles[tiling.tiling] += tiling.score
|
|
|
|
ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()]
|
|
|
|
if config.triton.max_tiles >= 3:
|
|
# Add one 3D tiling choice
|
|
for i in range(1, len(ranked_tilings)):
|
|
a0, a1 = ranked_tilings[0]
|
|
b0, b1 = ranked_tilings[i]
|
|
if V.graph.sizevars.size_hint(a1 - b1) == 0:
|
|
continue
|
|
if V.graph.sizevars.size_hint(a1 - b1) < 0:
|
|
# swap so a0 is bigger
|
|
a0, a1 = ranked_tilings[i]
|
|
b0, b1 = ranked_tilings[0]
|
|
assert V.graph.sizevars.size_hint(a1 - b1) > 0
|
|
if V.graph.sizevars.maybe_guard_multiple_of(a1, b1):
|
|
tiling = (a0, ir.FloorDiv(a1, b1), b1)
|
|
ranked_tilings = [tiling] + ranked_tilings
|
|
break # only 1 choice for now
|
|
|
|
for tiled_groups in ranked_tilings:
|
|
new_groups = (*tiled_groups, reduction_numel)
|
|
if all(
|
|
TritonKernel.is_compatible(new_groups, node.get_ranges())
|
|
for node in node_schedule
|
|
if isinstance(node, scheduler.SchedulerNode)
|
|
):
|
|
return new_groups
|
|
|
|
return (numel, reduction_numel)
|
|
|
|
def flush(self):
|
|
pass
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CandidateTiling:
|
|
tiling: List[sympy.Expr]
|
|
score: int # higher is better
|
|
name: str = None
|
|
|
|
@staticmethod
|
|
def is_good_size(s):
|
|
"""Somewhat arbitrary heuristic used to boost scores for some sizes"""
|
|
s = V.graph.sizevars.size_hint(s)
|
|
return s >= 32 and (s % 32 == 0)
|
|
|
|
|
|
class DisableReduction:
|
|
"""
|
|
Marker to invoke `kernel.disable_reduction()`. This closes a
|
|
reduction loop and allows for pointwise ops to occur on the output
|
|
of a reduction.
|
|
"""
|
|
|
|
|
|
class EnableReduction:
|
|
"""
|
|
Marker to end a DisableReduction block.
|
|
"""
|
|
|
|
@staticmethod
|
|
def filter(node_schedule):
|
|
"""
|
|
Get the nodes from node_schedule skipping those in a
|
|
DisableReduction block.
|
|
"""
|
|
disabled = False
|
|
for node in node_schedule:
|
|
if node in (EnableReduction, DisableReduction):
|
|
# Don't tile stuff outside the main reduction loop
|
|
disabled = node is DisableReduction
|
|
elif disabled:
|
|
pass
|
|
else:
|
|
yield node
|
|
|
|
|
|
class CantSplit(Exception):
|
|
pass
|