mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This change adds a new environment variable (`TORCHINDUCTOR_TRITON_DISABLE_DEVICE_DETECTION`) and configuration in `torch._inductor.config` which can be set to `"1"` to allow a user to disable triton's device detection logic in [torch/utils/_triton.py:has_triton()](c9e57d7e9f/torch/utils/_triton.py (L128)
). This function is used at import scope in several places but the function has a side effect of initializing the mtia device if it is available which is causing some of our autotuning workflows to crash. Worth noting that when enabled this configuration disables all device detection not just mtia and this is because the logic in has_triton will initialize the mtia device as a side effect even when checking for a cuda or other device via the [get_interface_for_device()](c9e57d7e9f/torch/_dynamo/device_interface.py (L570)
) function. I've tagged it `topic: not user facing` since I don't anticipate any outside of meta users making use of this, however this is my first PR here, so please indicate if it should be handled differently. Test Plan: This has been tested in the context of internal workflows. Differential Revision: D82347853 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162974 Approved by: https://github.com/xmfan
200 lines
5.1 KiB
Python
200 lines
5.1 KiB
Python
import functools
|
|
import hashlib
|
|
from typing import Any
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_package() -> bool:
|
|
try:
|
|
import triton # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
|
|
try:
|
|
import triton # noqa: F401
|
|
|
|
major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
|
|
return (major, minor)
|
|
except ImportError:
|
|
return fallback
|
|
|
|
|
|
@functools.cache
|
|
def _device_supports_tma() -> bool:
|
|
import torch
|
|
|
|
return (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_experimental_host_tma() -> bool:
|
|
if has_triton_package():
|
|
if _device_supports_tma():
|
|
try:
|
|
from triton.tools.experimental_descriptor import ( # noqa: F401
|
|
create_1d_tma_descriptor,
|
|
create_2d_tma_descriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tensor_descriptor_host_tma() -> bool:
|
|
if has_triton_package():
|
|
if _device_supports_tma():
|
|
try:
|
|
from triton.tools.tensor_descriptor import ( # noqa: F401
|
|
TensorDescriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tma() -> bool:
|
|
return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tma_device() -> bool:
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
) or torch.xpu.is_available():
|
|
# old API
|
|
try:
|
|
from triton.language.extra.cuda import ( # noqa: F401
|
|
experimental_device_tensormap_create1d,
|
|
experimental_device_tensormap_create2d,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
# new API
|
|
try:
|
|
from triton.language import make_tensor_descriptor # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_datacenter_blackwell_tma_device() -> bool:
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (10, 0)
|
|
and torch.cuda.get_device_capability() < (11, 0)
|
|
and not torch.version.hip
|
|
):
|
|
return has_triton_tma_device() and has_triton_tensor_descriptor_host_tma()
|
|
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_stable_tma_api() -> bool:
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
) or torch.xpu.is_available():
|
|
try:
|
|
from triton.language import make_tensor_descriptor # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton() -> bool:
|
|
if not has_triton_package():
|
|
return False
|
|
|
|
from torch._inductor.config import triton_disable_device_detection
|
|
|
|
if triton_disable_device_detection:
|
|
return False
|
|
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
def cuda_extra_check(device_interface: Any) -> bool:
|
|
return device_interface.Worker.get_device_properties().major >= 7
|
|
|
|
def cpu_extra_check(device_interface: Any) -> bool:
|
|
import triton.backends
|
|
|
|
return "cpu" in triton.backends.backends
|
|
|
|
def _return_true(device_interface: Any) -> bool:
|
|
return True
|
|
|
|
triton_supported_devices = {
|
|
"cuda": cuda_extra_check,
|
|
"xpu": _return_true,
|
|
"cpu": cpu_extra_check,
|
|
"mtia": _return_true,
|
|
}
|
|
|
|
def is_device_compatible_with_triton() -> bool:
|
|
for device, extra_check in triton_supported_devices.items():
|
|
device_interface = get_interface_for_device(device)
|
|
if device_interface.is_available() and extra_check(device_interface):
|
|
return True
|
|
return False
|
|
|
|
return is_device_compatible_with_triton()
|
|
|
|
|
|
@functools.cache
|
|
def triton_backend() -> Any:
|
|
from triton.compiler.compiler import make_backend
|
|
from triton.runtime.driver import driver
|
|
|
|
target = driver.active.get_current_target()
|
|
return make_backend(target)
|
|
|
|
|
|
@functools.cache
|
|
def triton_hash_with_backend() -> str:
|
|
from torch._inductor.runtime.triton_compat import triton_key
|
|
|
|
backend = triton_backend()
|
|
key = f"{triton_key()}-{backend.hash()}"
|
|
|
|
# Hash is upper case so that it can't contain any Python keywords.
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|