[Inductor][Triton] Update TMA Compatibility Requirements (#157881)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157881
Approved by: https://github.com/Skylion007, https://github.com/drisspg
This commit is contained in:
NikhilAPatel
2025-07-16 06:30:42 +00:00
committed by PyTorch MergeBot
parent e71bb021b9
commit ea74fdd24a

View File

@ -1543,35 +1543,87 @@ def use_triton_template(
)
def use_triton_tma_template(*matrices: IRNode) -> bool:
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
"""
Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
that Triton relies on today.
* https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
A tensor is accepted when:
* 2 ≤ rank ≤ 5
* dtype ∈ {FP16, BF16, FP8-E4M3FN}
* Every logical size ≥ 2
* Base pointer 16-byte aligned
* All "outer" dims have 16-byte aligned strides
* The “inner” dim has stride 1 (contiguous)
* For FP8 tensors, inner dim ≥ 32
"""
from torch.utils._triton import has_triton_tma_device
from .virtualized import V
def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
def _is_tma_compatible(x: IRNode) -> bool:
if len(x.get_size()) != 2:
sizes = x.get_size()
strides = x.get_stride()
rank = len(sizes)
dtype = x.get_dtype()
itemsize = dtype.itemsize
# 2 ≤ rank ≤ 5
if rank < 2 or rank > 5:
return False
dtype = x.get_dtype()
# dtype ∈ {FP16, BF16, FP8-E4M3FN}
if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn):
return False
layout = x.get_layout()
transposed = layout.is_transposed()
if not (layout.is_contiguous() or transposed):
# Base pointer 16-byte aligned
if x.get_name() in V.graph.unaligned_buffers:
return False
inner_dim = layout.size[1]
if transposed:
inner_dim = layout.size[0]
if add_guards:
sizes_i = V.graph.sizevars.guard_int_seq(sizes)
strides_i = V.graph.sizevars.guard_int_seq(strides)
else:
sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes]
strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt(
# Every logical size ≥ 2
if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
return False
# Find the single contiguous (“inner”) dim
inner = [
i
for i, st in enumerate(strides_i)
if V.graph.sizevars.statically_known_equals(st, 1)
]
if len(inner) != 1:
return False
inner_idx = inner[0]
# All "outer" dims must have 16-byte aligned strides
for i, st in enumerate(strides_i):
if i == inner_idx:
continue
if not _aligned(st * itemsize):
return False
# Inner dim byte width must still be a multiple of 16 B
inner_dim = sizes_i[inner_idx]
if not _aligned(inner_dim * itemsize):
return False
# FP8 special case: inner ≥ 32
if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq(
inner_dim, 32
):
return False
inner_bytes = inner_dim * dtype.itemsize
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)
return True
return (
config.triton.enable_persistent_tma_matmul