mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][TMA] Split config-gated and pure compatibility logic for TMA template eligibility checks (#159123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159123 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
d90ce83027
commit
806d9e3fe7
@ -1630,7 +1630,7 @@ def use_triton_template(
|
||||
)
|
||||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
"""
|
||||
Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
|
||||
that Triton relies on today.
|
||||
@ -1712,10 +1712,13 @@ def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool
|
||||
|
||||
return True
|
||||
|
||||
return has_triton_tma_device() and all(_is_tma_compatible(m) for m in matrices)
|
||||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
return (
|
||||
config.triton.enable_persistent_tma_matmul
|
||||
and has_triton_tma_device()
|
||||
and all(_is_tma_compatible(m) for m in matrices)
|
||||
can_use_tma(*matrices, add_guards=add_guards)
|
||||
and config.triton.enable_persistent_tma_matmul
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user