mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Triton] [Inductor] Enable TMA store for TMA mm templates (#160480)
Summary: Adds support for TMA store in all TMA matmul templates (notably persistent_tma including addmm and scaled_mm). This works by requiring a template be registered with `tma_store=True` and when met constructs indices/range_trees to hook into the existing code base's TMA store support. This also includes a couple notable changes: - Adds support in the TMA template support for checking the output layout. - Adds support for "hoisting" the tensor descriptor to the top of the kernel. This will currently only be used by template code right now, but in principle it can be generalized to other implementation. - Supports considering multiple indices as the "contiguous" index. This is handled with support for transposing the input data when the alignment is no longer consistent. In general since the TMA support is derived from the index it doesn't seems reasonable that the 1D index math forces a certain alignment depending on index ordering so long as the layout matches. Test Plan: Tested with test_max_autotune.py unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160480 Approved by: https://github.com/NikhilAPatel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2f6daf6a7
commit
74a35c6344
@ -1664,7 +1664,9 @@ def use_triton_template(
|
||||
)
|
||||
|
||||
|
||||
def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
def can_use_tma(
|
||||
*matrices: IRNode, output_layout: Optional[Layout] = None, add_guards: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
|
||||
that Triton relies on today.
|
||||
@ -1686,11 +1688,37 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
|
||||
return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
|
||||
|
||||
def _is_tma_compatible_default(x: IRNode) -> bool:
|
||||
sizes = x.get_size()
|
||||
strides = x.get_stride()
|
||||
def _is_tma_compatible_layout(layout: Optional[Layout]) -> bool:
|
||||
if layout is None:
|
||||
return True
|
||||
sizes = layout.size
|
||||
strides = layout.stride
|
||||
dtype = layout.dtype
|
||||
|
||||
# Verify the output is 16-byte aligned
|
||||
if not _aligned(layout.offset):
|
||||
return False
|
||||
|
||||
return _is_tma_compatible(sizes, strides, dtype, allow_float32=True)
|
||||
|
||||
def _is_tma_compatible_matrix(m: IRNode) -> bool:
|
||||
sizes = m.get_size()
|
||||
strides = m.get_stride()
|
||||
dtype = m.get_dtype()
|
||||
|
||||
# Base pointer 16-byte aligned
|
||||
if m.get_name() in V.graph.unaligned_buffers:
|
||||
return False
|
||||
|
||||
return _is_tma_compatible(sizes, strides, dtype, allow_float32=False)
|
||||
|
||||
def _is_tma_compatible(
|
||||
sizes: Sequence[sympy.Expr],
|
||||
strides: Sequence[_IntLike],
|
||||
dtype: torch.dtype,
|
||||
allow_float32: bool,
|
||||
) -> bool:
|
||||
rank = len(sizes)
|
||||
dtype = x.get_dtype()
|
||||
itemsize = dtype.itemsize
|
||||
|
||||
# 2 ≤ rank ≤ 5
|
||||
@ -1698,11 +1726,9 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
return False
|
||||
|
||||
# dtype ∈ {FP16, BF16, FP8-E4M3FN}
|
||||
if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn):
|
||||
return False
|
||||
|
||||
# Base pointer 16-byte aligned
|
||||
if x.get_name() in V.graph.unaligned_buffers:
|
||||
if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn) and (
|
||||
not allow_float32 or dtype != torch.float32
|
||||
):
|
||||
return False
|
||||
|
||||
if add_guards:
|
||||
@ -1746,31 +1772,20 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
|
||||
return True
|
||||
|
||||
def _is_tma_compatible_xpu(x: IRNode) -> bool:
|
||||
strides = x.get_stride()
|
||||
strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
|
||||
# Find the single contiguous (“inner”) dim
|
||||
inner = [
|
||||
i
|
||||
for i, st in enumerate(strides_i)
|
||||
if V.graph.sizevars.statically_known_equals(st, 1)
|
||||
]
|
||||
if len(inner) != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
return has_triton_tma_device() and all(
|
||||
_is_tma_compatible_default(m)
|
||||
if (m_device := m.get_device()) is None or m_device.type != "xpu"
|
||||
else _is_tma_compatible_xpu(m)
|
||||
for m in matrices
|
||||
return (
|
||||
has_triton_tma_device()
|
||||
and all(_is_tma_compatible_matrix(m) for m in matrices)
|
||||
and _is_tma_compatible_layout(output_layout)
|
||||
)
|
||||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
def use_triton_tma_template(
|
||||
*matrices: IRNode, output_layout: Layout, add_guards: bool = False
|
||||
) -> bool:
|
||||
layout = output_layout if config.triton.enable_template_tma_store else None
|
||||
return (
|
||||
all(len(m.get_size()) == 2 for m in matrices)
|
||||
and can_use_tma(*matrices, add_guards=add_guards)
|
||||
and can_use_tma(*matrices, output_layout=layout, add_guards=add_guards)
|
||||
and config.triton.enable_persistent_tma_matmul
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user