[Intel GPU] Enable tensor memory descriptor in triton template for XPU. (#161600)

As Intel Triton now supports tensor descriptor, this PR updates the pinned Intel Triton version and introduces support for Triton MM template with tensor descriptor on XPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161600
Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
xinan.lin
2025-08-28 12:39:58 +00:00
committed by PyTorch MergeBot
parent 5790b00975
commit 3519969e4f
5 changed files with 21 additions and 3 deletions

View File

@ -1 +1 @@
a6572fb0be5b9b0a19b0641a0ce05810fa04e44c
d0e80f39c562c70986fc548fa6e5852ad86e16e7

View File

@ -424,6 +424,7 @@ class TestMaxAutotune(TestCase):
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
@fresh_cache()
@skipIfXpu(msg="XPU doesn't support sm carveout")
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA")
@unittest.skipIf(

View File

@ -1592,6 +1592,19 @@ class CUDAPersistentTMATemplateConfigHeuristic(TMAConfigMixin, CUDAConfigHeurist
self.mm_configs = self.persistent_mm_configs
@register_template_heuristic(
"mm_persistent_tma",
"xpu",
)
class XPUPersistentTMATemplateConfigHeuristic(TMAConfigMixin, XPUConfigHeuristic):
"""Persistent TMA template heuristic for CUDA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use persistent_mm_configs
self.mm_configs = self.persistent_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is None, op_name="scaled_mm"

View File

@ -1561,12 +1561,16 @@ def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
@functools.lru_cache
def get_max_num_sms() -> int:
if torch.xpu.is_available():
return torch.xpu.get_device_properties().gpu_subslice_count
return torch.cuda.get_device_properties("cuda").multi_processor_count
def get_num_sms() -> int:
"""Handle experimental carveout if set otherwise return hardware SM count"""
# TODO we need to properly guard on this global
if torch.xpu.is_available():
return get_max_num_sms()
carveout = torch._C._get_sm_carveout_experimental()
return get_max_num_sms() - (carveout if carveout is not None else 0)

View File

@ -71,7 +71,7 @@ def has_triton_tma_device() -> bool:
torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (9, 0)
and not torch.version.hip
):
) or torch.xpu.is_available():
# old API
try:
from triton.language.extra.cuda import ( # noqa: F401
@ -103,7 +103,7 @@ def has_triton_stable_tma_api() -> bool:
torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (9, 0)
and not torch.version.hip
):
) or torch.xpu.is_available():
try:
from triton.language import make_tensor_descriptor # noqa: F401