Revert "Use the device interface for detecting Triton availability (#139171)"

This reverts commit 940b60db974f08a31c746eec2f9c399fc8a861ee.

Reverted https://github.com/pytorch/pytorch/pull/139171 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @jansel can you please help get these changes working? See D70946254 for more details. To validate the fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/139171#issuecomment-2715392451))
This commit is contained in:
PyTorch MergeBot
2025-03-11 18:49:21 +00:00
parent 57ee821a41
commit c916a8efc5
18 changed files with 128 additions and 243 deletions

View File

@ -142,7 +142,7 @@ def check_rocm():
return rocm_ver if torch.version.hip else "None"
def check_dynamo(backend: str, device: str, err_msg: str) -> None:
def check_dynamo(backend, device, err_msg) -> None:
import torch
if device == "cuda" and not torch.cuda.is_available():
@ -151,15 +151,17 @@ def check_dynamo(backend: str, device: str, err_msg: str) -> None:
try:
import torch._dynamo as dynamo
from torch._dynamo.eval_frame import raise_if_inductor_unavailable
try:
raise_if_inductor_unavailable(device)
except RuntimeError as e:
print(
f"WARNING: Inductor not available for {device} ({e}). Skipping check."
)
return
if device == "cuda":
from torch.utils._triton import has_triton
if not has_triton():
print(
f"WARNING: CUDA available but triton cannot be used. "
f"Your GPU may not be supported. "
f"Skipping CUDA check on {backend} backend\n"
)
return
dynamo.reset()
@ -203,8 +205,6 @@ _SANITY_CHECK_ARGS = (
def main() -> None:
from torch._dynamo.eval_frame import is_dynamo_supported
python_ver = check_python()
torch_ver = check_torch()
cuda_ver = check_cuda()
@ -215,10 +215,10 @@ def main() -> None:
f"CUDA version: {cuda_ver}\n"
f"ROCM version: {rocm_ver}\n"
)
if not is_dynamo_supported():
warnings.warn("Dynamo is not supported on this platform. Skipping check.")
return
for args in _SANITY_CHECK_ARGS:
if sys.version_info >= (3, 13):
warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.")
continue
check_dynamo(*args)
print("All required checks passed")

View File

@ -17,7 +17,6 @@ The abstraction layer enables device-agnostic code in TorchDynamo while allowing
specialized implementations for each hardware backend's unique features.
"""
import inspect
import time
from collections.abc import Iterable
from dataclasses import dataclass
@ -32,6 +31,8 @@ if torch.cuda._is_compiled():
else:
get_cuda_stream = None
_device_t = Union[torch.device, str, int, None]
# Recording the device properties in the main process but used in worker process.
caching_worker_device_properties: dict[str, Any] = {}
caching_worker_current_devices: dict[str, int] = {}
@ -44,7 +45,7 @@ class DeviceInterface:
"""
class device:
def __new__(cls, device: torch.types.Device):
def __new__(cls, device: _device_t):
raise NotImplementedError
class Event:
@ -76,7 +77,7 @@ class DeviceInterface:
raise NotImplementedError
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: _device_t = None):
raise NotImplementedError
@staticmethod
@ -84,7 +85,7 @@ class DeviceInterface:
raise NotImplementedError
@staticmethod
def set_device(device: torch.types.Device):
def set_device(device: _device_t):
raise NotImplementedError
@staticmethod
@ -124,15 +125,15 @@ class DeviceInterface:
raise NotImplementedError
@staticmethod
def synchronize(device: torch.types.Device = None):
def synchronize(device: _device_t = None):
raise NotImplementedError
@classmethod
def get_device_properties(cls, device: torch.types.Device = None):
def get_device_properties(cls, device: _device_t = None):
return cls.Worker.get_device_properties(device)
@staticmethod
def get_compute_capability(device: torch.types.Device = None):
def get_compute_capability(device: _device_t = None):
raise NotImplementedError
@staticmethod
@ -146,30 +147,9 @@ class DeviceInterface:
return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation)
@staticmethod
def memory_allocated(device: torch.types.Device = None) -> int:
def memory_allocated(device: _device_t = None) -> int:
raise NotImplementedError
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
"""
Returns True if the device has Triton support, False otherwise, even if
the appropriate Triton backend is not available.
"""
return False
@classmethod
def raise_if_triton_unavailable(cls, device: torch.types.Device = None) -> None:
"""
Raises a `RuntimeError` with the appropriate human-readable instructions
to resolve the issue if Triton is not available for the given device, or
the default device if `device` is `None`.
The caller should ensure the presence of the 'triton' package before
calling this method.
"""
if not cls.is_triton_capable():
raise RuntimeError("This device is not capable of supporting Triton")
class DeviceGuard:
"""
@ -218,7 +198,7 @@ class CudaInterface(DeviceInterface):
return torch.cuda.current_device()
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: _device_t = None):
if device is not None:
if isinstance(device, str):
device = torch.device(device)
@ -258,36 +238,13 @@ class CudaInterface(DeviceInterface):
return torch.cuda.is_available()
@staticmethod
def get_compute_capability(device: torch.types.Device = None):
def get_compute_capability(device: _device_t = None):
if torch.version.hip is None:
major, min = torch.cuda.get_device_capability(device)
return major * 10 + min
else:
return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return (
torch.version.hip is not None
or torch.cuda.get_device_properties(device).major >= 7
)
@staticmethod
def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
from torch._inductor.exc import GPUTooOldForTriton
if not CudaInterface.is_triton_capable(device):
device_props = torch.cuda.get_device_properties(device)
raise GPUTooOldForTriton(device_props, inspect.currentframe())
import triton.backends
if torch.version.hip is not None:
if "amd" not in triton.backends.backends:
raise RuntimeError("triton not built with the 'amd' backend")
elif "nvidia" not in triton.backends.backends:
raise RuntimeError("triton not built with the 'nvidia' backend")
get_xpu_stream: Optional[Callable[[int], int]]
if torch.xpu._is_compiled():
@ -313,7 +270,7 @@ class XpuInterface(DeviceInterface):
return torch.xpu.current_device()
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: _device_t = None):
if device is not None:
if isinstance(device, str):
device = torch.device(device)
@ -352,7 +309,7 @@ class XpuInterface(DeviceInterface):
return torch.xpu.is_available()
@staticmethod
def get_compute_capability(device: torch.types.Device = None):
def get_compute_capability(device: _device_t = None):
cc = torch.xpu.get_device_capability(device)
return cc
@ -360,17 +317,6 @@ class XpuInterface(DeviceInterface):
def is_bf16_supported(including_emulation: bool = False) -> bool:
return torch.xpu.is_bf16_supported()
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return True
@staticmethod
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
import triton.backends
if "intel" not in triton.backends.backends:
raise RuntimeError("triton not built with the 'intel' backend")
@dataclass
class CpuDeviceProperties:
@ -388,14 +334,6 @@ class CpuInterface(DeviceInterface):
def record(self, stream=None):
self.time = time.perf_counter()
class Worker:
@staticmethod
def get_device_properties(device: torch.types.Device = None):
import multiprocessing
cpu_count = multiprocessing.cpu_count()
return CpuDeviceProperties(cpu_count)
@staticmethod
def is_available() -> bool:
return True
@ -405,7 +343,7 @@ class CpuInterface(DeviceInterface):
return True
@staticmethod
def get_compute_capability(device: torch.types.Device = None) -> str:
def get_compute_capability(device: _device_t = None) -> str:
return ""
@staticmethod
@ -417,19 +355,16 @@ class CpuInterface(DeviceInterface):
return 0
@staticmethod
def synchronize(device: torch.types.Device = None):
def synchronize(device: _device_t = None):
pass
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return True
class Worker:
@staticmethod
def get_device_properties(device: _device_t = None):
import multiprocessing
@staticmethod
def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
import triton.backends
if "cpu" not in triton.backends.backends:
raise RuntimeError("triton not built with the 'cpu' backend")
cpu_count = multiprocessing.cpu_count()
return CpuDeviceProperties(cpu_count)
class MpsInterface(DeviceInterface):
@ -454,16 +389,16 @@ class MpsInterface(DeviceInterface):
return 0
@staticmethod
def get_compute_capability(device: torch.types.Device = None) -> str:
def get_compute_capability(device: _device_t = None) -> str:
return ""
@staticmethod
def synchronize(device: torch.types.Device = None):
def synchronize(device: _device_t = None):
torch.mps.synchronize()
class Worker:
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: _device_t = None):
return {}
@staticmethod

View File

@ -893,7 +893,7 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
return fn
def raise_if_dynamo_unavailable() -> None:
def check_if_dynamo_supported():
if sys.version_info >= (3, 14):
raise RuntimeError("Python 3.14+ not yet supported for torch.compile")
elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1:
@ -902,40 +902,21 @@ def raise_if_dynamo_unavailable() -> None:
)
def is_dynamo_supported() -> bool:
def is_dynamo_supported():
try:
raise_if_dynamo_unavailable()
check_if_dynamo_supported()
return True
except Exception:
return False
def raise_if_inductor_unavailable(device: torch.device | str | None = None) -> None:
from torch._inductor.codegen.common import (
get_scheduling_for_device,
init_backend_registration,
)
raise_if_dynamo_unavailable()
init_backend_registration()
if device is None:
device = torch.get_default_device()
elif isinstance(device, str):
device = torch.device(device)
scheduling_factory = get_scheduling_for_device(device.type)
if scheduling_factory is None:
raise RuntimeError(
f"No Inductor scheduling factory registered for {device.type}"
)
scheduling_factory(None).raise_if_unavailable(device)
def check_if_inductor_supported():
check_if_dynamo_supported()
def is_inductor_supported(device: torch.device | str | None = None) -> bool:
def is_inductor_supported():
try:
raise_if_inductor_unavailable(device)
check_if_inductor_supported()
return True
except Exception:
return False
@ -998,7 +979,7 @@ def _optimize(
@torch._dynamo.optimize()
def toy_example(a, b): ...
"""
raise_if_dynamo_unavailable()
check_if_dynamo_supported()
check_for_incompatible_configs()
# Note: The hooks object could be global instead of passed around, *however* that would make
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
@ -1566,7 +1547,7 @@ def export(
f = _f
specialize_float = _specialize_float
assume_static_by_default = _assume_static_by_default
raise_if_dynamo_unavailable()
check_if_dynamo_supported()
torch._C._log_api_usage_once("torch._dynamo.export")
if decomposition_table is not None:
assert aten_graph, (

View File

@ -44,6 +44,7 @@ _step_counter = itertools.count(1)
# Update num_steps if more phases are added: Dynamo, AOT, Backend
# This is very inductor centric
# _inductor.utils.has_triton() gives a circular import error here
if not disable_progress:
try:

View File

@ -89,7 +89,7 @@ from torch._utils_internal import (
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
from torch.monitor import _WaitCounter
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton, has_triton_package
from torch.utils.hooks import RemovableHandle
@ -1489,7 +1489,7 @@ def record_compilation_metrics(
"dynamo_config": _get_dynamo_config_for_logging(),
"inductor_config": _scrubbed_inductor_config_for_logging(),
"cuda_version": torch.version.cuda,
"triton_version": triton.__version__ if has_triton_package() else "",
"triton_version": triton.__version__ if has_triton() else "",
"remote_cache_version": remote_cache_version,
"inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type,
}
@ -3744,10 +3744,17 @@ def build_checkpoint_variable(**options):
)
def is_compile_supported(device_type: str) -> bool:
from .eval_frame import is_inductor_supported
def is_compile_supported(device_type):
from .eval_frame import is_dynamo_supported
return is_inductor_supported(device_type)
compile_supported = is_dynamo_supported()
if device_type == "cpu":
pass
elif device_type in ["cuda", "xpu"] and compile_supported:
compile_supported = has_triton()
else:
compile_supported = False
return compile_supported
# The following 3.11 source code functions are adapted from

View File

@ -567,9 +567,9 @@ class VariableBuilder:
def _wrap(self, value):
# import here to avoid circular dependencies
from torch.utils._triton import has_triton_package, has_triton_tma
from torch.utils._triton import has_triton, has_triton_tma
if has_triton_package():
if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
else:

View File

@ -38,14 +38,14 @@ if TYPE_CHECKING:
from torch._dynamo.variables.functions import TritonKernelVariable
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
from torch.fx.proxy import Proxy
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton
TritonMetaParamsType = dict[str, int]
TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
if has_triton_package():
if has_triton():
from triton.runtime.autotuner import Autotuner, Config as TritonConfig
from triton.runtime.jit import JITFunction
else:

View File

@ -45,15 +45,8 @@ class CUDACombinedScheduling(BaseScheduling):
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
@classmethod
def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
return TritonScheduling.get_backend_features(device)
@classmethod
def raise_if_unavailable(
cls, device: Union[str, torch.device, None] = None
) -> None:
TritonScheduling.raise_if_unavailable(device)
def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]:
return self._triton_scheduling.get_backend_features(device)
def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):

View File

@ -596,7 +596,7 @@ class MetalScheduling(SIMDScheduling):
def __init__(self, scheduler: Optional[Scheduler]) -> None:
super().__init__(scheduler)
wrapper = getattr(V.graph, "wrapper_code", None)
wrapper = V.graph.wrapper_code
if wrapper is not None:
wrapper.header.splice(
"from torch._inductor.runtime.runtime_utils import compile_mps_shader"

View File

@ -5,7 +5,6 @@ import collections
import contextlib
import dataclasses
import functools
import inspect
import itertools
import logging
import math
@ -33,7 +32,6 @@ from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from ..async_compile import AsyncCompile
from ..codecache import code_hash, get_path, PyCodeCache
from ..exc import TritonMissing
from ..ops_handler import DefaultHandler
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import (
@ -4073,20 +4071,6 @@ class TritonScheduling(SIMDScheduling):
)
return cls.backend_features
@classmethod
def raise_if_unavailable(
cls, device: Union[str, torch.device, None] = None
) -> None:
if not has_triton_package():
raise TritonMissing(inspect.currentframe())
from torch._dynamo.device_interface import get_interface_for_device
if device is None:
device = torch.get_default_device()
get_interface_for_device(device).raise_if_triton_unavailable(device)
def codegen_comment(self, node_schedule):
wrapper = V.graph.wrapper_code
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)

View File

@ -93,7 +93,7 @@ from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame
from ..fx._lazy_graph_module import _use_lazy_graph_module
from ..fx.graph import _PyTreeCodeGen
from ..utils._triton import has_triton_package
from ..utils._triton import has_triton
from . import config, metrics
from .debug import DebugContext
from .decomposition import select_decomp_table
@ -1664,7 +1664,7 @@ def get_cpp_wrapper_config() -> dict[str, object]:
"triton.autotune_at_compile_time": (
config.triton.autotune_at_compile_time
if config.triton.autotune_at_compile_time is not None
else has_triton_package()
else has_triton()
),
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False, # TODO: to be removed

View File

@ -24,9 +24,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._mode_utils import no_dispatch
from ..codegen.common import get_scheduling_for_device
from ..codegen.cuda_combined_scheduling import CUDACombinedScheduling
from ..codegen.triton import TritonScheduling
from ...utils._triton import has_triton
from ..pattern_matcher import (
fwd_only,
gen_register_replacement,
@ -460,11 +458,7 @@ def _should_pad_bench(
):
return True
scheduling_factory = get_scheduling_for_device(mat1.device.type)
if scheduling_factory is None or not isinstance(
scheduling_factory(None),
(TritonScheduling, CUDACombinedScheduling),
):
if not has_triton():
return False
if not is_mm_compute_bound(m, k, n, mat1.dtype):

View File

@ -11,7 +11,7 @@ from typing_extensions import override
import torch
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton
from ..remote_cache import (
create_cache,
@ -36,7 +36,7 @@ def inductor_meta_from_config() -> _InductorMetaTy:
from torch._inductor import config
backend_hash = None
if has_triton_package():
if has_triton():
try:
backend_hash = torch.utils._triton.triton_hash_with_backend()
except RuntimeError:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import collections
import dataclasses
import functools
import inspect
import itertools
import logging
import math
@ -30,12 +31,14 @@ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import free_symbol_is_type, SymT
from torch.utils._triton import has_triton
from . import comms, config, dependencies, ir, metrics
from .analyze_preserves_zero_mask import can_codegen_without_upcasts
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
from .comm_analysis import estimate_nccl_collective_runtime
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
from .exc import GPUTooOldForTriton, TritonMissing
from .ir import (
ComputedBuffer,
get_device_type,
@ -3907,14 +3910,20 @@ class Scheduler:
)
V.graph.add_device_info(device)
device_scheduling_type = get_scheduling_for_device(device.type)
if device_scheduling_type is None:
device_scheduling = get_scheduling_for_device(device.type)
if device_scheduling is None:
raise RuntimeError(f"Unsupported device type: {device.type}")
scheduling = device_scheduling_type(self)
scheduling.raise_if_unavailable(device)
if not has_triton():
if (
device.type == "cuda"
and (device_props := torch.cuda.get_device_properties(device)).major < 7
):
raise GPUTooOldForTriton(device_props, inspect.currentframe())
elif is_gpu(device.type) and not device.type == "mps":
raise TritonMissing(inspect.currentframe())
return scheduling
return device_scheduling(self)
def get_backend(self, device: Optional[torch.device]) -> BaseScheduling:
assert device is not None
@ -4363,17 +4372,6 @@ class BaseScheduling:
"""Return a set of .codegen.common.BackendFeature()"""
return OrderedSet()
@classmethod
def raise_if_unavailable(
cls, device: Union[str, torch.device, None] = None
) -> None:
"""
Raises a RuntimeError if the given device does not support this codegen or required
prerequisites are not available with a useful description for the user. If None is given,
the default device is checked.
"""
return None
def can_fuse_vertical(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:

View File

@ -1337,7 +1337,7 @@ def use_triton_template(
def use_triton_tma_template(*matrices: IRNode) -> bool:
from torch.utils._triton import has_triton_tma
from torch.utils._triton import has_triton_tma_device
from .virtualized import V
@ -1362,7 +1362,7 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
return (
config.triton.enable_persistent_tma_matmul
and has_triton_tma()
and has_triton_tma_device()
and all(_is_tma_compatible(m) for m in matrices)
)

View File

@ -8,7 +8,7 @@ from typing import Optional
import torch
from torch._dynamo.utils import warn_once
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton
from ._triton_ops_meta import get_meta
@ -1323,7 +1323,7 @@ def bsr_dense_addmm(
return out_backup
if has_triton_package():
if has_triton():
import triton
import triton.language as tl

View File

@ -97,8 +97,7 @@ def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) ->
from torch._dynamo.utils import is_compile_supported
device_type = storage.device.type
# FIXME: MPS does not yet support some of the ops required for hashing
if stable_hash or not is_compile_supported(device_type) or device_type == "mps":
if stable_hash or not is_compile_supported(device_type):
cpu_storage = storage.cpu()
# TODO: make storage support buffer protocol so this isn't
# necessary

View File

@ -1,11 +1,6 @@
# mypy: allow-untyped-defs
import functools
import hashlib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.types import Device
@functools.lru_cache(None)
@ -14,47 +9,11 @@ def has_triton_package() -> bool:
from triton.compiler.compiler import triton_key
return triton_key is not None
except (ImportError, RuntimeError):
except ImportError:
return False
@functools.lru_cache(None)
def has_triton(device: "Device" = None) -> bool:
"""
Determine if Triton is available for use on this system for a given device
(if device is not None) or any available device type if no device is given.
"""
import torch
from torch._dynamo.device_interface import (
DeviceInterface,
get_interface_for_device,
get_registered_device_interfaces,
)
if not has_triton_package():
except RuntimeError:
return False
def device_has_triton(di: type[DeviceInterface]) -> bool:
if not di.is_available():
return False
try:
di.raise_if_triton_unavailable(device)
except RuntimeError:
return False
return True
if device is None:
return any(
device_has_triton(di) for _, di in get_registered_device_interfaces()
)
if not isinstance(device, (str, torch.device)):
device = torch.device(device)
return device_has_triton(get_interface_for_device(device))
@functools.lru_cache(None)
def has_triton_tma():
@ -102,6 +61,40 @@ def has_triton_tma_device():
return False
@functools.lru_cache(None)
def has_triton() -> bool:
if not has_triton_package():
return False
from torch._dynamo.device_interface import get_interface_for_device
def cuda_extra_check(device_interface):
return device_interface.Worker.get_device_properties().major >= 7
def cpu_extra_check(device_interface):
import triton.backends
return "cpu" in triton.backends.backends
def _return_true(device_interface):
return True
triton_supported_devices = {
"cuda": cuda_extra_check,
"xpu": _return_true,
"cpu": cpu_extra_check,
}
def is_device_compatible_with_triton():
for device, extra_check in triton_supported_devices.items():
device_interface = get_interface_for_device(device)
if device_interface.is_available() and extra_check(device_interface):
return True
return False
return is_device_compatible_with_triton()
@functools.lru_cache(None)
def triton_backend():
from triton.compiler.compiler import make_backend