[inductor][triton pin] add support for new TMA API for mm.py templates (#155723)

Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, https://github.com/pytorch/pytorch/pull/150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with https://github.com/pytorch/pytorch/pull/154858, but this broke TMA usage in Triton 3.2.

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155723
Approved by: https://github.com/NikhilAPatel
This commit is contained in:
David Berard
2025-06-11 15:04:11 -07:00
committed by PyTorch MergeBot
parent 2b9d638e33
commit c3ecabf059
4 changed files with 98 additions and 4 deletions

View File

@ -265,6 +265,7 @@ persistent_tma_mm_template = TritonTemplate(
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
{%- if TMA_EXPERIMENTAL_API %}
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=A,
@ -283,6 +284,23 @@ persistent_tma_mm_template = TritonTemplate(
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
a_desc = a_desc_ptr
b_desc = b_desc_ptr
{%- else %}
a_desc = triton.language.make_tensor_descriptor(
base=A,
shape=[M, K] if A_ROW_MAJOR else [K, M],
strides=[K, 1] if A_ROW_MAJOR else [M, 1],
block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
)
b_desc = triton.language.make_tensor_descriptor(
base=B,
shape=[K, N] if B_ROW_MAJOR else [N, K],
strides=[N, 1] if B_ROW_MAJOR else [K, 1],
block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
)
{%- endif %}
pid_m = 0
pid_n = 0
rm = 0
@ -303,18 +321,29 @@ persistent_tma_mm_template = TritonTemplate(
rk = ki * BLOCK_K
{%- if TMA_EXPERIMENTAL_API %}
a = tl._experimental_descriptor_load(
a_desc_ptr,
a_desc,
[rm, rk] if A_ROW_MAJOR else [rk, rm],
[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
A.dtype.element_ty,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
b_desc,
[rk, rn] if B_ROW_MAJOR else [rn, rk],
[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
B.dtype.element_ty,
)
{%- else %}
a = tl.load_tensor_descriptor(
a_desc,
[rm, rk] if A_ROW_MAJOR else [rk, rm],
)
b = tl.load_tensor_descriptor(
b_desc,
[rk, rn] if B_ROW_MAJOR else [rn, rk],
)
{%- endif %}
acc += tl.dot(
a if A_ROW_MAJOR else a.T,
b if B_ROW_MAJOR else b.T,
@ -416,6 +445,7 @@ device_tma = r"""
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
{%- if TMA_EXPERIMENTAL_API %}
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=A,
@ -434,6 +464,23 @@ device_tma = r"""
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
a_desc = a_desc_ptr
b_desc = a_desc_ptr
{%- else %}
a_desc = triton.language.make_tensor_descriptor(
base=A,
shape=[M, K],
strides=[K, 1],
block_shape=[BLOCK_M, BLOCK_K],
)
b_desc = triton.language.make_tensor_descriptor(
base=B,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_N, BLOCK_K],
)
{%- endif %}
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
@ -465,12 +512,17 @@ device_tma = r"""
offs_k = ki * BLOCK_K
{%- if TMA_EXPERIMENTAL_API %}
a = tl._experimental_descriptor_load(
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
)
b = tl._experimental_descriptor_load(
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
)
{%- else %}
a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k])
{%- endif %}
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:

View File

@ -79,13 +79,21 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
return options_dict
def tma_options() -> dict[str, Any]:
from torch.utils._triton import has_triton_stable_tma_api
return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()}
def persistent_mm_options(mat1, mat2):
return dict(
res = dict(
A_ROW_MAJOR=not mat1.layout.is_transposed(),
B_ROW_MAJOR=not mat2.layout.is_transposed(),
NUM_SMS=get_num_sms(),
TMA_SIZE=TMA_DESCRIPTOR_SIZE,
)
res.update(tma_options())
return res
def scaled_mm_options( # type: ignore[no-untyped-def]
@ -126,6 +134,8 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
mm_template_options["NUM_SMS"] = get_num_sms()
mm_template_options.update(tma_options())
return mm_template_options

View File

@ -1506,7 +1506,7 @@ def use_triton_template(
def use_triton_tma_template(*matrices: IRNode) -> bool:
from torch.utils._triton import has_triton_tma_device
from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device
from .virtualized import V
@ -1535,6 +1535,10 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
inner_bytes = inner_dim * dtype.itemsize
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)
if has_triton_stable_tma_api() and config.cpp_wrapper:
# TODO(dberard) remove this when we get AOTI support for new TMA APIs (#155047)
return False
return (
config.triton.enable_persistent_tma_matmul
and has_triton_tma_device()

View File

@ -74,6 +74,7 @@ def has_triton_tma_device() -> bool:
and torch.cuda.get_device_capability() >= (9, 0)
and not torch.version.hip
):
# old API
try:
from triton.language.extra.cuda import ( # noqa: F401
experimental_device_tensormap_create1d,
@ -84,6 +85,33 @@ def has_triton_tma_device() -> bool:
except ImportError:
pass
# new API
try:
from triton.language import make_tensor_descriptor # noqa: F401
return True
except ImportError:
pass
return False
@functools.lru_cache(None)
def has_triton_stable_tma_api() -> bool:
if has_triton_package():
import torch
if (
torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (9, 0)
and not torch.version.hip
):
try:
from triton.language import make_tensor_descriptor # noqa: F401
return True
except ImportError:
pass
return False