mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][triton pin] add support for new TMA API for mm.py templates (#155723)
Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488 For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs). For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API). For mm_scaled_grouped.py, https://github.com/pytorch/pytorch/pull/150944 will remove TMA support for Triton 3.2. Note: we attempted this earlier with https://github.com/pytorch/pytorch/pull/154858, but this broke TMA usage in Triton 3.2. Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155723 Approved by: https://github.com/NikhilAPatel
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b9d638e33
commit
c3ecabf059
@ -265,6 +265,7 @@ persistent_tma_mm_template = TritonTemplate(
|
||||
a_desc_ptr = workspace_base
|
||||
b_desc_ptr = workspace_base + TMA_SIZE
|
||||
|
||||
{%- if TMA_EXPERIMENTAL_API %}
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=a_desc_ptr,
|
||||
global_address=A,
|
||||
@ -283,6 +284,23 @@ persistent_tma_mm_template = TritonTemplate(
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
|
||||
a_desc = a_desc_ptr
|
||||
b_desc = b_desc_ptr
|
||||
{%- else %}
|
||||
a_desc = triton.language.make_tensor_descriptor(
|
||||
base=A,
|
||||
shape=[M, K] if A_ROW_MAJOR else [K, M],
|
||||
strides=[K, 1] if A_ROW_MAJOR else [M, 1],
|
||||
block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
|
||||
)
|
||||
b_desc = triton.language.make_tensor_descriptor(
|
||||
base=B,
|
||||
shape=[K, N] if B_ROW_MAJOR else [N, K],
|
||||
strides=[N, 1] if B_ROW_MAJOR else [K, 1],
|
||||
block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
|
||||
)
|
||||
{%- endif %}
|
||||
|
||||
pid_m = 0
|
||||
pid_n = 0
|
||||
rm = 0
|
||||
@ -303,18 +321,29 @@ persistent_tma_mm_template = TritonTemplate(
|
||||
|
||||
rk = ki * BLOCK_K
|
||||
|
||||
{%- if TMA_EXPERIMENTAL_API %}
|
||||
a = tl._experimental_descriptor_load(
|
||||
a_desc_ptr,
|
||||
a_desc,
|
||||
[rm, rk] if A_ROW_MAJOR else [rk, rm],
|
||||
[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
|
||||
A.dtype.element_ty,
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
b_desc_ptr,
|
||||
b_desc,
|
||||
[rk, rn] if B_ROW_MAJOR else [rn, rk],
|
||||
[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
|
||||
B.dtype.element_ty,
|
||||
)
|
||||
{%- else %}
|
||||
a = tl.load_tensor_descriptor(
|
||||
a_desc,
|
||||
[rm, rk] if A_ROW_MAJOR else [rk, rm],
|
||||
)
|
||||
b = tl.load_tensor_descriptor(
|
||||
b_desc,
|
||||
[rk, rn] if B_ROW_MAJOR else [rn, rk],
|
||||
)
|
||||
{%- endif %}
|
||||
acc += tl.dot(
|
||||
a if A_ROW_MAJOR else a.T,
|
||||
b if B_ROW_MAJOR else b.T,
|
||||
@ -416,6 +445,7 @@ device_tma = r"""
|
||||
a_desc_ptr = workspace_base
|
||||
b_desc_ptr = workspace_base + TMA_SIZE
|
||||
|
||||
{%- if TMA_EXPERIMENTAL_API %}
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=a_desc_ptr,
|
||||
global_address=A,
|
||||
@ -434,6 +464,23 @@ device_tma = r"""
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
|
||||
a_desc = a_desc_ptr
|
||||
b_desc = a_desc_ptr
|
||||
{%- else %}
|
||||
a_desc = triton.language.make_tensor_descriptor(
|
||||
base=A,
|
||||
shape=[M, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_M, BLOCK_K],
|
||||
)
|
||||
b_desc = triton.language.make_tensor_descriptor(
|
||||
base=B,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_N, BLOCK_K],
|
||||
)
|
||||
{%- endif %}
|
||||
|
||||
tiles_per_SM = num_tiles // NUM_SMS
|
||||
if start_pid < num_tiles % NUM_SMS:
|
||||
tiles_per_SM += 1
|
||||
@ -465,12 +512,17 @@ device_tma = r"""
|
||||
|
||||
offs_k = ki * BLOCK_K
|
||||
|
||||
{%- if TMA_EXPERIMENTAL_API %}
|
||||
a = tl._experimental_descriptor_load(
|
||||
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
|
||||
)
|
||||
{%- else %}
|
||||
a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
|
||||
b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k])
|
||||
{%- endif %}
|
||||
if USE_FAST_ACCUM:
|
||||
accumulator = tl.dot(a, b.T, accumulator)
|
||||
else:
|
||||
|
@ -79,13 +79,21 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
|
||||
return options_dict
|
||||
|
||||
|
||||
def tma_options() -> dict[str, Any]:
|
||||
from torch.utils._triton import has_triton_stable_tma_api
|
||||
|
||||
return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()}
|
||||
|
||||
|
||||
def persistent_mm_options(mat1, mat2):
|
||||
return dict(
|
||||
res = dict(
|
||||
A_ROW_MAJOR=not mat1.layout.is_transposed(),
|
||||
B_ROW_MAJOR=not mat2.layout.is_transposed(),
|
||||
NUM_SMS=get_num_sms(),
|
||||
TMA_SIZE=TMA_DESCRIPTOR_SIZE,
|
||||
)
|
||||
res.update(tma_options())
|
||||
return res
|
||||
|
||||
|
||||
def scaled_mm_options( # type: ignore[no-untyped-def]
|
||||
@ -126,6 +134,8 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
|
||||
mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
|
||||
mm_template_options["NUM_SMS"] = get_num_sms()
|
||||
|
||||
mm_template_options.update(tma_options())
|
||||
|
||||
return mm_template_options
|
||||
|
||||
|
||||
|
@ -1506,7 +1506,7 @@ def use_triton_template(
|
||||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
@ -1535,6 +1535,10 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
inner_bytes = inner_dim * dtype.itemsize
|
||||
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)
|
||||
|
||||
if has_triton_stable_tma_api() and config.cpp_wrapper:
|
||||
# TODO(dberard) remove this when we get AOTI support for new TMA APIs (#155047)
|
||||
return False
|
||||
|
||||
return (
|
||||
config.triton.enable_persistent_tma_matmul
|
||||
and has_triton_tma_device()
|
||||
|
@ -74,6 +74,7 @@ def has_triton_tma_device() -> bool:
|
||||
and torch.cuda.get_device_capability() >= (9, 0)
|
||||
and not torch.version.hip
|
||||
):
|
||||
# old API
|
||||
try:
|
||||
from triton.language.extra.cuda import ( # noqa: F401
|
||||
experimental_device_tensormap_create1d,
|
||||
@ -84,6 +85,33 @@ def has_triton_tma_device() -> bool:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# new API
|
||||
try:
|
||||
from triton.language import make_tensor_descriptor # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton_stable_tma_api() -> bool:
|
||||
if has_triton_package():
|
||||
import torch
|
||||
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability() >= (9, 0)
|
||||
and not torch.version.hip
|
||||
):
|
||||
try:
|
||||
from triton.language import make_tensor_descriptor # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user