[Intel GPU][FlexAttention] Enable TMA path on Intel GPU (#162138)

The existing `can_use_tma` has some conditions that are unnecessary for Intel GPUs.
We have removed these useless conditions on the Intel GPU path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162138
Approved by: https://github.com/liangan1, https://github.com/EikanWang, https://github.com/jansel, https://github.com/etaf
This commit is contained in:
Xingyuan Li
2025-09-05 16:54:46 +00:00
committed by PyTorch MergeBot
parent f3cebec39e
commit b2c7b9ad2d
2 changed files with 22 additions and 2 deletions

View File

@ -231,6 +231,7 @@ class TestMaxAutotune(TestCase):
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@skipIfXpu(msg="TMA path on Intel GPU not require this check")
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_persistent_tma_illegal_alignment(self, dynamic):
def mm(a, b):
@ -359,6 +360,7 @@ class TestMaxAutotune(TestCase):
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@skipIfXpu(msg="TMA path on Intel GPU not require this check")
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_persistent_tma_illegal_alignment(self, dynamic):
def addmm(x, a, b):

View File

@ -1679,7 +1679,7 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
def _is_tma_compatible(x: IRNode) -> bool:
def _is_tma_compatible_default(x: IRNode) -> bool:
sizes = x.get_size()
strides = x.get_stride()
rank = len(sizes)
@ -1739,7 +1739,25 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
return True
return has_triton_tma_device() and all(_is_tma_compatible(m) for m in matrices)
def _is_tma_compatible_xpu(x: IRNode) -> bool:
strides = x.get_stride()
strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
# Find the single contiguous (“inner”) dim
inner = [
i
for i, st in enumerate(strides_i)
if V.graph.sizevars.statically_known_equals(st, 1)
]
if len(inner) != 1:
return False
return True
return has_triton_tma_device() and all(
_is_tma_compatible_default(m)
if (m_device := m.get_device()) is None or m_device.type != "xpu"
else _is_tma_compatible_xpu(m)
for m in matrices
)
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool: