[inductor][triton pin] add support for new TMA API for mm.py templates (#155723)

Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, https://github.com/pytorch/pytorch/pull/150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with https://github.com/pytorch/pytorch/pull/154858, but this broke TMA usage in Triton 3.2.

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155723
Approved by: https://github.com/NikhilAPatel
This commit is contained in:
David Berard
2025-06-11 15:04:11 -07:00
committed by PyTorch MergeBot
parent 2b9d638e33
commit c3ecabf059
4 changed files with 98 additions and 4 deletions

View File

@ -1506,7 +1506,7 @@ def use_triton_template(
def use_triton_tma_template(*matrices: IRNode) -> bool:
from torch.utils._triton import has_triton_tma_device
from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device
from .virtualized import V
@ -1535,6 +1535,10 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
inner_bytes = inner_dim * dtype.itemsize
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)
if has_triton_stable_tma_api() and config.cpp_wrapper:
# TODO(dberard) remove this when we get AOTI support for new TMA APIs (#155047)
return False
return (
config.triton.enable_persistent_tma_matmul
and has_triton_tma_device()