mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][triton pin] add support for new TMA API for mm.py templates (#155723)
Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488 For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs). For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API). For mm_scaled_grouped.py, https://github.com/pytorch/pytorch/pull/150944 will remove TMA support for Triton 3.2. Note: we attempted this earlier with https://github.com/pytorch/pytorch/pull/154858, but this broke TMA usage in Triton 3.2. Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155723 Approved by: https://github.com/NikhilAPatel
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b9d638e33
commit
c3ecabf059
@ -1506,7 +1506,7 @@ def use_triton_template(
|
||||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
@ -1535,6 +1535,10 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
inner_bytes = inner_dim * dtype.itemsize
|
||||
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)
|
||||
|
||||
if has_triton_stable_tma_api() and config.cpp_wrapper:
|
||||
# TODO(dberard) remove this when we get AOTI support for new TMA APIs (#155047)
|
||||
return False
|
||||
|
||||
return (
|
||||
config.triton.enable_persistent_tma_matmul
|
||||
and has_triton_tma_device()
|
||||
|
Reference in New Issue
Block a user