Compare commits

...

25 Commits

Author SHA1 Message Date
cdb20620d7 Update signature 2025-11-03 09:56:06 +00:00
354ff68a7e Update number of signature arguments in test 2025-11-03 09:56:06 +00:00
6b6c8a72d9 Revert change to pyrefly comment 2025-11-03 09:56:06 +00:00
b53aade896 Lint 2025-11-03 09:56:06 +00:00
f626a97382 Use argsort_sym, small refactor of remove_dims 2025-11-03 09:56:06 +00:00
e0a4fd9c4f Add transpose tensor descriptor parameter override for templates and fix is_reduction check 2025-11-03 09:56:05 +00:00
43aa5b8483 Add reverse paramter to argsort functions 2025-11-03 09:56:05 +00:00
1128e38bad Fix tests after change block_ptr order parameter 2025-11-03 09:56:05 +00:00
bdb250fd06 Add todo on benchmarking transposes applied to tensor descriptors that have compliant strides 2025-11-03 09:56:05 +00:00
c3c67f099d Add override for tma templates, simplify range tree reorder algorithm and generalise stride sorter class 2025-11-03 09:56:05 +00:00
5febc621c3 Lint 2025-11-03 09:56:05 +00:00
0f454ea394 Simplify argsort 2025-11-03 09:56:05 +00:00
efb66325eb Enable transpose by default 2025-11-03 09:56:05 +00:00
f38377e39f Count number of transposes 2025-11-03 09:56:05 +00:00
a4f98825a0 Remove unneeded code 2025-11-03 09:56:05 +00:00
f3bbe21dd3 Remove xfails 2025-11-03 09:56:05 +00:00
022ca38933 Fix docstring 2025-11-03 09:56:05 +00:00
2b8449e314 Fix docstring 2025-11-03 09:56:05 +00:00
00b2e4231e Lint 2025-11-03 09:56:05 +00:00
54423311de Check if min block size is larger than max block 2025-11-03 09:56:05 +00:00
51bfa1685c Lint 2025-11-03 09:56:05 +00:00
75b2514651 Add tests & lint 2025-11-03 09:56:05 +00:00
06293c59ec Reorder range trees rather than transpose 2025-11-03 09:56:05 +00:00
4d18feaef0 Remove breakpoint and fix type 2025-11-03 09:56:05 +00:00
3a2af4e8d1 Transpose discontiguous tensor descriptors 2025-11-03 09:56:05 +00:00
8 changed files with 417 additions and 137 deletions

View File

@ -1928,7 +1928,7 @@ class TestMaxAutotune(TestCase):
# Make sure all args of generate_and_load_args are passed to make_key_args (Except generate_with_caching)
# update this function each time new arg added to generate_and_load and make sure arg is added to make_key
self.assertEqual(generate_and_load_args - 1, make_key_args)
self.assertEqual(generate_and_load_args, 18)
self.assertEqual(generate_and_load_args, 19)
@fresh_cache()
@config.patch(
@ -2017,6 +2017,7 @@ class TestMaxAutotune(TestCase):
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[10,30],
'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]",
'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False,
'transpose_discontiguous_tensor_descriptors_override':None,
'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32',
'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True},'hint_override':None}"""
@ -2056,8 +2057,10 @@ class TestMaxAutotune(TestCase):
"[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"],
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94],
'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0,
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False,'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,
'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True},'hint_override':None}"""
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False,
'transpose_discontiguous_tensor_descriptors_override':None,
'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,
'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True},'hint_override':None}"""
expected = expected.replace("cuda", GPU_TYPE)
self.assertExpectedInline(
remove_white_space(cache_key),

View File

@ -81,7 +81,6 @@ TMA_TEST_XFAIL = dict.fromkeys(
"test_broadcast_prefer_nd_tiling_False_x_size2_y_size2",
"test_broadcast_prefer_nd_tiling_True_x_size0_y_size0",
"test_broadcast_prefer_nd_tiling_True_x_size2_y_size2",
"test_broadcast_with_singleton_dims",
),
TMA_XFAIL,
)
@ -168,8 +167,6 @@ class BlockDescriptorTestBase(InductorTestCase):
self.assertEqual(len(code), expected_num_programs)
count_code("@triton.jit", expected_num_triton_kernels)
count_code(self.block_descriptor_constructor_str, expected_num_block_pointers)
# Verify that 1D shapes aren't being transposed for the TMA store.
count_code("tl.trans", 0)
return result, code
@ -478,12 +475,12 @@ class CommonTemplate:
self.assertExpectedInline(
load_lines,
"""\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1])
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[0, 1], offsets=[yoffset, xoffset]), boundary_check=[0, 1])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[:, None]""", # noqa: B950
)
self.assertExpectedInline(
store_lines,
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp2, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[0, 1], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp2, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950
)
else:
self.assertExpectedInline(
@ -906,7 +903,6 @@ class CommonTemplate:
# Check for 2 reduction dimensions.
self._assert_reduction_ndims(code, 2)
@xfail_if_use_tensor_descriptor # Cannot use TMA API for store with no x dimension.
@test_torchinductor.skip_if_triton_cpu # Illegal instruction File; cannot xfail because it crashes process
def test_2d_reduction_multi_kernel(self):
"""
@ -1017,7 +1013,6 @@ class CommonTemplate:
# Check the code for multiple Rn_BLOCK's
self._assert_reduction_ndims(code, 2 if tile_reductions else 1)
@xfail_if_use_tensor_descriptor
def test_complex_reshape_block_ptr(self):
def func(x, y):
add_ = x + y
@ -1154,13 +1149,13 @@ class CommonTemplate:
self.assertExpectedInline(
load_lines,
"""\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", # noqa: B950
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[0, 1, 2], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[0, 1, 2], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", # noqa: B950
)
self.assertExpectedInline(
store_lines,
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])""", # noqa: B950
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[0, 1, 2], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])""", # noqa: B950
)
# Check the indices. These are used for non-block pointers.
@ -1236,7 +1231,6 @@ class CommonTemplate:
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
# }
# This is now fixed by ensuring that that wild symbols only match integers
@xfail_if_use_tensor_descriptor
@skipIfXpu(
msg="Triton issue exposed by new driver, will be resolved after next triton update."
)
@ -1406,10 +1400,51 @@ test_torchinductor.copy_tests(CommonTemplate, TritonBlockPointerTestGPU, GPU_TYP
"Requires Triton CUDA backend and CUDA compute capability >= 9.0",
)
@config.patch({"triton.use_tensor_descriptor": True, "assume_aligned_inputs": True})
@instantiate_parametrized_tests
class TritonTensorDescriptorTestCUDA(BlockDescriptorTestBase):
block_descriptor_constructor_str = "tl.make_tensor_descriptor"
device = GPU_TYPE
@config.patch({"triton.transpose_discontiguous_tensor_descriptor": True})
@parametrize(
"view_size,permute_order,num_tensor_descriptors,expect_transpose",
[
((128,), (0,), 3, False),
((128, 128), (0, 1), 3, False),
((128, 64), (1, 0), 3, True),
((256, 32, 16), (2, 0, 1), 3, True),
((16, 32, 256), (2, 0, 1), 3, True),
],
)
def test_match_with_transpose(
self,
view_size: tuple[int],
permute_order: tuple[int],
num_tensor_descriptors: int,
expect_transpose: bool,
):
a = self._discontiguous_tensor(view_size, self.device)
pre_permute_size = [1] * len(view_size)
for i, value in zip(permute_order, view_size):
pre_permute_size[i] = value
b = self._discontiguous_tensor(pre_permute_size, self.device)
b = b.permute(permute_order)
def fn(a, b):
return a * b
result, (code,) = self._run_and_compare(
fn,
a,
b,
expected_num_block_pointers=num_tensor_descriptors,
expected_num_triton_kernels=1,
config_patches=tiled_reduction_config,
)
transpose_count = code.count("tl.trans")
self.assertEqual(transpose_count, 1 if expect_transpose else 0)
test_torchinductor.copy_tests(
CommonTemplate,

View File

@ -11,9 +11,10 @@ import math
import operator
import os
import textwrap
from abc import abstractmethod
from collections.abc import Iterable, Sequence
from functools import lru_cache
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
import sympy
from sympy.printing.precedence import PRECEDENCE
@ -30,7 +31,7 @@ from torch.utils._triton import has_triton_package, has_triton_stable_tma_api
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from .. import config, ir, metrics, utils
from ..async_compile import AsyncCompile
from ..codecache import code_hash, get_path, PyCodeCache, write_atomic
from ..debug import set_kernel_post_grad_provenance_tracing
@ -105,9 +106,9 @@ from .wrapper import SymbolicCallArg
if TYPE_CHECKING:
from types import ModuleType
from typing import TypeVar
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from ..ir import IRNode
from .common import BlockShapeType
@ -273,14 +274,6 @@ class TritonSymbols:
assert expr_shape is not None
# Below logic handles when index symbols does not match with convention range tree order.
# Mainly, it is for TMA template where TMA indices are expected to be in (x,y), not (y,x).
# so in such case, the get_block_shape(yindex) should be (1,YBLOCK), not (YBLOCK,1).
if isinstance(V.kernel, torch._inductor.select_algorithm.TritonTemplateKernel):
out_shape = V.kernel.template_out_shape
if out_shape == ("XBLOCK", "YBLOCK") and V.kernel.tma_store:
expr_shape = (expr_shape[1], expr_shape[0], *expr_shape[2:])
return expr_shape
@classmethod
@ -341,6 +334,10 @@ class BlockDescriptorOptions:
broadcast_shape: Sequence[sympy.Expr]
broadcasting_dims: list[bool]
final_shape: Sequence[sympy.Expr]
# If the BlockParameters have been sorted using a particular stride order
# transpose load / store blocks at runtime using the information in
# stride_sorter.
stride_sorter: BlockParameters.StrideSorter
_boundary_check: Optional[list[int]] = None
# Can we safely lift the constructor
# to the top of the kernel?
@ -371,8 +368,8 @@ class BlockDescriptorOptions:
range_trees: list[IterationRangesRoot],
mask_vars: OrderedSet[str],
get_max_block: Callable[[str], int],
can_lift=False,
transpose_contiguous=False,
stride_sorter_cls: type[BlockParameters.StrideSorter],
can_lift: bool = False,
) -> BlockDescriptorOptions:
"""Helper to create a BlockDescriptorOptions instance"""
@ -385,14 +382,10 @@ class BlockDescriptorOptions:
params.shape = lookup_size(params.shape)
params.strides = lookup_size(params.strides)
# Strip out dimensions of stride 0.
# These will be restored with tl.broadcast_to.
broadcasting_dims = [
sizevars.statically_known_equals(stride, 0) for stride in params.strides
]
# Strip out dimensions of size 1.
# These will be restored by tl.reshape.
# Size 1 dimensions are redundant since the triton kernel shape
# will be e.g. [YBLOCK, XBLOCK], so tl.reshape would just remove these
# dimensions anyway
singleton_dims = [
sizevars.statically_known_equals(dim, 1) for dim in params.block_shape
]
@ -400,44 +393,28 @@ class BlockDescriptorOptions:
# Handle a pure singletons, e.g. [1, 1]
singleton_dims[-1] = False
# Drop singleton dimensions from the block descriptor.
params = params.remove_dims(singleton_dims)
# Maybe reorder dimensions based on strides
# with tl.trans applied at load / store time
params, stride_sorter = params.maybe_sort_with_stride_order(
stride_sorter_cls=stride_sorter_cls, shape_env=V.graph._shape_env
)
# Strip out dimensions of stride 0.
# These will be restored with tl.broadcast_to.
broadcasting_dims = [
sizevars.statically_known_equals(stride, 0) for stride in params.strides
]
# Record the post-broadcast shape before broadcasting dims are removed.
# The pre-broadcast shape is identical to this, except broadcasting dims are
# replaced with 1.
broadcast_shape = [
dim
for dim, is_singleton in zip(params.block_shape, singleton_dims)
if not is_singleton
]
broadcast_shape = params.block_shape
# Combine all removable dims.
removable_dims = [any(dims) for dims in zip(singleton_dims, broadcasting_dims)]
# Remove singleton_dims from broadcasting_dims so that
# broadcast_shape and broadcasting_dims have the same length
broadcasting_dims = [
dim
for dim, is_singleton in zip(broadcasting_dims, singleton_dims)
if not is_singleton
]
def remove_dims(it):
"""Removes any broadcasting or singleton dims from a given sequence"""
return [
item
for item, is_removable in zip(it, removable_dims)
if not is_removable
]
# Drop removable dimensions from the input.
params = BlockParameters(
**{
key: remove_dims(val) for key, val in dataclasses.asdict(params).items()
},
)
# TODO: Generalize to ND tensors.
transpose = transpose_contiguous and params.strides[-1] != 1
if transpose:
params = params.transpose()
# Drop broadcasting dims from the block descriptor.
params = params.remove_dims(broadcasting_dims)
# Compute the final shape, adjusting for special kernel types.
final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees]
@ -445,12 +422,6 @@ class BlockDescriptorOptions:
assert range_trees[0].prefix == "x"
final_shape.pop(0)
# Check for when BlockParams have been transposed.
order = list(reversed(range(len(params.shape))))
if transpose:
final_shape.reverse()
order.reverse()
reduction_ndim = V.kernel.num_reduction_dims
if (
not V.kernel.inside_reduction
@ -460,6 +431,14 @@ class BlockDescriptorOptions:
# Need to expand rank to match the rank used inside the reduction loop
final_shape += [sympy.S.One] * reduction_ndim
try:
# Get permutation to sort strides in descending order.
# This is used as the order argument in tl.make_block_ptr
order = utils.argsort_sym(V.graph._shape_env, params.strides, reverse=True)
except AssertionError:
# Symbolic shapes, failed to evaluate comparison expression
order = list(reversed(range(len(params.strides))))
result = cls(
params=params,
constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
@ -468,6 +447,7 @@ class BlockDescriptorOptions:
final_shape=final_shape,
broadcast_shape=broadcast_shape,
broadcasting_dims=broadcasting_dims,
stride_sorter=stride_sorter,
can_lift=can_lift,
)
result.compute_boundary_check(get_max_block, range_trees)
@ -567,21 +547,55 @@ class BlockDescriptorOptions:
initial_shape: Sequence[sympy.Expr],
final_shape: Sequence[sympy.Expr],
allow_implicit: bool,
for_store: bool,
) -> str:
"""
Generate a broadcast and a reshape for the block descriptor.
This restores stride-0 dimensions which were removed from the block descriptor.
Transposes are also applied to the input using self.stride_sorter:
if for_store is True:
- First Broadcast the value. Since self.broadcast_shape is stored in
descending stride order, it must be reverted to the original order
since the input value does not have dims with descending strides
- After, transpose the broadcasted value so that dimensions are in
descending stride order
- Finally reshape to the block shape
else (for load):
- First broadcast the value to self.broadcast_shape (strides are descending)
- Then transpose the value so that dimensions no longer have descending strides
- Finally reshape the block to the final kernel tile shape
"""
broadcast_shape = self.broadcast_shape
broadcasting_dims = self.broadcasting_dims
# If the block parameters have been sorted by descending strides,
# permute the broadcasting parameters so that they are compatible
# with the value being stored. This is because the dimensions
# of the value being stored are not sorted in descending stride order,
# but the broadcasting parameters are based on the dims in sorted order
if for_store:
broadcast_shape = self.stride_sorter.revert(self.broadcast_shape)
broadcasting_dims = self.stride_sorter.revert(self.broadcasting_dims)
# Reshape to add singletons.
pre_broadcast_shape = [
sympy.S.One if is_broadcasting else dim
for dim, is_broadcasting in zip(
self.broadcast_shape, self.broadcasting_dims
)
for dim, is_broadcasting in zip(broadcast_shape, broadcasting_dims)
]
value = triton_reshape(value, initial_shape, pre_broadcast_shape)
if (
not self.stride_sorter.is_identity
and not for_store
and len(pre_broadcast_shape) == len(final_shape)
):
# If all we need to do is transpose to match the final shape
# with implicit broadcasting then we don't need an explicit broadcast
# unless the caller requests it. So just test implicit broadcast support
# with the transposed pre broadcast shape
pre_broadcast_shape = self.stride_sorter.revert(pre_broadcast_shape)
# Broadcast singletons.
# For loads, we can often implicitly broadcast singleton dimensions.
# We need an explicit broadcast for stores, or if the final reshape does more
@ -597,10 +611,32 @@ class BlockDescriptorOptions:
)
if any(self.broadcasting_dims) and not supports_implicit_broadcast:
value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})"
value = (
f"tl.broadcast_to({value}, {V.kernel.index_to_str(broadcast_shape)})"
)
old_shape = self.broadcast_shape
if not self.stride_sorter.is_identity:
# if for_store the transform is
# (non-descending strides) broadcasted kernel tile shape
# -> (descending strides) block descriptor shape
# o/w if loading the transform is
# (descending strides) ((maybe implicitly) broadcasted block shape
# -> (non-descending) (maybe implicitly) broadcasted kernel tile shape
permute_dims = (
self.stride_sorter.sort_idx
if for_store
else self.stride_sorter.revert_sort_idx
)
value = f"tl.trans({value}, {permute_dims})"
old_shape = (
self.broadcast_shape
if for_store
else self.stride_sorter.revert(self.broadcast_shape)
)
# Reshape to the final shape.
value = triton_reshape(value, self.broadcast_shape, final_shape)
value = triton_reshape(value, old_shape, final_shape)
return value
@ -1984,6 +2020,99 @@ class BlockParameters:
strides: list[sympy.Expr] = dataclasses.field(default_factory=list)
offsets: list[sympy.Expr] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
class StrideSorter:
original_strides: list[int]
sort_idx: list[int]
revert_sort_idx: list[int] = dataclasses.field(init=False)
def __post_init__(self):
assert len(self.original_strides) > 0
assert len(self.sort_idx) == len(self.original_strides)
identity_sort_idx = list(range(len(self.original_strides)))
self._is_identity = self.sort_idx == identity_sort_idx
# Set revert_sort_idx
sorted_dims_by_strides_map = {k: i for i, k in enumerate(self.sort_idx)}
self.revert_sort_idx = [
sorted_dims_by_strides_map[i]
for i in range(len(sorted_dims_by_strides_map))
]
@property
def is_identity(self):
return self._is_identity
@classmethod
@abstractmethod
def create(
cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv
) -> BlockParameters.StrideSorter:
"""Create a `StrideSorter` that can be used to sort block parameters."""
def sort(self, attr):
if not self.is_identity:
return [attr[i] for i in self.sort_idx]
return attr
def revert(self, attr):
if not self.is_identity:
return [attr[i] for i in self.sort_idx]
return attr
@dataclasses.dataclass
class IdentityStrideSorter(StrideSorter):
def __post_init__(self):
super().__post_init__()
@classmethod
def create(
cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv
) -> BlockParameters.StrideSorter:
return cls(
original_strides=original_strides,
sort_idx=list(range(len(original_strides))),
)
@dataclasses.dataclass
class TensorDecriptorStrideSorter(StrideSorter):
"""
Sorts BlockParameters dimensions with strides in descending order.
"""
def __post_init__(self):
super().__post_init__()
@classmethod
def create(
cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv
) -> BlockParameters.StrideSorter:
"""
If the strides are not all known constants or if the strides are already
sorted in descending order, return identity sort.
For example if block_shape @ strides is [ZBLOCK, XBLOCK, YBLOCK] @ [8, 1, 16]
The indices to sort the strides in descending order will be [2, 0, 1].
The indices to revert back to the original order will be [1, 2, 0].
"""
identity_sort = list(range(len(original_strides)))
try:
# TODO: even if the strides are not in descending order the strides
# may be tensor descriptor compliant
# i.e. innermost stride == 1 and outer strides 16 byte aligned
# We should benchmark the effect of applying a transpose to these
# cases vs leaving them unsorted.
sort_idx = utils.argsort_sym(shape_env, original_strides, reverse=True)
except AssertionError:
# Symbolic shapes, failed to evaluate comparison expression
sort_idx = identity_sort
return cls(
original_strides=original_strides,
sort_idx=sort_idx,
)
def __add__(self, other: BlockParameters) -> BlockParameters:
"""
Concatenates block parameters.
@ -1992,12 +2121,37 @@ class BlockParameters:
a, b = tuple(dataclasses.asdict(x) for x in (self, other))
return cls(**{key: a[key] + b[key] for key in a})
def transpose(self) -> BlockParameters:
def maybe_sort_with_stride_order(
self, stride_sorter_cls: type[StrideSorter], shape_env: ShapeEnv
) -> tuple[BlockParameters, BlockParameters.StrideSorter]:
"""
Sort `BlockParameter` with stride_sorter_cls. Returns block parameters
as well as a `StrideSorter` which contains information on how the sort
can be reverted.
"""
stride_sorter = stride_sorter_cls.create(self.strides, shape_env=shape_env)
params = BlockParameters(
**{
key: stride_sorter.sort(val)
for key, val in dataclasses.asdict(self).items()
}
)
return params, stride_sorter
def remove_dims(self, removable_dims: list[bool]) -> BlockParameters:
"""
Remove dimensions where removable_dims is True.
"""
def filter_dims(it):
return [
item
for item, is_removable in zip(it, removable_dims)
if not is_removable
]
return BlockParameters(
self.shape[::-1],
self.block_shape[::-1],
self.strides[::-1],
self.offsets[::-1],
**{key: filter_dims(val) for key, val in dataclasses.asdict(self).items()},
)
@ -2131,8 +2285,9 @@ class TMACompatibilityChecker:
# and that the outer strides are 16 byte aligned
if not V.graph.sizevars.statically_known_equals(strides[-1], sympy.Integer(1)):
log.debug(
"%s TMA API requires innermost stride to be 1.",
"%s TMA API requires innermost stride to be 1. Strides are: %s",
self.failed_debug_prefix,
strides,
)
return False
@ -2143,8 +2298,10 @@ class TMACompatibilityChecker:
sympy.Integer(0),
):
log.debug(
"%s TMA API requires outer strides to be 16 byte aligned.",
"%s TMA API requires outer strides to be 16 byte aligned. Dtype bytes: %d, strides: %s",
self.failed_debug_prefix,
element_size,
strides,
)
return False
@ -2153,6 +2310,18 @@ class TMACompatibilityChecker:
# can be loaded / stored.
# Start with finding the innermost block type
innermost_block_shape = block_params.block_shape[-1]
# Pure singleton case
if V.graph.sizevars.statically_known_equals(
innermost_block_shape, sympy.Integer(1)
):
log.debug(
"%s innermost block shape cannot load 16 bytes. Block shape: %s",
self.failed_debug_prefix,
block_params.block_shape,
)
return False
innermost_block_type = None
innermost_block_symt = None
for block_type_str in innermost_block_shape.free_symbols:
@ -2161,6 +2330,7 @@ class TMACompatibilityChecker:
innermost_block_type = block_type_str
innermost_block_symt = block_symt
break
assert innermost_block_type and innermost_block_symt, (
f"{innermost_block_shape} expr must contain a single block type from {TritonSymbols.block_types}"
)
@ -2189,8 +2359,10 @@ class TMACompatibilityChecker:
innermost_block_bytes, sympy.Integer(16)
):
log.debug(
"%s persistent reduction innermost block shape cannot load 16 bytes.",
"%s persistent reduction innermost block shape cannot load 16 bytes. Block shape: %s, persistent RBLOCK: %d",
self.failed_debug_prefix,
block_params.block_shape,
persistent_rblock,
)
return False
@ -2199,17 +2371,45 @@ class TMACompatibilityChecker:
# then the TMA API can only be used if the dtype has an 8 byte element
# size so that 16 bytes of data can be loaded in the innermost dimension
try:
def indexing_div_rep(
x: sympy.Expr,
y: sympy.Expr,
z: Optional[sympy.Expr] = None,
) -> sympy.Expr:
div = x / y
if z:
div = div % z
return div
solve_expr = innermost_block_shape * element_size - 16
# Sympy cannot handle FloorDiv and ModularIndexing well, so simplify
solve_expr_simplified = solve_expr.replace(
FloorDiv, indexing_div_rep
).replace(ModularIndexing, indexing_div_rep)
min_block_size = next_power_of_2(
int(
sympy.nsolve(
innermost_block_shape * element_size - 16,
solve_expr_simplified,
innermost_block_type,
1,
)
)
)
block_type_str = V.kernel.index_to_str(innermost_block_type)
# TODO: min block size may be too large / introduce redundancy
if min_block_size > self.kernel.max_block(
prefix_str[innermost_block_symt]
):
log.debug(
"%s the minimum block size to satisfy expression %s is too large: %d",
self.failed_debug_prefix,
solve_expr_simplified,
min_block_size,
)
return False
block_type_str = self.kernel.index_to_str(innermost_block_type)
# Check block sizes if the user has provided a fixed triton config
if self.kernel.fixed_config:
if min_block_size > self.kernel.fixed_config[block_type_str]:
@ -2232,8 +2432,9 @@ class TMACompatibilityChecker:
except ValueError:
log.debug(
"%s innermost block shape cannot load 16 bytes.",
"%s innermost block shape cannot load 16 bytes. Block params: %s",
self.failed_debug_prefix,
block_params.block_shape,
)
return False
@ -2266,6 +2467,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
tensor_descriptor_options_cls: type[TensorDescriptorOptions] = (
TensorDescriptorOptions
)
transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None
def __init__(
self,
@ -2736,17 +2938,39 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
else self.tensor_descriptor_options_cls
)
nonlocal tma_compatibility_checker
stride_sorter_cls: type[BlockParameters.StrideSorter]
if config.triton.use_block_ptr:
can_lift = False
transpose_contiguous = False
stride_sorter_cls = BlockParameters.IdentityStrideSorter
else:
tma_compatibility_checker = cast(
TMACompatibilityChecker, tma_compatibility_checker
)
can_lift = tma_compatibility_checker.can_lift()
if (
self.transpose_discontiguous_tensor_descriptors_override
is not None
):
transpose_contiguous = (
self.transpose_discontiguous_tensor_descriptors_override
)
else:
transpose_contiguous = (
config.triton.transpose_discontiguous_tensor_descriptor
)
# For templates:
# Only try transpose if we know the output shape
# in case we need to transpose the data.
transpose_contiguous = copy_shape is not None
if hasattr(self, "template_out_shape"):
transpose_contiguous &= copy_shape is not None
stride_sorter_cls = (
BlockParameters.TensorDecriptorStrideSorter
if transpose_contiguous
else BlockParameters.IdentityStrideSorter
)
options = options_class.create(
params=block_params,
@ -2755,9 +2979,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
mask_vars=mask_vars,
get_max_block=self.max_block,
can_lift=can_lift,
transpose_contiguous=transpose_contiguous,
stride_sorter_cls=stride_sorter_cls,
)
if isinstance(options_class, TensorDescriptorOptions):
if isinstance(options, TensorDescriptorOptions):
tma_compatibility_checker = cast(
TMACompatibilityChecker, tma_compatibility_checker
)
@ -3005,30 +3229,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
return block_descriptor, other
def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
def stringify_shape(shape):
return tuple(
symt.name if isinstance(symt, sympy.Symbol) else str(symt)
for symt in shape
)
if value.shape:
value_forward_shape = stringify_shape(value.shape)
value_reverse_shape = stringify_shape(value.shape[::-1])
else:
value_forward_shape = None
value_reverse_shape = None
final_shape = stringify_shape(indexing.final_shape)
# TODO: Generalize to N Dimensions
if (
value_forward_shape != final_shape
and value_reverse_shape == final_shape
and len(final_shape) == 2
):
# TMA stores may require transposing the data to ensure we are contiguous along
# the final dimension. This applies to Block-pointers generally, but should only practically
# be reached with TMA.
value = f"tl.trans({value})"
# Stores require an explicit broadcast. We do this in two phases:
# 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
# YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
@ -3044,7 +3244,11 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
indexing.broadcasting_dims[idx] = False
value = indexing.codegen_broadcast_and_reshape(
value, indexing.final_shape, indexing.block_shape, False
value,
indexing.final_shape,
indexing.block_shape,
allow_implicit=False,
for_store=True,
)
# workaround https://github.com/triton-lang/triton/issues/2814
@ -3236,7 +3440,11 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
else:
line = f"{block_descriptor}.load({V.kernel.index_to_str(indexing.offsets)})"
line = indexing.codegen_broadcast_and_reshape(
line, indexing.block_shape, indexing.final_shape, True
line,
indexing.block_shape,
indexing.final_shape,
allow_implicit=True,
for_store=False,
)
shape = indexing.final_shape
elif is_sympy_integer_like(original_index):

View File

@ -1508,6 +1508,11 @@ class triton:
# can be satisfied, along with any existing requirements for index expressions
use_tensor_descriptor = False
# (Experimental)
# Whether to allow reordering tensor descriptor matches with descending
# strides, at the expense of transposing values after load / before store.
transpose_discontiguous_tensor_descriptor = True
# Inject a bug into our relu implementation; useful for testing our repro
# extraction and minification functionality.
# Valid values: "compile_error", "runtime_error", "accuracy"

View File

@ -2522,7 +2522,7 @@ def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Conf
if inductor_meta.get("persistent_reduction"):
tma_min_block_sizes = {
block_type: block_size
for block_type, block_size in tma_min_block_sizes
for block_type, block_size in tma_min_block_sizes.items()
if not prefix_is_reduction(block_type.lower())
}

View File

@ -390,6 +390,7 @@ class TritonTemplateKernel(TritonKernel):
num_buffers_warp_spec=0,
use_jit=False,
tma_store=False,
transpose_discontiguous_tensor_descriptors_override=None,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
@ -420,6 +421,28 @@ class TritonTemplateKernel(TritonKernel):
features=SIMDKernelFeatures([], numel),
hint_override=hint_override,
)
if tma_store:
# By default `construct_range_trees` will return the range_trees in the order
# ["z", "y", "x", "r0_", "r1_"] (see simd.py:all_prefixes)
# and this order defines what the kernel block shape will be. So if the template
# input / output has requested e.g. ["x", "y"], `construct_range_trees` will still return the
# trees in the order ["y", "x"]. This would mean that the template would need to transpose
# the loaded value.
# The below sorts the range trees according to that required by the caller
prefix_to_range_tree = {rt.prefix: rt for rt in self.range_trees}
pw_sorted_range_trees = []
reduction_idx = None
for i, prefix in enumerate(tiling):
rt = prefix_to_range_tree[prefix]
if rt.is_reduction:
reduction_idx = i
break
rt.index = i
rt.grid_dim = i
rt.tensor_dim = i
pw_sorted_range_trees.append(rt)
self.range_trees = pw_sorted_range_trees + self.range_trees[reduction_idx:]
self.input_nodes = input_nodes
self.output_node = output_node
self.named_input_nodes = {} # type: ignore[var-annotated]
@ -427,6 +450,9 @@ class TritonTemplateKernel(TritonKernel):
self.kernel_name = kernel_name
self.use_jit = use_jit
self.tma_store = tma_store
self.transpose_discontiguous_tensor_descriptors_override = (
transpose_discontiguous_tensor_descriptors_override
)
self.num_stages = num_stages
self.num_warps = num_warps
self.num_consumer_groups = num_consumer_groups
@ -1170,13 +1196,8 @@ class TritonTemplateKernel(TritonKernel):
intermediate_lines: list[str] = []
epilogue_index_symbols: list[sympy.Symbol] = []
if self.tma_store:
# Generate the expected indexing symbols.
# Note: TMA indices are expected to be in the
# format (x, y), but the range_tree is always
# (yindex, xindex).
index_order = [1, 0]
val_shape_copy = list(val_shape)
for i, range_tree in zip(index_order, self.range_trees[:-1]):
for i, range_tree in enumerate(self.range_trees[:-1]):
name = range_tree.name
symbol = range_tree.symbol()
epilogue_index_symbols.append(symbol)
@ -1197,7 +1218,7 @@ class TritonTemplateKernel(TritonKernel):
index_symbols[i],
val_shape[i],
i,
len(index_order),
len(val_shape),
# pyrefly: ignore [missing-argument]
block_name=range_tree.symt.name,
)
@ -1214,10 +1235,6 @@ class TritonTemplateKernel(TritonKernel):
# after the remapping.
# pyrefly: ignore [missing-argument]
val_shape_copy[i] = range_tree.symt.name
# Reverse the index symbols because TMA is indexed
# as (x, y) whereas the variables will naturally be indexed
# as (y, x)
epilogue_index_symbols.reverse()
val_shape = tuple(val_shape_copy)
else:
mask_vars: list[str] = []
@ -1561,6 +1578,7 @@ class GeneratedCodeCache:
epilogue_fn: Optional[Callable[..., Any]],
epilogue_fn_hash: Optional[str],
tma_store: bool,
transpose_discontiguous_tensor_descriptors_override: Optional[bool],
subgraphs: Optional[list[ir.Buffer]], # has to be none to cache
workspace_arg: Optional[WorkspaceArg], # has to be none to cache
layout: ir.Layout,
@ -1618,6 +1636,7 @@ class GeneratedCodeCache:
"num_buffers_warp_spec": num_buffers_warp_spec,
"epilogue_fn_hash": epilogue_fn_hash,
"tma_store": tma_store,
"transpose_discontiguous_tensor_descriptors_override": transpose_discontiguous_tensor_descriptors_override,
"kwargs": kwargs,
"hint_override": hint_override,
}
@ -1733,6 +1752,7 @@ class TritonTemplate(KernelTemplate):
generate_with_caching,
hint_override: Optional[int] = None,
tma_store: bool = False,
transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None,
) -> Optional[GenerateAndLoadResult]:
"""Generate the python code and load it into the current process"""
caching_enabled = (
@ -1752,6 +1772,7 @@ class TritonTemplate(KernelTemplate):
epilogue_fn,
epilogue_fn_hash,
tma_store,
transpose_discontiguous_tensor_descriptors_override,
subgraphs,
workspace_arg,
layout,
@ -1812,6 +1833,7 @@ class TritonTemplate(KernelTemplate):
use_jit=False,
hint_override=hint_override,
tma_store=tma_store,
transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override,
**kernel_options,
)
@ -1933,6 +1955,7 @@ class TritonTemplate(KernelTemplate):
generate_with_caching=False,
hint_override: Optional[int] = None,
tma_store: bool = False,
transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None,
**kwargs,
):
"""This function generates a TritonTemplateCaller
@ -1979,6 +2002,7 @@ class TritonTemplate(KernelTemplate):
generate_with_caching and self._cache_codegen_enabled_for_template,
hint_override=hint_override,
tma_store=tma_store,
transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override,
)
# May happen as result of dev by 0.
@ -2042,6 +2066,7 @@ class TritonTemplate(KernelTemplate):
use_jit=False,
hint_override=hint_override,
tma_store=tma_store,
transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override,
**options,
)
render = functools.partial(

View File

@ -1773,6 +1773,7 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(),
"tma_store": config.triton.enable_template_tma_store,
"transpose_discontiguous_tensor_descriptors_override": True,
}
# Get base template configs from superclass
for template_kwargs in super()._get_template_configs_impl(

View File

@ -1344,15 +1344,18 @@ clear_inductor_caches = clear_caches
fresh_inductor_cache = fresh_cache
def argsort(seq: Sequence[Any]) -> list[int]:
# preserve original order for equal strides
def argsort(seq: Sequence[Any], *, reverse: bool = False) -> list[int]:
getter = seq.__getitem__
a_r = range(len(seq))
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
# sorted preserves order when elements compare equal
return list(sorted(a_r, key=getter, reverse=reverse)) # noqa: C413
def argsort_sym(
shape_env: ShapeEnv, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]]
shape_env: ShapeEnv,
seq: Sequence[Union[int, torch.SymInt, sympy.Expr]],
*,
reverse: bool = False,
) -> list[int]:
def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int:
a_idx, a_val = a
@ -1382,7 +1385,7 @@ def argsort_sym(
(idx, s.node.expr if isinstance(s, torch.SymInt) else s)
for idx, s in enumerate(seq)
]
exprs = sorted(exprs, key=functools.cmp_to_key(cmp))
exprs = sorted(exprs, key=functools.cmp_to_key(cmp), reverse=reverse)
result = [idx for idx, _ in exprs]
return result