mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Use the device interface for detecting Triton availability (#139171)"
This reverts commit 940b60db974f08a31c746eec2f9c399fc8a861ee. Reverted https://github.com/pytorch/pytorch/pull/139171 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @jansel can you please help get these changes working? See D70946254 for more details. To validate the fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/139171#issuecomment-2715392451))
This commit is contained in:
@ -142,7 +142,7 @@ def check_rocm():
|
||||
return rocm_ver if torch.version.hip else "None"
|
||||
|
||||
|
||||
def check_dynamo(backend: str, device: str, err_msg: str) -> None:
|
||||
def check_dynamo(backend, device, err_msg) -> None:
|
||||
import torch
|
||||
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
@ -151,15 +151,17 @@ def check_dynamo(backend: str, device: str, err_msg: str) -> None:
|
||||
|
||||
try:
|
||||
import torch._dynamo as dynamo
|
||||
from torch._dynamo.eval_frame import raise_if_inductor_unavailable
|
||||
|
||||
try:
|
||||
raise_if_inductor_unavailable(device)
|
||||
except RuntimeError as e:
|
||||
print(
|
||||
f"WARNING: Inductor not available for {device} ({e}). Skipping check."
|
||||
)
|
||||
return
|
||||
if device == "cuda":
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
if not has_triton():
|
||||
print(
|
||||
f"WARNING: CUDA available but triton cannot be used. "
|
||||
f"Your GPU may not be supported. "
|
||||
f"Skipping CUDA check on {backend} backend\n"
|
||||
)
|
||||
return
|
||||
|
||||
dynamo.reset()
|
||||
|
||||
@ -203,8 +205,6 @@ _SANITY_CHECK_ARGS = (
|
||||
|
||||
|
||||
def main() -> None:
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
|
||||
python_ver = check_python()
|
||||
torch_ver = check_torch()
|
||||
cuda_ver = check_cuda()
|
||||
@ -215,10 +215,10 @@ def main() -> None:
|
||||
f"CUDA version: {cuda_ver}\n"
|
||||
f"ROCM version: {rocm_ver}\n"
|
||||
)
|
||||
if not is_dynamo_supported():
|
||||
warnings.warn("Dynamo is not supported on this platform. Skipping check.")
|
||||
return
|
||||
for args in _SANITY_CHECK_ARGS:
|
||||
if sys.version_info >= (3, 13):
|
||||
warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.")
|
||||
continue
|
||||
check_dynamo(*args)
|
||||
print("All required checks passed")
|
||||
|
||||
|
@ -17,7 +17,6 @@ The abstraction layer enables device-agnostic code in TorchDynamo while allowing
|
||||
specialized implementations for each hardware backend's unique features.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
@ -32,6 +31,8 @@ if torch.cuda._is_compiled():
|
||||
else:
|
||||
get_cuda_stream = None
|
||||
|
||||
_device_t = Union[torch.device, str, int, None]
|
||||
|
||||
# Recording the device properties in the main process but used in worker process.
|
||||
caching_worker_device_properties: dict[str, Any] = {}
|
||||
caching_worker_current_devices: dict[str, int] = {}
|
||||
@ -44,7 +45,7 @@ class DeviceInterface:
|
||||
"""
|
||||
|
||||
class device:
|
||||
def __new__(cls, device: torch.types.Device):
|
||||
def __new__(cls, device: _device_t):
|
||||
raise NotImplementedError
|
||||
|
||||
class Event:
|
||||
@ -76,7 +77,7 @@ class DeviceInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: _device_t = None):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -84,7 +85,7 @@ class DeviceInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def set_device(device: torch.types.Device):
|
||||
def set_device(device: _device_t):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -124,15 +125,15 @@ class DeviceInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def synchronize(device: torch.types.Device = None):
|
||||
def synchronize(device: _device_t = None):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_properties(cls, device: torch.types.Device = None):
|
||||
def get_device_properties(cls, device: _device_t = None):
|
||||
return cls.Worker.get_device_properties(device)
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None):
|
||||
def get_compute_capability(device: _device_t = None):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -146,30 +147,9 @@ class DeviceInterface:
|
||||
return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation)
|
||||
|
||||
@staticmethod
|
||||
def memory_allocated(device: torch.types.Device = None) -> int:
|
||||
def memory_allocated(device: _device_t = None) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
"""
|
||||
Returns True if the device has Triton support, False otherwise, even if
|
||||
the appropriate Triton backend is not available.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def raise_if_triton_unavailable(cls, device: torch.types.Device = None) -> None:
|
||||
"""
|
||||
Raises a `RuntimeError` with the appropriate human-readable instructions
|
||||
to resolve the issue if Triton is not available for the given device, or
|
||||
the default device if `device` is `None`.
|
||||
|
||||
The caller should ensure the presence of the 'triton' package before
|
||||
calling this method.
|
||||
"""
|
||||
if not cls.is_triton_capable():
|
||||
raise RuntimeError("This device is not capable of supporting Triton")
|
||||
|
||||
|
||||
class DeviceGuard:
|
||||
"""
|
||||
@ -218,7 +198,7 @@ class CudaInterface(DeviceInterface):
|
||||
return torch.cuda.current_device()
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: _device_t = None):
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
@ -258,36 +238,13 @@ class CudaInterface(DeviceInterface):
|
||||
return torch.cuda.is_available()
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None):
|
||||
def get_compute_capability(device: _device_t = None):
|
||||
if torch.version.hip is None:
|
||||
major, min = torch.cuda.get_device_capability(device)
|
||||
return major * 10 + min
|
||||
else:
|
||||
return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return (
|
||||
torch.version.hip is not None
|
||||
or torch.cuda.get_device_properties(device).major >= 7
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
|
||||
from torch._inductor.exc import GPUTooOldForTriton
|
||||
|
||||
if not CudaInterface.is_triton_capable(device):
|
||||
device_props = torch.cuda.get_device_properties(device)
|
||||
raise GPUTooOldForTriton(device_props, inspect.currentframe())
|
||||
|
||||
import triton.backends
|
||||
|
||||
if torch.version.hip is not None:
|
||||
if "amd" not in triton.backends.backends:
|
||||
raise RuntimeError("triton not built with the 'amd' backend")
|
||||
elif "nvidia" not in triton.backends.backends:
|
||||
raise RuntimeError("triton not built with the 'nvidia' backend")
|
||||
|
||||
|
||||
get_xpu_stream: Optional[Callable[[int], int]]
|
||||
if torch.xpu._is_compiled():
|
||||
@ -313,7 +270,7 @@ class XpuInterface(DeviceInterface):
|
||||
return torch.xpu.current_device()
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: _device_t = None):
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
@ -352,7 +309,7 @@ class XpuInterface(DeviceInterface):
|
||||
return torch.xpu.is_available()
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None):
|
||||
def get_compute_capability(device: _device_t = None):
|
||||
cc = torch.xpu.get_device_capability(device)
|
||||
return cc
|
||||
|
||||
@ -360,17 +317,6 @@ class XpuInterface(DeviceInterface):
|
||||
def is_bf16_supported(including_emulation: bool = False) -> bool:
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
|
||||
import triton.backends
|
||||
|
||||
if "intel" not in triton.backends.backends:
|
||||
raise RuntimeError("triton not built with the 'intel' backend")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CpuDeviceProperties:
|
||||
@ -388,14 +334,6 @@ class CpuInterface(DeviceInterface):
|
||||
def record(self, stream=None):
|
||||
self.time = time.perf_counter()
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
import multiprocessing
|
||||
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
return CpuDeviceProperties(cpu_count)
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return True
|
||||
@ -405,7 +343,7 @@ class CpuInterface(DeviceInterface):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None) -> str:
|
||||
def get_compute_capability(device: _device_t = None) -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
@ -417,19 +355,16 @@ class CpuInterface(DeviceInterface):
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def synchronize(device: torch.types.Device = None):
|
||||
def synchronize(device: _device_t = None):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return True
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(device: _device_t = None):
|
||||
import multiprocessing
|
||||
|
||||
@staticmethod
|
||||
def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
|
||||
import triton.backends
|
||||
|
||||
if "cpu" not in triton.backends.backends:
|
||||
raise RuntimeError("triton not built with the 'cpu' backend")
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
return CpuDeviceProperties(cpu_count)
|
||||
|
||||
|
||||
class MpsInterface(DeviceInterface):
|
||||
@ -454,16 +389,16 @@ class MpsInterface(DeviceInterface):
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None) -> str:
|
||||
def get_compute_capability(device: _device_t = None) -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def synchronize(device: torch.types.Device = None):
|
||||
def synchronize(device: _device_t = None):
|
||||
torch.mps.synchronize()
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: _device_t = None):
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
|
@ -893,7 +893,7 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
||||
return fn
|
||||
|
||||
|
||||
def raise_if_dynamo_unavailable() -> None:
|
||||
def check_if_dynamo_supported():
|
||||
if sys.version_info >= (3, 14):
|
||||
raise RuntimeError("Python 3.14+ not yet supported for torch.compile")
|
||||
elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1:
|
||||
@ -902,40 +902,21 @@ def raise_if_dynamo_unavailable() -> None:
|
||||
)
|
||||
|
||||
|
||||
def is_dynamo_supported() -> bool:
|
||||
def is_dynamo_supported():
|
||||
try:
|
||||
raise_if_dynamo_unavailable()
|
||||
check_if_dynamo_supported()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def raise_if_inductor_unavailable(device: torch.device | str | None = None) -> None:
|
||||
from torch._inductor.codegen.common import (
|
||||
get_scheduling_for_device,
|
||||
init_backend_registration,
|
||||
)
|
||||
|
||||
raise_if_dynamo_unavailable()
|
||||
|
||||
init_backend_registration()
|
||||
|
||||
if device is None:
|
||||
device = torch.get_default_device()
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
scheduling_factory = get_scheduling_for_device(device.type)
|
||||
if scheduling_factory is None:
|
||||
raise RuntimeError(
|
||||
f"No Inductor scheduling factory registered for {device.type}"
|
||||
)
|
||||
scheduling_factory(None).raise_if_unavailable(device)
|
||||
def check_if_inductor_supported():
|
||||
check_if_dynamo_supported()
|
||||
|
||||
|
||||
def is_inductor_supported(device: torch.device | str | None = None) -> bool:
|
||||
def is_inductor_supported():
|
||||
try:
|
||||
raise_if_inductor_unavailable(device)
|
||||
check_if_inductor_supported()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@ -998,7 +979,7 @@ def _optimize(
|
||||
@torch._dynamo.optimize()
|
||||
def toy_example(a, b): ...
|
||||
"""
|
||||
raise_if_dynamo_unavailable()
|
||||
check_if_dynamo_supported()
|
||||
check_for_incompatible_configs()
|
||||
# Note: The hooks object could be global instead of passed around, *however* that would make
|
||||
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
|
||||
@ -1566,7 +1547,7 @@ def export(
|
||||
f = _f
|
||||
specialize_float = _specialize_float
|
||||
assume_static_by_default = _assume_static_by_default
|
||||
raise_if_dynamo_unavailable()
|
||||
check_if_dynamo_supported()
|
||||
torch._C._log_api_usage_once("torch._dynamo.export")
|
||||
if decomposition_table is not None:
|
||||
assert aten_graph, (
|
||||
|
@ -44,6 +44,7 @@ _step_counter = itertools.count(1)
|
||||
|
||||
# Update num_steps if more phases are added: Dynamo, AOT, Backend
|
||||
# This is very inductor centric
|
||||
# _inductor.utils.has_triton() gives a circular import error here
|
||||
|
||||
if not disable_progress:
|
||||
try:
|
||||
|
@ -89,7 +89,7 @@ from torch._utils_internal import (
|
||||
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
|
||||
from torch.monitor import _WaitCounter
|
||||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton, has_triton_package
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
@ -1489,7 +1489,7 @@ def record_compilation_metrics(
|
||||
"dynamo_config": _get_dynamo_config_for_logging(),
|
||||
"inductor_config": _scrubbed_inductor_config_for_logging(),
|
||||
"cuda_version": torch.version.cuda,
|
||||
"triton_version": triton.__version__ if has_triton_package() else "",
|
||||
"triton_version": triton.__version__ if has_triton() else "",
|
||||
"remote_cache_version": remote_cache_version,
|
||||
"inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type,
|
||||
}
|
||||
@ -3744,10 +3744,17 @@ def build_checkpoint_variable(**options):
|
||||
)
|
||||
|
||||
|
||||
def is_compile_supported(device_type: str) -> bool:
|
||||
from .eval_frame import is_inductor_supported
|
||||
def is_compile_supported(device_type):
|
||||
from .eval_frame import is_dynamo_supported
|
||||
|
||||
return is_inductor_supported(device_type)
|
||||
compile_supported = is_dynamo_supported()
|
||||
if device_type == "cpu":
|
||||
pass
|
||||
elif device_type in ["cuda", "xpu"] and compile_supported:
|
||||
compile_supported = has_triton()
|
||||
else:
|
||||
compile_supported = False
|
||||
return compile_supported
|
||||
|
||||
|
||||
# The following 3.11 source code functions are adapted from
|
||||
|
@ -567,9 +567,9 @@ class VariableBuilder:
|
||||
|
||||
def _wrap(self, value):
|
||||
# import here to avoid circular dependencies
|
||||
from torch.utils._triton import has_triton_package, has_triton_tma
|
||||
from torch.utils._triton import has_triton, has_triton_tma
|
||||
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
from triton.runtime.jit import JITFunction
|
||||
else:
|
||||
|
@ -38,14 +38,14 @@ if TYPE_CHECKING:
|
||||
from torch._dynamo.variables.functions import TritonKernelVariable
|
||||
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
|
||||
from torch.fx.proxy import Proxy
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
TritonMetaParamsType = dict[str, int]
|
||||
TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
|
||||
TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
|
||||
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
|
||||
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner, Config as TritonConfig
|
||||
from triton.runtime.jit import JITFunction
|
||||
else:
|
||||
|
@ -45,15 +45,8 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
|
||||
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
|
||||
|
||||
@classmethod
|
||||
def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
|
||||
return TritonScheduling.get_backend_features(device)
|
||||
|
||||
@classmethod
|
||||
def raise_if_unavailable(
|
||||
cls, device: Union[str, torch.device, None] = None
|
||||
) -> None:
|
||||
TritonScheduling.raise_if_unavailable(device)
|
||||
def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]:
|
||||
return self._triton_scheduling.get_backend_features(device)
|
||||
|
||||
def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
|
||||
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
|
||||
|
@ -596,7 +596,7 @@ class MetalScheduling(SIMDScheduling):
|
||||
|
||||
def __init__(self, scheduler: Optional[Scheduler]) -> None:
|
||||
super().__init__(scheduler)
|
||||
wrapper = getattr(V.graph, "wrapper_code", None)
|
||||
wrapper = V.graph.wrapper_code
|
||||
if wrapper is not None:
|
||||
wrapper.header.splice(
|
||||
"from torch._inductor.runtime.runtime_utils import compile_mps_shader"
|
||||
|
@ -5,7 +5,6 @@ import collections
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@ -33,7 +32,6 @@ from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..async_compile import AsyncCompile
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..exc import TritonMissing
|
||||
from ..ops_handler import DefaultHandler
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import (
|
||||
@ -4073,20 +4071,6 @@ class TritonScheduling(SIMDScheduling):
|
||||
)
|
||||
return cls.backend_features
|
||||
|
||||
@classmethod
|
||||
def raise_if_unavailable(
|
||||
cls, device: Union[str, torch.device, None] = None
|
||||
) -> None:
|
||||
if not has_triton_package():
|
||||
raise TritonMissing(inspect.currentframe())
|
||||
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
if device is None:
|
||||
device = torch.get_default_device()
|
||||
|
||||
get_interface_for_device(device).raise_if_triton_unavailable(device)
|
||||
|
||||
def codegen_comment(self, node_schedule):
|
||||
wrapper = V.graph.wrapper_code
|
||||
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||
|
@ -93,7 +93,7 @@ from .._dynamo.backends.common import aot_autograd
|
||||
from .._dynamo.exc import ShortenTraceback, SkipFrame
|
||||
from ..fx._lazy_graph_module import _use_lazy_graph_module
|
||||
from ..fx.graph import _PyTreeCodeGen
|
||||
from ..utils._triton import has_triton_package
|
||||
from ..utils._triton import has_triton
|
||||
from . import config, metrics
|
||||
from .debug import DebugContext
|
||||
from .decomposition import select_decomp_table
|
||||
@ -1664,7 +1664,7 @@ def get_cpp_wrapper_config() -> dict[str, object]:
|
||||
"triton.autotune_at_compile_time": (
|
||||
config.triton.autotune_at_compile_time
|
||||
if config.triton.autotune_at_compile_time is not None
|
||||
else has_triton_package()
|
||||
else has_triton()
|
||||
),
|
||||
"triton.autotune_cublasLt": False,
|
||||
"triton.cudagraphs": False, # TODO: to be removed
|
||||
|
@ -24,9 +24,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
from ..codegen.common import get_scheduling_for_device
|
||||
from ..codegen.cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from ..codegen.triton import TritonScheduling
|
||||
from ...utils._triton import has_triton
|
||||
from ..pattern_matcher import (
|
||||
fwd_only,
|
||||
gen_register_replacement,
|
||||
@ -460,11 +458,7 @@ def _should_pad_bench(
|
||||
):
|
||||
return True
|
||||
|
||||
scheduling_factory = get_scheduling_for_device(mat1.device.type)
|
||||
if scheduling_factory is None or not isinstance(
|
||||
scheduling_factory(None),
|
||||
(TritonScheduling, CUDACombinedScheduling),
|
||||
):
|
||||
if not has_triton():
|
||||
return False
|
||||
|
||||
if not is_mm_compute_bound(m, k, n, mat1.dtype):
|
||||
|
@ -11,7 +11,7 @@ from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from ..remote_cache import (
|
||||
create_cache,
|
||||
@ -36,7 +36,7 @@ def inductor_meta_from_config() -> _InductorMetaTy:
|
||||
from torch._inductor import config
|
||||
|
||||
backend_hash = None
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
try:
|
||||
backend_hash = torch.utils._triton.triton_hash_with_backend()
|
||||
except RuntimeError:
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@ -30,12 +31,14 @@ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, SymT
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from . import comms, config, dependencies, ir, metrics
|
||||
from .analyze_preserves_zero_mask import can_codegen_without_upcasts
|
||||
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
|
||||
from .comm_analysis import estimate_nccl_collective_runtime
|
||||
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from .exc import GPUTooOldForTriton, TritonMissing
|
||||
from .ir import (
|
||||
ComputedBuffer,
|
||||
get_device_type,
|
||||
@ -3907,14 +3910,20 @@ class Scheduler:
|
||||
)
|
||||
V.graph.add_device_info(device)
|
||||
|
||||
device_scheduling_type = get_scheduling_for_device(device.type)
|
||||
if device_scheduling_type is None:
|
||||
device_scheduling = get_scheduling_for_device(device.type)
|
||||
if device_scheduling is None:
|
||||
raise RuntimeError(f"Unsupported device type: {device.type}")
|
||||
|
||||
scheduling = device_scheduling_type(self)
|
||||
scheduling.raise_if_unavailable(device)
|
||||
if not has_triton():
|
||||
if (
|
||||
device.type == "cuda"
|
||||
and (device_props := torch.cuda.get_device_properties(device)).major < 7
|
||||
):
|
||||
raise GPUTooOldForTriton(device_props, inspect.currentframe())
|
||||
elif is_gpu(device.type) and not device.type == "mps":
|
||||
raise TritonMissing(inspect.currentframe())
|
||||
|
||||
return scheduling
|
||||
return device_scheduling(self)
|
||||
|
||||
def get_backend(self, device: Optional[torch.device]) -> BaseScheduling:
|
||||
assert device is not None
|
||||
@ -4363,17 +4372,6 @@ class BaseScheduling:
|
||||
"""Return a set of .codegen.common.BackendFeature()"""
|
||||
return OrderedSet()
|
||||
|
||||
@classmethod
|
||||
def raise_if_unavailable(
|
||||
cls, device: Union[str, torch.device, None] = None
|
||||
) -> None:
|
||||
"""
|
||||
Raises a RuntimeError if the given device does not support this codegen or required
|
||||
prerequisites are not available with a useful description for the user. If None is given,
|
||||
the default device is checked.
|
||||
"""
|
||||
return None
|
||||
|
||||
def can_fuse_vertical(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
|
@ -1337,7 +1337,7 @@ def use_triton_template(
|
||||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
from torch.utils._triton import has_triton_tma
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
@ -1362,7 +1362,7 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
|
||||
return (
|
||||
config.triton.enable_persistent_tma_matmul
|
||||
and has_triton_tma()
|
||||
and has_triton_tma_device()
|
||||
and all(_is_tma_compatible(m) for m in matrices)
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import warn_once
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from ._triton_ops_meta import get_meta
|
||||
|
||||
@ -1323,7 +1323,7 @@ def bsr_dense_addmm(
|
||||
return out_backup
|
||||
|
||||
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
@ -97,8 +97,7 @@ def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) ->
|
||||
from torch._dynamo.utils import is_compile_supported
|
||||
|
||||
device_type = storage.device.type
|
||||
# FIXME: MPS does not yet support some of the ops required for hashing
|
||||
if stable_hash or not is_compile_supported(device_type) or device_type == "mps":
|
||||
if stable_hash or not is_compile_supported(device_type):
|
||||
cpu_storage = storage.cpu()
|
||||
# TODO: make storage support buffer protocol so this isn't
|
||||
# necessary
|
||||
|
@ -1,11 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import Device
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@ -14,47 +9,11 @@ def has_triton_package() -> bool:
|
||||
from triton.compiler.compiler import triton_key
|
||||
|
||||
return triton_key is not None
|
||||
except (ImportError, RuntimeError):
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton(device: "Device" = None) -> bool:
|
||||
"""
|
||||
Determine if Triton is available for use on this system for a given device
|
||||
(if device is not None) or any available device type if no device is given.
|
||||
"""
|
||||
import torch
|
||||
from torch._dynamo.device_interface import (
|
||||
DeviceInterface,
|
||||
get_interface_for_device,
|
||||
get_registered_device_interfaces,
|
||||
)
|
||||
|
||||
if not has_triton_package():
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
def device_has_triton(di: type[DeviceInterface]) -> bool:
|
||||
if not di.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
di.raise_if_triton_unavailable(device)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if device is None:
|
||||
return any(
|
||||
device_has_triton(di) for _, di in get_registered_device_interfaces()
|
||||
)
|
||||
|
||||
if not isinstance(device, (str, torch.device)):
|
||||
device = torch.device(device)
|
||||
|
||||
return device_has_triton(get_interface_for_device(device))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton_tma():
|
||||
@ -102,6 +61,40 @@ def has_triton_tma_device():
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton() -> bool:
|
||||
if not has_triton_package():
|
||||
return False
|
||||
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
def cuda_extra_check(device_interface):
|
||||
return device_interface.Worker.get_device_properties().major >= 7
|
||||
|
||||
def cpu_extra_check(device_interface):
|
||||
import triton.backends
|
||||
|
||||
return "cpu" in triton.backends.backends
|
||||
|
||||
def _return_true(device_interface):
|
||||
return True
|
||||
|
||||
triton_supported_devices = {
|
||||
"cuda": cuda_extra_check,
|
||||
"xpu": _return_true,
|
||||
"cpu": cpu_extra_check,
|
||||
}
|
||||
|
||||
def is_device_compatible_with_triton():
|
||||
for device, extra_check in triton_supported_devices.items():
|
||||
device_interface = get_interface_for_device(device)
|
||||
if device_interface.is_available() and extra_check(device_interface):
|
||||
return True
|
||||
return False
|
||||
|
||||
return is_device_compatible_with_triton()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def triton_backend():
|
||||
from triton.compiler.compiler import make_backend
|
||||
|
Reference in New Issue
Block a user