[Inductor][TMA] Split config-gated and pure compatibility logic for TMA template eligibility checks (#159123)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159123
Approved by: https://github.com/drisspg
This commit is contained in:
NikhilAPatel
2025-07-25 17:45:32 +00:00
committed by PyTorch MergeBot
parent d90ce83027
commit 806d9e3fe7

View File

@ -1630,7 +1630,7 @@ def use_triton_template(
)
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
"""
Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
that Triton relies on today.
@ -1712,10 +1712,13 @@ def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool
return True
return has_triton_tma_device() and all(_is_tma_compatible(m) for m in matrices)
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
return (
config.triton.enable_persistent_tma_matmul
and has_triton_tma_device()
and all(_is_tma_compatible(m) for m in matrices)
can_use_tma(*matrices, add_guards=add_guards)
and config.triton.enable_persistent_tma_matmul
)