mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Intel GPU] Enable tensor memory descriptor in triton template for XPU. (#161600)
As Intel Triton now supports tensor descriptor, this PR updates the pinned Intel Triton version and introduces support for Triton MM template with tensor descriptor on XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161600 Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5790b00975
commit
3519969e4f
@ -1 +1 @@
|
||||
a6572fb0be5b9b0a19b0641a0ce05810fa04e44c
|
||||
d0e80f39c562c70986fc548fa6e5852ad86e16e7
|
||||
|
@ -424,6 +424,7 @@ class TestMaxAutotune(TestCase):
|
||||
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@fresh_cache()
|
||||
@skipIfXpu(msg="XPU doesn't support sm carveout")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA")
|
||||
@unittest.skipIf(
|
||||
|
@ -1592,6 +1592,19 @@ class CUDAPersistentTMATemplateConfigHeuristic(TMAConfigMixin, CUDAConfigHeurist
|
||||
self.mm_configs = self.persistent_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic(
|
||||
"mm_persistent_tma",
|
||||
"xpu",
|
||||
)
|
||||
class XPUPersistentTMATemplateConfigHeuristic(TMAConfigMixin, XPUConfigHeuristic):
|
||||
"""Persistent TMA template heuristic for CUDA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use persistent_mm_configs
|
||||
self.mm_configs = self.persistent_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is None, op_name="scaled_mm"
|
||||
|
@ -1561,12 +1561,16 @@ def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
|
||||
|
||||
@functools.lru_cache
|
||||
def get_max_num_sms() -> int:
|
||||
if torch.xpu.is_available():
|
||||
return torch.xpu.get_device_properties().gpu_subslice_count
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
"""Handle experimental carveout if set otherwise return hardware SM count"""
|
||||
# TODO we need to properly guard on this global
|
||||
if torch.xpu.is_available():
|
||||
return get_max_num_sms()
|
||||
carveout = torch._C._get_sm_carveout_experimental()
|
||||
return get_max_num_sms() - (carveout if carveout is not None else 0)
|
||||
|
||||
|
@ -71,7 +71,7 @@ def has_triton_tma_device() -> bool:
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability() >= (9, 0)
|
||||
and not torch.version.hip
|
||||
):
|
||||
) or torch.xpu.is_available():
|
||||
# old API
|
||||
try:
|
||||
from triton.language.extra.cuda import ( # noqa: F401
|
||||
@ -103,7 +103,7 @@ def has_triton_stable_tma_api() -> bool:
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability() >= (9, 0)
|
||||
and not torch.version.hip
|
||||
):
|
||||
) or torch.xpu.is_available():
|
||||
try:
|
||||
from triton.language import make_tensor_descriptor # noqa: F401
|
||||
|
||||
|
Reference in New Issue
Block a user