mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b: * Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements - AITER ASM kernels deliver over 500TFLOPS training performance. See [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more details. * Now returns natural based `logsumexp` tensor, matching CUDA's behavior - PR #156903 is reverted in this PR as well since it is not needed anymore. * Enables `CausalVariant.LOWER_RIGHT` The build system changes drastically along with new packaging scheme of AOTriton 0.11 * AOTriton 0.11 packs GPU images separately from AOTriton runtime * `aotriton.cmake` now selectively downloads image packs according to `PYTORCH_ROCM_ARCH` * `aotriton.cmake` now only use pre-compiled runtime library that exactly matches the ROCM in the build environment. For PyTorch builds with ROCm versions not listed in the file, the build process will build AOTriton runtime without GPU images from source - This avoids any further ABI breaks like ROCM 6.4 -> 7.0 - recursive git clone is disabled since building AOTriton runtime does not require submodules. Bug fixes: * Fix a kernel bug introduced when implementing SWA Known Problems: * gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status due to accuracy issues. Triton compiler fixes are needed to restore the support status. * Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0. This issue is under investigation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161754 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
180 lines
4.6 KiB
Python
180 lines
4.6 KiB
Python
import functools
|
|
import hashlib
|
|
from typing import Any
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_package() -> bool:
|
|
try:
|
|
import triton # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
|
|
try:
|
|
import triton # noqa: F401
|
|
|
|
major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
|
|
return (major, minor)
|
|
except ImportError:
|
|
return fallback
|
|
|
|
|
|
@functools.cache
|
|
def _device_supports_tma() -> bool:
|
|
import torch
|
|
|
|
return (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_experimental_host_tma() -> bool:
|
|
if has_triton_package():
|
|
if _device_supports_tma():
|
|
try:
|
|
from triton.tools.experimental_descriptor import ( # noqa: F401
|
|
create_1d_tma_descriptor,
|
|
create_2d_tma_descriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tensor_descriptor_host_tma() -> bool:
|
|
if has_triton_package():
|
|
if _device_supports_tma():
|
|
try:
|
|
from triton.tools.tensor_descriptor import ( # noqa: F401
|
|
TensorDescriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tma() -> bool:
|
|
return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tma_device() -> bool:
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
) or torch.xpu.is_available():
|
|
# old API
|
|
try:
|
|
from triton.language.extra.cuda import ( # noqa: F401
|
|
experimental_device_tensormap_create1d,
|
|
experimental_device_tensormap_create2d,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
# new API
|
|
try:
|
|
from triton.language import make_tensor_descriptor # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_stable_tma_api() -> bool:
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
) or torch.xpu.is_available():
|
|
try:
|
|
from triton.language import make_tensor_descriptor # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton() -> bool:
|
|
if not has_triton_package():
|
|
return False
|
|
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
def cuda_extra_check(device_interface: Any) -> bool:
|
|
return device_interface.Worker.get_device_properties().major >= 7
|
|
|
|
def cpu_extra_check(device_interface: Any) -> bool:
|
|
import triton.backends
|
|
|
|
return "cpu" in triton.backends.backends
|
|
|
|
def _return_true(device_interface: Any) -> bool:
|
|
return True
|
|
|
|
triton_supported_devices = {
|
|
"cuda": cuda_extra_check,
|
|
"xpu": _return_true,
|
|
"cpu": cpu_extra_check,
|
|
"mtia": _return_true,
|
|
}
|
|
|
|
def is_device_compatible_with_triton() -> bool:
|
|
for device, extra_check in triton_supported_devices.items():
|
|
device_interface = get_interface_for_device(device)
|
|
if device_interface.is_available() and extra_check(device_interface):
|
|
return True
|
|
return False
|
|
|
|
return is_device_compatible_with_triton()
|
|
|
|
|
|
@functools.cache
|
|
def triton_backend() -> Any:
|
|
from triton.compiler.compiler import make_backend
|
|
from triton.runtime.driver import driver
|
|
|
|
target = driver.active.get_current_target()
|
|
return make_backend(target)
|
|
|
|
|
|
@functools.cache
|
|
def triton_hash_with_backend() -> str:
|
|
from torch._inductor.runtime.triton_compat import triton_key
|
|
|
|
backend = triton_backend()
|
|
key = f"{triton_key()}-{backend.hash()}"
|
|
|
|
# Hash is upper case so that it can't contain any Python keywords.
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|