mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Compare commits
3 Commits
ciflow/vll
...
new-codege
Author | SHA1 | Date | |
---|---|---|---|
25e1e99fcd | |||
f803a57a7a | |||
39287c561b |
@ -3573,6 +3573,7 @@ class CommonTemplate:
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
@skip_if_halide # only 32-bit indexing
|
||||
@skipIfRocm
|
||||
@largeTensorTest("4GB", inductor=True)
|
||||
def test_large_pointwise(self):
|
||||
def fn(a):
|
||||
@ -12962,6 +12963,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
if name not in {"softmax", "log_softmax", "logsumexp"}
|
||||
],
|
||||
)
|
||||
@skipIfRocm
|
||||
def test_pointwise(self, name, op):
|
||||
dtype = torch.float32
|
||||
check_lowp = True
|
||||
|
@ -34,7 +34,6 @@ from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..async_compile import AsyncCompile
|
||||
from ..codecache import code_hash, get_path, PyCodeCache, write_atomic
|
||||
from ..debug import set_kernel_post_grad_provenance_tracing
|
||||
from ..ops_handler import DefaultHandler
|
||||
from ..runtime import triton_heuristics
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
@ -48,7 +47,6 @@ from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
|
||||
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
DelayMaybeLine,
|
||||
DelayReplaceLine,
|
||||
get_bounds_index_expr,
|
||||
get_fused_kernel_name,
|
||||
@ -75,7 +73,6 @@ from .common import (
|
||||
DeferredLine,
|
||||
IndentedBuffer,
|
||||
InplacedBuffer,
|
||||
is_buffer_removed,
|
||||
OpOverrides,
|
||||
PythonPrinter,
|
||||
RemovedArg,
|
||||
@ -641,13 +638,6 @@ def triton_reshape(
|
||||
return f"{value}[{', '.join(expand)}]"
|
||||
|
||||
|
||||
def enable_pdl_codegen():
|
||||
if not torch._inductor.config.triton.enable_pdl:
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
|
||||
return major >= 9
|
||||
|
||||
|
||||
# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
|
||||
# number of operators which Triton "implements", but in a way that is
|
||||
# inconsistent with Python semantics (and consistent with C semantics). We
|
||||
@ -1607,6 +1597,10 @@ class TritonKernelOverrides(TritonOverrides):
|
||||
V.kernel.cse.put(cache_key, (mantissa, exponent))
|
||||
return (mantissa, exponent)
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg):
|
||||
return f"tl.device_assert({cond}, {repr(msg)})"
|
||||
|
||||
|
||||
class HelperFunctions:
|
||||
"""An ordered set of helper functions."""
|
||||
@ -1950,8 +1944,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
self.fixed_config = fixed_config
|
||||
super().__init__(tiling, **kwargs)
|
||||
self.cse = TritonCSE(self.newvar_prefix, self.suffix)
|
||||
# Cache of values that can be reused for the prologue.
|
||||
self.prologue_cache: dict[str, str] = {}
|
||||
self.prologue: IndentedBuffer = IndentedBuffer()
|
||||
self.post_loop_combine: IndentedBuffer = IndentedBuffer()
|
||||
self.post_loop_store: IndentedBuffer = IndentedBuffer()
|
||||
@ -1966,7 +1958,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
self.tma_min_block_sizes = dict[str, int]()
|
||||
self.hint_override = hint_override
|
||||
self._load_counts: collections.Counter[str] = collections.Counter()
|
||||
self._load_index = 0
|
||||
|
||||
# A set of autotuning hints to pass as part of triton_meta
|
||||
self.autotune_hints = OrderedSet[AutotuneHint]()
|
||||
@ -1983,44 +1974,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
if self.cooperative_reduction:
|
||||
self.init_cooperative_reduction_mask()
|
||||
|
||||
self.has_load_with_contiguous_rdim = False
|
||||
# We track the store name since a store can be canceled later
|
||||
self.stores_with_contiguous_rdim: list[str] = []
|
||||
|
||||
@staticmethod
|
||||
def _has_stride1_on_rdim(index) -> bool:
|
||||
# These analysis is only needed in deterministic mode so far
|
||||
# to filter triton configs. Return false immediately to avoid
|
||||
# increasing compilation time when the mode is off.
|
||||
if not (
|
||||
config.deterministic or config.test_configs.force_filter_reduction_configs
|
||||
):
|
||||
return False
|
||||
support_vars = index.free_symbols
|
||||
reduce_vars = [
|
||||
var
|
||||
for var in support_vars
|
||||
if symbol_is_type(var, TritonSymbols.reduction_types)
|
||||
]
|
||||
|
||||
if len(reduce_vars) == 0:
|
||||
return False
|
||||
|
||||
# for expression "x0 + 150528*((x1//(s27*s38))) + 3*(ModularIndexing(x1, 1, s38)) + 672*(ModularIndexing(x1, s38, s27))"
|
||||
# stride_vars will results in DivisionByZero error
|
||||
try:
|
||||
stride_vars = V.graph.sizevars.stride_vars(index, reduce_vars, support_vars)
|
||||
except ZeroDivisionError:
|
||||
return False
|
||||
|
||||
return any(stride == 1 for stride in stride_vars)
|
||||
|
||||
@property
|
||||
def has_store_with_contiguous_rdim(self) -> bool:
|
||||
return not all(
|
||||
is_buffer_removed(name) for name in self.stores_with_contiguous_rdim
|
||||
)
|
||||
|
||||
def dtype_to_str(self, dtype: torch.dtype) -> str:
|
||||
return triton_type(dtype)
|
||||
|
||||
@ -2088,6 +2041,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
)
|
||||
|
||||
def codegen_range_tree(self):
|
||||
|
||||
# Skip if pointwise inner loop
|
||||
#return
|
||||
|
||||
for tree in self.range_trees:
|
||||
# reduction indexing goes inside a loop
|
||||
if not tree.is_loop:
|
||||
@ -2532,95 +2489,86 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
and self.range_trees[-1].is_loop
|
||||
and indexing.has_rindex()
|
||||
) or indexing.can_lift:
|
||||
if indexing.can_lift and var in self.prologue_cache:
|
||||
# Check for epilogue subtiling to reuse the same
|
||||
# tensor descriptor.
|
||||
block_descriptor = self.prologue_cache[var]
|
||||
block_descriptor_id = next(self.block_ptr_id)
|
||||
if isinstance(indexing, BlockPtrOptions):
|
||||
block_descriptor = f"block_ptr{block_descriptor_id}"
|
||||
else:
|
||||
block_descriptor_id = next(self.block_ptr_id)
|
||||
if isinstance(indexing, BlockPtrOptions):
|
||||
block_descriptor = f"block_ptr{block_descriptor_id}"
|
||||
else:
|
||||
block_descriptor = f"tma_descriptor{block_descriptor_id}"
|
||||
line_body = DeferredLine(
|
||||
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
|
||||
)
|
||||
if indexing.can_lift:
|
||||
self.prologue.writeline(line_body)
|
||||
# Cache the descriptor for epilogue subtiling
|
||||
self.prologue_cache[var] = block_descriptor
|
||||
else:
|
||||
self.body.writeline(line_body)
|
||||
block_descriptor = f"tma_descriptor{block_descriptor_id}"
|
||||
line_body = DeferredLine(
|
||||
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
|
||||
)
|
||||
if indexing.can_lift:
|
||||
self.prologue.writeline(line_body)
|
||||
else:
|
||||
self.body.writeline(line_body)
|
||||
|
||||
if isinstance(indexing, BlockPtrOptions):
|
||||
# Store for later use. If the buffer is removed the below advancements
|
||||
# are no longer necessary
|
||||
self.block_ptr_to_buffer[block_descriptor] = name
|
||||
if isinstance(indexing, BlockPtrOptions):
|
||||
# Store for later use. If the buffer is removed the below advancements
|
||||
# are no longer necessary
|
||||
self.block_ptr_to_buffer[block_descriptor] = name
|
||||
|
||||
# Generate block pointer advancements, for later use.
|
||||
for symt in TritonSymbols.reduction_types:
|
||||
advance_offsets = indexing.advance_roffset(symt)
|
||||
# Generate block pointer advancements, for later use.
|
||||
for symt in TritonSymbols.reduction_types:
|
||||
advance_offsets = indexing.advance_roffset(symt)
|
||||
|
||||
# Ignore identity advancements.
|
||||
if all(
|
||||
V.graph.sizevars.statically_known_equals(
|
||||
offset, sympy.Integer(0)
|
||||
)
|
||||
for offset in advance_offsets
|
||||
):
|
||||
continue
|
||||
|
||||
advancements = self.pointer_advancements[symt]
|
||||
assert block_descriptor not in advancements, (
|
||||
f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
|
||||
# Ignore identity advancements.
|
||||
if all(
|
||||
V.graph.sizevars.statically_known_equals(
|
||||
offset, sympy.Integer(0)
|
||||
)
|
||||
advancements[block_descriptor] = advance_offsets
|
||||
for offset in advance_offsets
|
||||
):
|
||||
continue
|
||||
|
||||
advancements = self.pointer_advancements[symt]
|
||||
assert block_descriptor not in advancements, (
|
||||
f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
|
||||
)
|
||||
advancements[block_descriptor] = advance_offsets
|
||||
else:
|
||||
block_descriptor = indexing.format(var)
|
||||
return block_descriptor, other
|
||||
|
||||
def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
|
||||
def stringify_shape(shape):
|
||||
return tuple(
|
||||
symt.name if isinstance(symt, sympy.Symbol) else str(symt)
|
||||
for symt in shape
|
||||
)
|
||||
|
||||
if value.shape:
|
||||
value_forward_shape = stringify_shape(value.shape)
|
||||
value_reverse_shape = stringify_shape(value.shape[::-1])
|
||||
# TMA stores may require transposing the data to ensure we are contiguous along
|
||||
# the final dimension. We do this by checking the shape information on value.
|
||||
# It can either
|
||||
# 1. Match the final shape. In this case no broadcast/reshape
|
||||
# is necessary.
|
||||
# 2. Exist as the Transpose of the final shape, which means we had to transpose
|
||||
# the store_descriptor relative to the accumulator indexing/value. If this
|
||||
# happens we will generate a tl.trans().
|
||||
# 3. A mismatched provided shape. When this occurs we will error.
|
||||
# 4. No shape is provided. This will proceed with the default explicit broadcast
|
||||
# described below.
|
||||
#
|
||||
# To prevent unintended side effects we will gate options 1-3 behind isinstance(indexing, TensorDescriptorOptions).
|
||||
if isinstance(indexing, TensorDescriptorOptions) and value.shape:
|
||||
str_final_shape = tuple([symt.name for symt in indexing.final_shape])
|
||||
if value.shape[::-1] == str_final_shape:
|
||||
value = f"tl.trans({value})"
|
||||
elif value.shape != str_final_shape:
|
||||
raise AssertionError(
|
||||
"TMA store requires no broadcasting when a shape is provided"
|
||||
)
|
||||
else:
|
||||
value_forward_shape = None
|
||||
value_reverse_shape = None
|
||||
final_shape = stringify_shape(indexing.final_shape)
|
||||
# TODO: Generalize to N Dimensions
|
||||
if (
|
||||
value_forward_shape != final_shape
|
||||
and value_reverse_shape == final_shape
|
||||
and len(final_shape) == 2
|
||||
):
|
||||
# TMA stores may require transposing the data to ensure we are contiguous along
|
||||
# the final dimension. This applies to Block-pointers generally, but should only practically
|
||||
# be reached with TMA.
|
||||
value = f"tl.trans({value})"
|
||||
# Stores require an explicit broadcast. We do this in two phases:
|
||||
# 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
|
||||
# YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
|
||||
# 2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the
|
||||
# result to the shape of the pointer.
|
||||
value = f"tl.broadcast_to({value}, {indexing.final_shape})"
|
||||
|
||||
# Stores require an explicit broadcast. We do this in two phases:
|
||||
# 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
|
||||
# YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
|
||||
# 2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the
|
||||
# result to the shape of the pointer.
|
||||
value = f"tl.broadcast_to({value}, {indexing.final_shape})"
|
||||
# These dims no longer need broadcasting.
|
||||
for idx, (dim, broadcast_dim) in enumerate(
|
||||
zip(indexing.final_shape, indexing.broadcast_shape)
|
||||
):
|
||||
if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
|
||||
indexing.broadcasting_dims[idx] = False
|
||||
|
||||
# These dims no longer need broadcasting.
|
||||
for idx, (dim, broadcast_dim) in enumerate(
|
||||
zip(indexing.final_shape, indexing.broadcast_shape)
|
||||
):
|
||||
if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
|
||||
indexing.broadcasting_dims[idx] = False
|
||||
|
||||
value = indexing.codegen_broadcast_and_reshape(
|
||||
value, indexing.final_shape, indexing.block_shape, False
|
||||
)
|
||||
value = indexing.codegen_broadcast_and_reshape(
|
||||
value, indexing.final_shape, indexing.block_shape, False
|
||||
)
|
||||
|
||||
# workaround https://github.com/triton-lang/triton/issues/2814
|
||||
value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
|
||||
@ -2669,27 +2617,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
else:
|
||||
return self.loads
|
||||
|
||||
def _handle_pdl_before_load(self, wait_buffer):
|
||||
GDC_WAIT = "tl.extra.cuda.gdc_wait()"
|
||||
self._load_index += 1
|
||||
if self.inside_reduction:
|
||||
wait_buffer = self.body
|
||||
if enable_pdl_codegen():
|
||||
if self._load_index == 1:
|
||||
wait_buffer.writeline(GDC_WAIT)
|
||||
|
||||
def _handle_pdl_after_load(self, launch_buffer, result_var):
|
||||
GDC_LAUNCH = "tl.extra.cuda.gdc_launch_dependents()"
|
||||
if self.inside_reduction:
|
||||
launch_buffer = self.post_loop_combine
|
||||
if enable_pdl_codegen():
|
||||
current_load_index = self._load_index
|
||||
launch_if_last_load = DelayMaybeLine(
|
||||
lambda: current_load_index == self._load_index,
|
||||
f"0; {GDC_LAUNCH} # gdc launch for {result_var}",
|
||||
)
|
||||
self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
"""
|
||||
Load from the memory location 'name', offset by some indexing expression 'index'.
|
||||
@ -2711,12 +2638,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
force=False,
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
|
||||
indexing.index
|
||||
):
|
||||
self.has_load_with_contiguous_rdim = True
|
||||
|
||||
has_rindex = indexing.has_rindex()
|
||||
has_tmpmask = indexing.has_tmpmask()
|
||||
|
||||
@ -2826,11 +2747,11 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
dtype = torch.bool
|
||||
|
||||
load_buffer = self.get_load_buffer(indexing)
|
||||
self._handle_pdl_before_load(load_buffer)
|
||||
if config.triton.enable_pdl:
|
||||
load_buffer.writeline("tl.extra.cuda.gdc_wait()")
|
||||
result_var = self.cse.generate(
|
||||
load_buffer, make_line(line), dtype=dtype, shape=shape
|
||||
)
|
||||
self._handle_pdl_after_load(load_buffer, result_var)
|
||||
if result_var.use_count > 1:
|
||||
load_counts[name] -= 1 # don't double count cache hit
|
||||
assert isinstance(result_var, TritonCSEVariable)
|
||||
@ -2884,11 +2805,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
tma_compatibility_checker=tma_compatibility_checker,
|
||||
)
|
||||
|
||||
if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
|
||||
indexing.index
|
||||
):
|
||||
self.stores_with_contiguous_rdim.append(name)
|
||||
|
||||
# Guard against write-after-read corruption in triton.
|
||||
# See # https://github.com/triton-lang/triton/issues/1615
|
||||
# This triton bug means that a load which is broadcasted over multiple
|
||||
@ -2924,9 +2840,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
exit_stack.close()
|
||||
|
||||
def device_assert_async(self, cond, msg) -> None:
|
||||
self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})")
|
||||
|
||||
def guard_cooperative_store(self, name, buffer):
|
||||
"""
|
||||
For cooperative reductions only one thread block should write out the result.
|
||||
@ -2984,7 +2897,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
"Bucketize only supports indexing with int32 and int64"
|
||||
)
|
||||
|
||||
self._handle_pdl_before_load(self.compute)
|
||||
result = self.cse.generate(
|
||||
self.compute,
|
||||
f"triton_helpers.bucketize_binary_search({values}, "
|
||||
@ -2998,7 +2910,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
dtype=indexing_dtype, # type: ignore[attr-defined]
|
||||
shape=values.shape,
|
||||
)
|
||||
self._handle_pdl_after_load(self.compute, result)
|
||||
|
||||
masks = self._combine_masks(values, boundary_indices, sorter_indices)
|
||||
result.mask_vars = masks # type: ignore[attr-defined]
|
||||
@ -3175,34 +3086,14 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
if self.persistent_reduction:
|
||||
default = ir.Reduction.default_value(reduction_type, src_dtype)
|
||||
|
||||
def update_constant_dtype(constant, src_dtype, dst_dtype):
|
||||
"update reduction constant mask value to match dst_dtype"
|
||||
|
||||
# int is the only mask which may not fit within lower bitwidth,
|
||||
# because float uses inf/-inf
|
||||
if src_dtype.is_floating_point or src_dtype == torch.bool:
|
||||
return constant
|
||||
|
||||
if src_dtype == dst_dtype or constant == 0:
|
||||
return constant
|
||||
|
||||
if constant == torch.iinfo(src_dtype).max:
|
||||
return torch.iinfo(dst_dtype).max
|
||||
elif constant == torch.iinfo(src_dtype).min:
|
||||
return torch.iinfo(dst_dtype).min
|
||||
else:
|
||||
return constant
|
||||
default = self._map_tuple_or_scalar(constant_repr, default)
|
||||
|
||||
def _mask_value(value, default) -> CSEVariable:
|
||||
default = update_constant_dtype(default, src_dtype, value.dtype)
|
||||
default_str = self._map_tuple_or_scalar(constant_repr, default)
|
||||
|
||||
return self.cse.generate(
|
||||
self.compute,
|
||||
where_cond(value, default_str),
|
||||
where_cond(value, default),
|
||||
dtype=value.dtype,
|
||||
shape=value.shape,
|
||||
shape=value.shape if value.shape is not None else default.shape,
|
||||
)
|
||||
|
||||
masked_value: Union[CSEVariable, Sequence[CSEVariable]]
|
||||
@ -3211,14 +3102,13 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
# will fallback below
|
||||
pass
|
||||
elif isinstance(value, tuple):
|
||||
masked_value = [_mask_value(v, d) for v, d in zip(value, default)] # type: ignore[arg-type]
|
||||
masked_value = [_mask_value(v, d) for v, d in zip(value, default)]
|
||||
else:
|
||||
masked_value = _mask_value(value, default)
|
||||
|
||||
if reduction_type in ("argmax", "argmin"):
|
||||
assert isinstance(masked_value, CSEVariable)
|
||||
accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype()
|
||||
|
||||
accumulator_index = str(
|
||||
self.cse.generate(
|
||||
self.compute,
|
||||
@ -3993,7 +3883,268 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
code.splice(self.prologue)
|
||||
self.prologue.clear()
|
||||
self.prologue_cache.clear()
|
||||
|
||||
def _should_use_pointwise_inner_loop(self):
|
||||
"""Check if we should use inner loop optimization for this pointwise kernel."""
|
||||
|
||||
# Safety check: never apply to reduction kernels
|
||||
if self.inside_reduction:
|
||||
return False
|
||||
|
||||
# Only apply to non-template kernels
|
||||
if self.fixed_config or self.cooperative_reduction or self.persistent_reduction:
|
||||
return False
|
||||
|
||||
# Check for operations incompatible with inner loop
|
||||
code_str = ""
|
||||
if hasattr(self, 'indexing_code') and self.indexing_code:
|
||||
code_str += str(self.indexing_code)
|
||||
if hasattr(self, 'loads') and self.loads:
|
||||
code_str += str(self.loads)
|
||||
if hasattr(self, 'compute') and self.compute:
|
||||
code_str += str(self.compute)
|
||||
if hasattr(self, 'stores') and self.stores:
|
||||
code_str += str(self.stores)
|
||||
|
||||
# Skip if using block pointers - they have complex shape calculations
|
||||
if "tl.make_block_ptr" in code_str:
|
||||
return False
|
||||
|
||||
# Skip if using tensor descriptors (TMA)
|
||||
if "tl.make_tensor_descriptor" in code_str:
|
||||
return False
|
||||
|
||||
# Separate trees by type
|
||||
valid_prefixes = {'x', 'y', 'z'}
|
||||
valid_trees = []
|
||||
reduction_trees = []
|
||||
|
||||
for tree in self.range_trees:
|
||||
if tree.prefix in valid_prefixes:
|
||||
valid_trees.append(tree)
|
||||
elif tree.prefix.startswith('r') and (len(tree.prefix) == 1 or tree.prefix[1:].replace('_', '').isdigit()):
|
||||
reduction_trees.append(tree)
|
||||
|
||||
# Check if all reduction trees are trivial (numel=1)
|
||||
for tree in reduction_trees:
|
||||
try:
|
||||
size_hint = V.graph.sizevars.size_hint(tree.numel, fallback=999999)
|
||||
if size_hint > 1:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
# Need at least one valid dimension
|
||||
if len(valid_trees) == 0 or len(valid_trees) > 2:
|
||||
return False
|
||||
|
||||
# Check dimension sizes
|
||||
dimension_sizes = []
|
||||
for tree in valid_trees:
|
||||
try:
|
||||
size_hint = V.graph.sizevars.size_hint(
|
||||
tree.numel,
|
||||
fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
dimension_sizes.append((tree.prefix, size_hint))
|
||||
except:
|
||||
return False
|
||||
|
||||
#if dimension_sizes:
|
||||
# largest_size = max(size for _, size in dimension_sizes)
|
||||
# # Need sufficient work to hide memory latency
|
||||
# if largest_size < 512:
|
||||
# return False
|
||||
#else:
|
||||
# return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_split_dimension_info(self):
|
||||
"""Get information about which dimension to split for inner loop."""
|
||||
# Only consider x, y, z dimensions (skip reduction dimensions)
|
||||
valid_prefixes = {'x', 'y', 'z'}
|
||||
non_reduction_trees = [
|
||||
t for t in self.range_trees
|
||||
if t.prefix in valid_prefixes
|
||||
]
|
||||
|
||||
if not non_reduction_trees:
|
||||
return None, []
|
||||
|
||||
# Find dimension with largest numel
|
||||
split_tree = max(
|
||||
non_reduction_trees,
|
||||
key=lambda tree: V.graph.sizevars.size_hint(
|
||||
tree.numel, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
)
|
||||
|
||||
return split_tree, non_reduction_trees
|
||||
|
||||
def _get_largest_dimension_for_inner_loop(self):
|
||||
"""Determine which dimension to split based on largest numel."""
|
||||
if len(self.range_trees) == 1:
|
||||
return self.range_trees[0]
|
||||
|
||||
# Find dimension with largest numel
|
||||
largest_tree = max(
|
||||
self.range_trees,
|
||||
key=lambda tree: V.graph.sizevars.size_hint(
|
||||
tree.numel, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
)
|
||||
return largest_tree
|
||||
|
||||
def _codegen_pointwise_with_inner_loop(self):
|
||||
"""Generate pointwise kernel body with inner R0_BLOCK loop."""
|
||||
|
||||
# Save the current generated code
|
||||
saved_indexing = self.indexing_code.getvalue()
|
||||
saved_loads = self.loads.getvalue()
|
||||
saved_compute = self.compute.getvalue()
|
||||
saved_stores = self.stores.getvalue()
|
||||
|
||||
# Clear buffers for our custom generation
|
||||
self.indexing_code.clear()
|
||||
self.body.clear()
|
||||
self.loads.clear()
|
||||
self.compute.clear()
|
||||
self.stores.clear()
|
||||
|
||||
# Determine which dimension to split
|
||||
non_reduction_trees = [t for t in self.range_trees if t.prefix in {'x', 'y', 'z'}]
|
||||
if not non_reduction_trees:
|
||||
# Restore and fallback
|
||||
self.indexing_code.writelines(saved_indexing.split('\n'))
|
||||
self.loads.writelines(saved_loads.split('\n'))
|
||||
self.compute.writelines(saved_compute.split('\n'))
|
||||
self.stores.writelines(saved_stores.split('\n'))
|
||||
return
|
||||
|
||||
split_tree = max(
|
||||
non_reduction_trees,
|
||||
key=lambda tree: V.graph.sizevars.size_hint(
|
||||
tree.numel, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
)
|
||||
split_prefix = split_tree.prefix
|
||||
|
||||
# Add static assertion
|
||||
self.body.writeline(f"tl.static_assert({split_prefix.upper()}BLOCK % R0_BLOCK == 0)")
|
||||
|
||||
# Map prefix to program_id
|
||||
prefix_to_pid = {
|
||||
'x': "tl.program_id(0)",
|
||||
'y': "(tl.program_id(1) + tl.program_id(2) * tl.num_programs(1))",
|
||||
'z': "tl.program_id(2)"
|
||||
}
|
||||
|
||||
# Build prefix list for broadcast calculations
|
||||
all_prefixes = [t.prefix for t in non_reduction_trees]
|
||||
|
||||
# Generate offset+index+mask for NON-SPLIT dimensions
|
||||
for tree in non_reduction_trees:
|
||||
if tree.prefix != split_prefix:
|
||||
prefix = tree.prefix
|
||||
pid = prefix_to_pid[prefix]
|
||||
|
||||
idx = all_prefixes.index(prefix)
|
||||
if len(all_prefixes) == 1:
|
||||
broadcast = ""
|
||||
elif len(all_prefixes) == 2:
|
||||
broadcast = "[:, None]" if idx == 0 else "[None, :]"
|
||||
else:
|
||||
parts = ["None"] * len(all_prefixes)
|
||||
parts[idx] = ":"
|
||||
broadcast = f"[{', '.join(parts)}]"
|
||||
|
||||
self.body.writeline(f"{prefix}offset = {pid} * {prefix.upper()}BLOCK")
|
||||
self.body.writeline(
|
||||
f"{prefix}index = {prefix}offset + tl.arange(0, {prefix.upper()}BLOCK){broadcast}"
|
||||
)
|
||||
self.body.writeline(f"{prefix}mask = {prefix}index < {prefix}numel")
|
||||
|
||||
# For split dimension, generate offset only
|
||||
pid = prefix_to_pid[split_prefix]
|
||||
self.body.writeline(f"{split_prefix}offset = {pid} * {split_prefix.upper()}BLOCK")
|
||||
self.body.writeline(f"tile_start = {split_prefix}offset")
|
||||
|
||||
# Save original name
|
||||
original_name = split_tree.name
|
||||
|
||||
# Generate pipelined inner loop
|
||||
self.body.writeline(f"for r in tl.range(0, {split_prefix.upper()}BLOCK, R0_BLOCK, num_stages=2):")
|
||||
|
||||
with self.body.indent():
|
||||
self.body.writeline("lanes = tl.arange(0, R0_BLOCK)")
|
||||
|
||||
# Calculate broadcast for split dimension
|
||||
split_idx = all_prefixes.index(split_prefix)
|
||||
if len(all_prefixes) == 1:
|
||||
broadcast = ""
|
||||
elif len(all_prefixes) == 2:
|
||||
broadcast = "[:, None]" if split_idx == 0 else "[None, :]"
|
||||
else:
|
||||
parts = ["None"] * len(all_prefixes)
|
||||
parts[split_idx] = ":"
|
||||
broadcast = f"[{', '.join(parts)}]"
|
||||
|
||||
self.body.writeline(
|
||||
f"{split_prefix}index = (tile_start + r + lanes){broadcast}"
|
||||
)
|
||||
self.body.writeline(f"{split_prefix}mask = {split_prefix}index < {split_prefix}numel")
|
||||
|
||||
# Override the range tree name temporarily
|
||||
split_tree.name = f"{split_prefix}index"
|
||||
|
||||
# Generate derived indices
|
||||
for node_name, entry in self.range_tree_nodes.items():
|
||||
if hasattr(entry, 'expr') and entry.name != f"{split_prefix}index":
|
||||
line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
|
||||
self.body.writeline(line)
|
||||
|
||||
# Fix shapes in the saved code
|
||||
block_name = f"{split_prefix.upper()}BLOCK"
|
||||
|
||||
# Replace in all buffers
|
||||
import re
|
||||
|
||||
def fix_shapes(code):
|
||||
if not code:
|
||||
return code
|
||||
|
||||
# Replace [XBLOCK] with [R0_BLOCK]
|
||||
code = code.replace(f"[{block_name}]", "[R0_BLOCK]")
|
||||
|
||||
# Replace tl.full([XBLOCK], ...) with tl.full([R0_BLOCK], ...)
|
||||
code = re.sub(
|
||||
rf'tl\.full\(\[{block_name}\]',
|
||||
'tl.full([R0_BLOCK]',
|
||||
code
|
||||
)
|
||||
|
||||
# Replace broadcast_to(expr, [XBLOCK]) with broadcast_to(expr, [R0_BLOCK])
|
||||
code = re.sub(
|
||||
rf'broadcast_to\((.*?),\s*\[{block_name}\]\)',
|
||||
r'broadcast_to(\1, [R0_BLOCK])',
|
||||
code
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
# Apply fixes and write the code
|
||||
if saved_indexing:
|
||||
self.body.splice(fix_shapes(saved_indexing))
|
||||
if saved_loads:
|
||||
self.body.splice(fix_shapes(saved_loads))
|
||||
if saved_compute:
|
||||
self.body.splice(fix_shapes(saved_compute))
|
||||
if saved_stores:
|
||||
self.body.splice(fix_shapes(saved_stores))
|
||||
|
||||
# Restore original name
|
||||
split_tree.name = original_name
|
||||
|
||||
def codegen_body(self):
|
||||
"""
|
||||
@ -4015,6 +4166,15 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
):
|
||||
return
|
||||
|
||||
if self._should_use_pointwise_inner_loop():
|
||||
self._codegen_pointwise_with_inner_loop()
|
||||
# Clear the buffers since we've used them
|
||||
self.indexing_code.clear()
|
||||
self.loads.clear()
|
||||
self.compute.clear()
|
||||
self.stores.clear()
|
||||
return
|
||||
|
||||
loop_trees = [tree for tree in self.range_trees if tree.is_loop]
|
||||
if self.inside_reduction and len(loop_trees) > 0:
|
||||
# Write the loop headers.
|
||||
@ -4105,16 +4265,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
args.append(str(arg))
|
||||
elif isinstance(arg, SymbolicCallArg):
|
||||
hint = V.graph.sizevars.size_hint(
|
||||
arg.inner_expr,
|
||||
hint_override=self.hint_override,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
arg.inner_expr, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
args.append(str(hint))
|
||||
elif isinstance(arg, sympy.Expr):
|
||||
hint = V.graph.sizevars.size_hint(
|
||||
arg,
|
||||
hint_override=self.hint_override,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
arg, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
args.append(str(hint))
|
||||
else:
|
||||
@ -4173,11 +4329,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
)
|
||||
elif isinstance(arg_sig, SizeArg):
|
||||
symval_hint = V.graph.sizevars.size_hint(
|
||||
arg_sig.expr,
|
||||
hint_override=self.hint_override,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
||||
|
||||
# Force the seed_offset to be 0 so calls to the same kernel
|
||||
# using different seed offset will have the same benchmark harness.
|
||||
@ -4187,9 +4339,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
result.writeline(f"{var_name} = {symval_hint}")
|
||||
elif isinstance(arg_sig, WorkspaceArg):
|
||||
device = V.graph.get_current_device_or_throw()
|
||||
count = V.graph.sizevars.size_hint(
|
||||
arg_sig.count, hint_override=self.hint_override
|
||||
)
|
||||
count = V.graph.sizevars.size_hint(arg_sig.count)
|
||||
result.writeline(
|
||||
f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})"
|
||||
)
|
||||
@ -4284,7 +4434,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
"min_split_scan_rblock": config.triton.min_split_scan_rblock,
|
||||
"spill_threshold": config.triton.spill_threshold,
|
||||
"store_cubin": config.triton.store_cubin,
|
||||
"deterministic": config.deterministic,
|
||||
}
|
||||
if torch.version.hip is not None:
|
||||
inductor_meta["is_hip"] = True
|
||||
@ -4315,6 +4464,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
metadata, and benchmarking infra.
|
||||
"""
|
||||
|
||||
from ..utils import prefix_is_reduction
|
||||
|
||||
code = IndentedBuffer()
|
||||
|
||||
size_hints = {}
|
||||
@ -4426,6 +4577,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
add_constexpr_arg(f"{tree.prefix.upper()}BLOCK")
|
||||
|
||||
# In codegen_kernel(), replace the R0_BLOCK section with:
|
||||
if self._should_use_pointwise_inner_loop():
|
||||
add_constexpr_arg("R0_BLOCK")
|
||||
|
||||
if self.cooperative_reduction:
|
||||
add_constexpr_arg("RSPLIT")
|
||||
|
||||
@ -4457,12 +4612,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
**self.inductor_meta_common(),
|
||||
}
|
||||
|
||||
if config.deterministic or config.test_configs.force_filter_reduction_configs:
|
||||
inductor_meta["has_loadstore_with_contiguous_rdim"] = (
|
||||
self.has_load_with_contiguous_rdim
|
||||
or self.has_store_with_contiguous_rdim
|
||||
)
|
||||
|
||||
# Bail on 3d tiling, which has more complicated coalesce patterns
|
||||
looped_red = V.kernel.features.is_reduction() and not self.persistent_reduction
|
||||
tiling_scores = self.tiling_scores
|
||||
@ -4538,7 +4687,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
|
||||
if enable_pdl_codegen():
|
||||
if config.triton.enable_pdl:
|
||||
triton_meta["launch_pdl"] = True
|
||||
|
||||
# Triton compiler includes equal_to_1 args into constants even
|
||||
@ -4990,7 +5139,7 @@ class TritonScheduling(SIMDScheduling):
|
||||
)
|
||||
return cls.backend_features
|
||||
|
||||
def codegen_comment(self, node_schedule, kernel_name=None):
|
||||
def codegen_comment(self, node_schedule):
|
||||
wrapper = V.graph.wrapper_code
|
||||
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||
if origins:
|
||||
@ -5016,13 +5165,6 @@ class TritonScheduling(SIMDScheduling):
|
||||
f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
|
||||
)
|
||||
|
||||
if kernel_name:
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
kernel_name,
|
||||
)
|
||||
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
|
||||
|
||||
def define_kernel(self, src_code, node_schedule, kernel):
|
||||
wrapper = V.graph.wrapper_code
|
||||
if src_code in wrapper.src_to_kernel:
|
||||
|
@ -17,7 +17,7 @@ import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -30,9 +30,8 @@ from typing import (
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, set_feature_use
|
||||
from torch._dynamo.utils import set_feature_use
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor import metrics
|
||||
from torch._prims_common import compute_required_storage_length
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
@ -237,7 +236,7 @@ def check_autotune_cache(
|
||||
not disabled
|
||||
and filename is not None
|
||||
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
|
||||
and os.environ.get("TRITON_INTERPRET", "0") != "1"
|
||||
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
|
||||
):
|
||||
configs_hash = hash_configs(configs)
|
||||
|
||||
@ -312,8 +311,6 @@ class CachingAutotuner(KernelInterface):
|
||||
"device_type": self.device_props.type,
|
||||
}
|
||||
self.inductor_meta = {} if inductor_meta is None else inductor_meta
|
||||
self.deterministic_mode = self.inductor_meta.get("deterministic", False)
|
||||
|
||||
self.save_cache_hook = save_cache_hook
|
||||
self.mutated_arg_names = mutated_arg_names
|
||||
self.reset_to_zero_arg_names = (
|
||||
@ -378,7 +375,7 @@ class CachingAutotuner(KernelInterface):
|
||||
self.is_backward = False
|
||||
|
||||
# Mode for launch grid calculation
|
||||
self.grid_mode: Literal["python", "cpp"] = "python"
|
||||
self.grid_mode: Literal["python", "python_slow", "cpp"] = "python"
|
||||
|
||||
def is_statically_launchable(self):
|
||||
"""
|
||||
@ -485,8 +482,7 @@ class CachingAutotuner(KernelInterface):
|
||||
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
|
||||
device_prop = self.device_props
|
||||
if (
|
||||
not self.deterministic_mode
|
||||
and self.inductor_meta.get("dynamic_scale_rblock", True)
|
||||
self.inductor_meta.get("dynamic_scale_rblock", True)
|
||||
and not self.inductor_meta.get("persistent_reduction")
|
||||
and self.heuristic_type == HeuristicType.REDUCTION
|
||||
and self.size_hints is not None
|
||||
@ -690,17 +686,9 @@ class CachingAutotuner(KernelInterface):
|
||||
|
||||
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
|
||||
|
||||
def _create_compile_meta(self, cfg: Config) -> dict[str, Any]:
|
||||
"""
|
||||
Create compilation metadata for a given autotuner config. This involves
|
||||
processing the Config kwargs so that the kwargs that are not part
|
||||
of the triton signature are passed in as options to triton.compile
|
||||
instead
|
||||
"""
|
||||
def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
|
||||
"""Ahead of time compile a given autotuner config."""
|
||||
compile_meta = copy.deepcopy(self.triton_meta)
|
||||
compile_meta["num_warps"] = cfg.num_warps
|
||||
compile_meta["num_stages"] = cfg.num_stages
|
||||
|
||||
cfg_kwargs = cfg.kwargs
|
||||
if self.device_props.type == "hip":
|
||||
cfg_kwargs = {**cfg_kwargs}
|
||||
@ -708,13 +696,14 @@ class CachingAutotuner(KernelInterface):
|
||||
if k in cfg_kwargs:
|
||||
compile_meta[k] = cfg_kwargs.pop(k)
|
||||
compile_meta["constants"].update(cfg_kwargs)
|
||||
|
||||
for i in self.fn.constexprs:
|
||||
arg_name = self.fn.arg_names[i]
|
||||
if arg_name not in compile_meta["constants"] and (
|
||||
arg_name == "num_warps" or arg_name == "num_stages"
|
||||
):
|
||||
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
|
||||
compile_meta["num_warps"] = cfg.num_warps
|
||||
compile_meta["num_stages"] = cfg.num_stages
|
||||
if HAS_WARP_SPEC:
|
||||
compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
|
||||
compile_meta["num_buffers_warp_spec"] = getattr(
|
||||
@ -728,53 +717,6 @@ class CachingAutotuner(KernelInterface):
|
||||
compile_meta["device_type"] = self.device_props.type
|
||||
compile_meta["cc"] = self.device_props.cc
|
||||
|
||||
return compile_meta
|
||||
|
||||
def _create_compile_options(
|
||||
self, cfg: Config, compile_meta: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create options to pass to triton.compile based on the compile metadata
|
||||
and the given config.
|
||||
"""
|
||||
options = {
|
||||
"num_warps": compile_meta["num_warps"],
|
||||
"num_stages": compile_meta["num_stages"],
|
||||
"debug": compile_meta["debug"],
|
||||
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
||||
}
|
||||
if "enable_fp_fusion" in compile_meta:
|
||||
options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
|
||||
if HAS_WARP_SPEC:
|
||||
options.update(
|
||||
{
|
||||
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
|
||||
"num_buffers_warp_spec": compile_meta.get(
|
||||
"num_buffers_warp_spec", 0
|
||||
),
|
||||
}
|
||||
)
|
||||
if self.device_props.type == "cuda":
|
||||
options.update(
|
||||
{
|
||||
"launch_cooperative_grid": compile_meta.get(
|
||||
"launch_cooperative_grid", False
|
||||
),
|
||||
"launch_pdl": compile_meta.get("launch_pdl", False), # True
|
||||
}
|
||||
)
|
||||
if self.device_props.type == "hip":
|
||||
if "waves_per_eu" in compile_meta:
|
||||
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
||||
if "matrix_instr_nonkdim" in compile_meta:
|
||||
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
|
||||
|
||||
return options
|
||||
|
||||
def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
|
||||
"""Ahead of time compile a given autotuner config."""
|
||||
compile_meta = self._create_compile_meta(cfg)
|
||||
|
||||
if self.device_props.type == "cpu":
|
||||
triton_helpers.set_driver_to_cpu()
|
||||
else:
|
||||
@ -807,8 +749,37 @@ class CachingAutotuner(KernelInterface):
|
||||
cc_warp_size(compile_meta["cc"]),
|
||||
)
|
||||
|
||||
options = self._create_compile_options(cfg, compile_meta)
|
||||
|
||||
options = {
|
||||
"num_warps": compile_meta["num_warps"],
|
||||
"num_stages": compile_meta["num_stages"],
|
||||
"debug": compile_meta["debug"],
|
||||
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
||||
}
|
||||
if "enable_fp_fusion" in compile_meta:
|
||||
options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
|
||||
if HAS_WARP_SPEC:
|
||||
options.update(
|
||||
{
|
||||
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
|
||||
"num_buffers_warp_spec": compile_meta.get(
|
||||
"num_buffers_warp_spec", 0
|
||||
),
|
||||
}
|
||||
)
|
||||
if self.device_props.type == "cuda":
|
||||
options.update(
|
||||
{
|
||||
"launch_cooperative_grid": compile_meta.get(
|
||||
"launch_cooperative_grid", False
|
||||
),
|
||||
"launch_pdl": compile_meta.get("launch_pdl", False), # True
|
||||
}
|
||||
)
|
||||
if self.device_props.type == "hip":
|
||||
if "waves_per_eu" in compile_meta:
|
||||
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
||||
if "matrix_instr_nonkdim" in compile_meta:
|
||||
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
|
||||
compile_kwargs = {
|
||||
"target": target,
|
||||
"options": options,
|
||||
@ -898,34 +869,25 @@ class CachingAutotuner(KernelInterface):
|
||||
)
|
||||
# reset to zero before evaluating any config
|
||||
self.reset_to_zero_args(*args, **kwargs)
|
||||
kernel_name = self.inductor_meta.get("kernel_name", "triton kernel")
|
||||
if autograd_profiler._is_profiler_enabled:
|
||||
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
||||
with torch._C._profiler._RecordFunctionFast(
|
||||
kernel_name,
|
||||
self.inductor_meta.get("kernel_name", "triton kernel"),
|
||||
cloned_args,
|
||||
profiler_kwargs,
|
||||
):
|
||||
try:
|
||||
launcher(
|
||||
*cloned_args,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception:
|
||||
log.error("Failed during launch %s: ", kernel_name)
|
||||
raise
|
||||
|
||||
else:
|
||||
try:
|
||||
launcher(
|
||||
*cloned_args,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception:
|
||||
log.error("Failed during launch %s: ", kernel_name)
|
||||
raise
|
||||
|
||||
else:
|
||||
launcher(
|
||||
*cloned_args,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
self.restore_args_from_cpu(cpu_copies)
|
||||
|
||||
# only use profiler when not already in a profiler instance
|
||||
@ -1092,18 +1054,6 @@ class CachingAutotuner(KernelInterface):
|
||||
k.shared,
|
||||
)
|
||||
|
||||
if metrics.is_metric_table_enabled("kernel_autotune"):
|
||||
if self.fn.fn is None:
|
||||
self.fn = self._reload_kernel().fn
|
||||
|
||||
kernel_path = self.fn.fn.__code__.co_filename
|
||||
kernel_name = self.fn.__name__
|
||||
|
||||
for k, v in timings.items():
|
||||
metrics.log_kernel_autotune_result(
|
||||
kernel_path, kernel_name, k.config, v
|
||||
)
|
||||
|
||||
self.reset_to_zero_args(*args, **kwargs)
|
||||
return timings
|
||||
|
||||
@ -1198,26 +1148,13 @@ class CachingAutotuner(KernelInterface):
|
||||
Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
|
||||
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
|
||||
"""
|
||||
if self.heuristic_type in (
|
||||
HeuristicType.TEMPLATE,
|
||||
HeuristicType.USER_AUTOTUNE,
|
||||
HeuristicType.FIXED,
|
||||
if (
|
||||
self.heuristic_type == HeuristicType.TEMPLATE
|
||||
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
|
||||
):
|
||||
# skip triton template
|
||||
return launcher
|
||||
|
||||
if self.deterministic_mode and self.heuristic_type in (
|
||||
HeuristicType.REDUCTION,
|
||||
HeuristicType.PERSISTENT_REDUCTION,
|
||||
HeuristicType.SPLIT_SCAN,
|
||||
):
|
||||
# Not only RBLOCK size matters for numericals of reduction.
|
||||
# num_warps also matters since that affect how much data
|
||||
# is handled by each thread, how many warp-reduction we do
|
||||
# in parallel and how much data is there for block
|
||||
# reduction.
|
||||
return launcher
|
||||
|
||||
with dynamo_timed(
|
||||
"CachingAutotuner.coordinate_descent_tuning",
|
||||
# These generate too many pt2_compile_event logs:
|
||||
@ -1251,7 +1188,6 @@ class CachingAutotuner(KernelInterface):
|
||||
config2launcher[config] = launcher
|
||||
|
||||
out = self.bench(launcher, *args, **kwargs)
|
||||
counters["inductor"]["coordesc_tuning_bench"] += 1
|
||||
log.debug(
|
||||
"COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
|
||||
launcher.config,
|
||||
@ -2574,7 +2510,27 @@ def pointwise(
|
||||
|
||||
configs = None
|
||||
if len(size_hints) == 1:
|
||||
if disable_pointwise_autotuning(inductor_meta) and not (
|
||||
|
||||
use_looped_pointwise = True
|
||||
has_r0_block = any(
|
||||
arg == "R0_BLOCK"
|
||||
for arg in triton_meta.get("signature", {}).keys()
|
||||
)
|
||||
if has_r0_block:
|
||||
configs = [
|
||||
Config({"XBLOCK": 512, "R0_BLOCK": 256}, num_warps=4, num_stages=1),
|
||||
Config({"XBLOCK": 1024, "R0_BLOCK": 512}, num_warps=4, num_stages=1),
|
||||
Config({"XBLOCK": 2048, "R0_BLOCK": 1024}, num_warps=8, num_stages=1),
|
||||
Config({"XBLOCK": 4096, "R0_BLOCK": 2048}, num_warps=8, num_stages=1),
|
||||
]
|
||||
#configs = [
|
||||
# triton_config_with_settings(size_hints, 512),
|
||||
# triton_config_with_settings(size_hints, 1024),
|
||||
# triton_config_with_settings(size_hints, 2048),
|
||||
# triton_config_with_settings(size_hints, 4096),
|
||||
#]
|
||||
|
||||
elif disable_pointwise_autotuning(inductor_meta) and not (
|
||||
inductor_meta.get("max_autotune")
|
||||
or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
@ -2588,7 +2544,36 @@ def pointwise(
|
||||
*hinted_configs,
|
||||
]
|
||||
if len(size_hints) == 2:
|
||||
if (
|
||||
use_looped_pointwise = True
|
||||
has_r0_block = any(
|
||||
arg == "R0_BLOCK"
|
||||
for arg in triton_meta.get("signature", {}).keys()
|
||||
)
|
||||
if has_r0_block:
|
||||
# Find which dimension is larger
|
||||
dim_sizes = [(name, size) for name, size in size_hints.items()]
|
||||
larger_dim, larger_size = max(dim_sizes, key=lambda x: x[1])
|
||||
smaller_dim, smaller_size = min(dim_sizes, key=lambda x: x[1])
|
||||
|
||||
# Split the larger dimension with R0_BLOCK
|
||||
larger_block = larger_dim.upper() + "BLOCK"
|
||||
smaller_block = smaller_dim.upper() + "BLOCK"
|
||||
|
||||
configs = [
|
||||
Config({larger_block: 512, "R0_BLOCK": 256, smaller_block: 32}, num_warps=4, num_stages=1),
|
||||
Config({larger_block: 1024, "R0_BLOCK": 512, smaller_block: 32}, num_warps=4, num_stages=1),
|
||||
Config({larger_block: 512, "R0_BLOCK": 256, smaller_block: 64}, num_warps=4, num_stages=1),
|
||||
Config({larger_block: 1024, "R0_BLOCK": 512, smaller_block: 64}, num_warps=8, num_stages=1),
|
||||
]
|
||||
#configs = [
|
||||
# triton_config_with_settings(size_hints, 32, 32),
|
||||
# triton_config_with_settings(size_hints, 64, 64),
|
||||
# triton_config_with_settings(size_hints, 256, 16),
|
||||
# triton_config_with_settings(size_hints, 16, 256),
|
||||
#]
|
||||
|
||||
|
||||
elif (
|
||||
disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
|
||||
) and not (
|
||||
inductor_meta.get("max_autotune")
|
||||
@ -2878,144 +2863,6 @@ def adapt_config_for_tiling(
|
||||
)
|
||||
|
||||
|
||||
class ReductionConfigKey:
|
||||
"""
|
||||
The part of reduction configs that affect determinism.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
# persistent reduction does not have a RBLOCK, use -1 as a flag
|
||||
self.r0_block = config.kwargs.get("R0_BLOCK", -1)
|
||||
self.r1_block = config.kwargs.get("R1_BLOCK", -1)
|
||||
self.num_warps = config.num_warps
|
||||
self.num_ctas = config.num_ctas
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.r0_block, self.r1_block, self.num_warps, self.num_ctas))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, ReductionConfigKey)
|
||||
and self.r0_block == other.r0_block
|
||||
and self.r1_block == other.r1_block
|
||||
and self.num_warps == other.num_warps
|
||||
and self.num_ctas == other.num_ctas
|
||||
)
|
||||
|
||||
|
||||
def filter_reduction_configs_for_determinism(
|
||||
inductor_meta: dict[str, Any], configs: list[Config]
|
||||
) -> list[Config]:
|
||||
"""
|
||||
Filter configs for reduction so the numerics can be deterministic.
|
||||
|
||||
This function group configs by fields that affect determinism
|
||||
- rblock size
|
||||
- num warps
|
||||
- num ctas
|
||||
and return the most promising group based on heuristics.
|
||||
|
||||
Heuristics:
|
||||
- skip reduction configs with too small RBLOCK
|
||||
- skip reduction configs with XBLOCK==1 if we are confident it will not perform well
|
||||
- pick the group with largest size: autotuning more configs may have more chance to give better perf
|
||||
- if there is a tie, pick the group with second largest RBLOCK
|
||||
- if there is still a tie, pick the group with second largest num_warps
|
||||
"""
|
||||
configs = unique_configs(configs)
|
||||
assert len(configs) > 0
|
||||
|
||||
def _do_filter_due_to_inductor_config():
|
||||
return (
|
||||
inductor_meta.get("deterministic", False)
|
||||
or torch._inductor.config.test_configs.force_filter_reduction_configs
|
||||
)
|
||||
|
||||
if not _do_filter_due_to_inductor_config() or len(configs) == 1:
|
||||
# no filtering happening if NOT in deterministic mode
|
||||
return configs
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug("reduction configs before filtering:")
|
||||
for c in configs:
|
||||
log.debug("%s", c)
|
||||
log.debug("")
|
||||
|
||||
def _has_too_small_rblock(config):
|
||||
rblock = config.kwargs.get("R0_BLOCK")
|
||||
# too small RBLOCK is likely to be bad
|
||||
return rblock is not None and rblock <= 4
|
||||
|
||||
def _nonpromising_xblock_1(config):
|
||||
# kernel like https://gist.github.com/shunting314/0b3281c087e79bc915fe45985ff9d7d5
|
||||
# without a load/store having contiguous rdim is unlikely to perform well with XBLOCK==1
|
||||
return config.kwargs["XBLOCK"] == 1 and not inductor_meta.get(
|
||||
"has_loadstore_with_contiguous_rdim", True
|
||||
)
|
||||
|
||||
newconfigs = [*filter(lambda x: not _has_too_small_rblock(x), configs)]
|
||||
# accept the filtering only if there are configs left
|
||||
if len(newconfigs) > 0:
|
||||
configs = newconfigs
|
||||
|
||||
newconfigs = [*filter(lambda x: not _nonpromising_xblock_1(x), configs)]
|
||||
if len(newconfigs) > 0:
|
||||
configs = newconfigs
|
||||
|
||||
groups: defaultdict[ReductionConfigKey, list[Config]] = defaultdict(
|
||||
list
|
||||
) # group configs by RBLOCK, num_warps, num_ctas
|
||||
|
||||
for c in configs:
|
||||
key = ReductionConfigKey(c)
|
||||
groups[key].append(c)
|
||||
|
||||
assert len(groups) > 0
|
||||
|
||||
def _pick_group():
|
||||
grouplist = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
max_group_size = len(grouplist[0][1])
|
||||
grouplist = [*filter(lambda g: len(g[1]) == max_group_size, grouplist)]
|
||||
|
||||
assert len(grouplist) > 0
|
||||
if len(grouplist) == 1:
|
||||
return grouplist[0][1]
|
||||
|
||||
# break tie by R0_BLOCK
|
||||
grouplist = sorted(grouplist, key=lambda x: x[0].r0_block)
|
||||
if grouplist[0][0].r0_block != grouplist[-1][0].r0_block:
|
||||
max_r0_block = grouplist[-1][0].r0_block
|
||||
grouplist = [*filter(lambda x: x[0].r0_block != max_r0_block, grouplist)]
|
||||
second_max_r0_block = grouplist[-1][0].r0_block
|
||||
grouplist = [
|
||||
*filter(lambda x: x[0].r0_block == second_max_r0_block, grouplist)
|
||||
]
|
||||
if len(grouplist) == 1:
|
||||
return grouplist[0][1]
|
||||
|
||||
# break tie by num_warps
|
||||
grouplist = sorted(grouplist, key=lambda x: x[0].num_warps)
|
||||
if grouplist[0][0].num_warps != grouplist[-1][0].num_warps:
|
||||
max_num_warps = grouplist[-1][0].num_warps
|
||||
grouplist = [*filter(lambda x: x[0].num_warps != max_num_warps, grouplist)]
|
||||
second_max_num_warps = grouplist[-1][0].num_warps
|
||||
grouplist = [
|
||||
*filter(lambda x: x[0].num_warps == second_max_num_warps, grouplist)
|
||||
]
|
||||
|
||||
# there is still a tie, pick the first one
|
||||
return grouplist[0][1]
|
||||
|
||||
configs = _pick_group()
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug("reduction configs after filtering:")
|
||||
for c in configs:
|
||||
log.debug("%s", c)
|
||||
log.debug("")
|
||||
return configs
|
||||
|
||||
|
||||
def reduction(
|
||||
size_hints,
|
||||
reduction_hint=False,
|
||||
@ -3041,7 +2888,6 @@ def reduction(
|
||||
)
|
||||
|
||||
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
||||
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
|
||||
return cached_autotune(
|
||||
size_hints,
|
||||
configs=configs,
|
||||
@ -3089,7 +2935,6 @@ def cooperative_reduction(
|
||||
# TODO(jansel): add more configs in max_autotune
|
||||
|
||||
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
||||
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
|
||||
return cached_autotune(
|
||||
size_hints,
|
||||
configs=configs,
|
||||
@ -3141,12 +2986,15 @@ def _persistent_reduction_configs(
|
||||
if "y" in size_hints:
|
||||
pass
|
||||
# TODO(jansel): we should be able to improve these heuristics
|
||||
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
|
||||
elif reduction_hint == ReductionHint.INNER:
|
||||
if rnumel > 1024:
|
||||
configs = configs[:1]
|
||||
else:
|
||||
x_block = 8
|
||||
if xnumel // x_block < 128 or loads_and_stores >= 5:
|
||||
if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256):
|
||||
# If loads/stores greater than 5, a lot of register pressure
|
||||
# rnumel < 256 means no vectorized loads if we split up r dim
|
||||
# so xblock still needs to be larger
|
||||
x_block = 1
|
||||
|
||||
configs = [
|
||||
@ -3202,7 +3050,6 @@ def persistent_reduction(
|
||||
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
||||
inductor_meta.pop(persistent_reduction_key)
|
||||
|
||||
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
|
||||
return cached_autotune(
|
||||
size_hints,
|
||||
configs,
|
||||
@ -3240,7 +3087,6 @@ def split_scan(
|
||||
cfg.kwargs[var] = min_rblock
|
||||
|
||||
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
||||
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
|
||||
return cached_autotune(
|
||||
size_hints,
|
||||
configs=configs,
|
||||
@ -3380,14 +3226,14 @@ class GridExpr:
|
||||
"""Generate code for grid size expressions in launcher"""
|
||||
|
||||
inductor_meta: dict[str, Any]
|
||||
mode: Literal["python", "cpp"] = "python"
|
||||
mode: Literal["python", "cpp", "python_slow"] = "python"
|
||||
prefix: list[str] = dataclasses.field(default_factory=list)
|
||||
x_grid: Union[str, int] = 1
|
||||
y_grid: Union[str, int] = 1
|
||||
z_grid: Union[str, int] = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.mode in ("python", "cpp")
|
||||
assert self.mode in ("python", "cpp", "python_slow")
|
||||
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
raise NotImplementedError
|
||||
@ -3403,6 +3249,10 @@ class GridExpr:
|
||||
# negative integer division is floored
|
||||
if self.mode == "python":
|
||||
return f"-(({numel}) // -({block}))"
|
||||
# This is more generic than above, and works in languages where
|
||||
# positive integer division is floored/truncated
|
||||
elif self.mode == "python_slow":
|
||||
return f"(({numel} + {block} - 1) // ({block}))"
|
||||
# For cpp code gen
|
||||
return f"(({numel} + ({block} - 1)) / ({block}))"
|
||||
|
||||
@ -3411,7 +3261,7 @@ class GridExpr:
|
||||
items = self._constant_fold(max, seq)
|
||||
if len(items) <= 1:
|
||||
return items[0]
|
||||
if self.mode == "python":
|
||||
if self.mode in ("python", "python_slow"):
|
||||
return f"max({', '.join(map(str, items))})"
|
||||
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
|
||||
|
||||
@ -3434,7 +3284,7 @@ class GridExpr:
|
||||
|
||||
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
|
||||
# Grid functions are one per kernel, so name collisions are fine
|
||||
if self.mode == "python":
|
||||
if self.mode in ("python", "python_slow"):
|
||||
return f"{name} = {expr}"
|
||||
if self.mode == "cpp":
|
||||
return f"uint32_t {name} = {expr};"
|
||||
@ -3444,7 +3294,7 @@ class GridExpr:
|
||||
def from_meta(
|
||||
inductor_meta: dict[str, Any],
|
||||
cfg: Union[Config, dict[str, int]],
|
||||
mode: Literal["python", "cpp"] = "python",
|
||||
mode: Literal["python", "cpp", "python_slow"] = "python",
|
||||
) -> GridExpr:
|
||||
grid_cls = globals()[inductor_meta["grid_type"]]
|
||||
assert issubclass(grid_cls, GridExpr)
|
||||
|
Reference in New Issue
Block a user