[Triton] [Inductor] Enable TMA store for TMA mm templates (#160480)

Summary:
Adds support for TMA store in all TMA matmul templates (notably persistent_tma including addmm and scaled_mm). This works by requiring a template be registered with `tma_store=True` and when met constructs indices/range_trees to hook into the existing code base's TMA store support.

This also includes a couple notable changes:
- Adds support in the TMA template support for checking the output layout.
- Adds support for "hoisting" the tensor descriptor to the top of the kernel. This will currently only be used by template code right now, but in principle it can be generalized to other implementation.
- Supports considering multiple indices as the "contiguous" index. This is handled with support for transposing the input data when the alignment is no longer consistent. In general since the TMA support is derived from the index it doesn't seems reasonable that the 1D index math forces a certain alignment depending on index ordering so long as the layout matches.

Test Plan:
Tested with test_max_autotune.py unit tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160480
Approved by: https://github.com/NikhilAPatel
This commit is contained in:
Nick Riasanovsky
2025-09-14 04:56:49 +00:00
committed by PyTorch MergeBot
parent d2f6daf6a7
commit 74a35c6344
8 changed files with 538 additions and 100 deletions

View File

@ -1664,7 +1664,9 @@ def use_triton_template(
)
def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
def can_use_tma(
*matrices: IRNode, output_layout: Optional[Layout] = None, add_guards: bool = False
) -> bool:
"""
Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
that Triton relies on today.
@ -1686,11 +1688,37 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
def _is_tma_compatible_default(x: IRNode) -> bool:
sizes = x.get_size()
strides = x.get_stride()
def _is_tma_compatible_layout(layout: Optional[Layout]) -> bool:
if layout is None:
return True
sizes = layout.size
strides = layout.stride
dtype = layout.dtype
# Verify the output is 16-byte aligned
if not _aligned(layout.offset):
return False
return _is_tma_compatible(sizes, strides, dtype, allow_float32=True)
def _is_tma_compatible_matrix(m: IRNode) -> bool:
sizes = m.get_size()
strides = m.get_stride()
dtype = m.get_dtype()
# Base pointer 16-byte aligned
if m.get_name() in V.graph.unaligned_buffers:
return False
return _is_tma_compatible(sizes, strides, dtype, allow_float32=False)
def _is_tma_compatible(
sizes: Sequence[sympy.Expr],
strides: Sequence[_IntLike],
dtype: torch.dtype,
allow_float32: bool,
) -> bool:
rank = len(sizes)
dtype = x.get_dtype()
itemsize = dtype.itemsize
# 2 ≤ rank ≤ 5
@ -1698,11 +1726,9 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
return False
# dtype ∈ {FP16, BF16, FP8-E4M3FN}
if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn):
return False
# Base pointer 16-byte aligned
if x.get_name() in V.graph.unaligned_buffers:
if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn) and (
not allow_float32 or dtype != torch.float32
):
return False
if add_guards:
@ -1746,31 +1772,20 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
return True
def _is_tma_compatible_xpu(x: IRNode) -> bool:
strides = x.get_stride()
strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
# Find the single contiguous (“inner”) dim
inner = [
i
for i, st in enumerate(strides_i)
if V.graph.sizevars.statically_known_equals(st, 1)
]
if len(inner) != 1:
return False
return True
return has_triton_tma_device() and all(
_is_tma_compatible_default(m)
if (m_device := m.get_device()) is None or m_device.type != "xpu"
else _is_tma_compatible_xpu(m)
for m in matrices
return (
has_triton_tma_device()
and all(_is_tma_compatible_matrix(m) for m in matrices)
and _is_tma_compatible_layout(output_layout)
)
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
def use_triton_tma_template(
*matrices: IRNode, output_layout: Layout, add_guards: bool = False
) -> bool:
layout = output_layout if config.triton.enable_template_tma_store else None
return (
all(len(m.get_size()) == 2 for m in matrices)
and can_use_tma(*matrices, add_guards=add_guards)
and can_use_tma(*matrices, output_layout=layout, add_guards=add_guards)
and config.triton.enable_persistent_tma_matmul
)