mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Re-land][Inductor] Support native Inductor as backend for MTIA (#159211)
The previous [diff/PR] (https://github.com/pytorch/pytorch/pull/158526) was reverted due to this docstring lint error: <img width="1736" height="722" alt="image" src="https://github.com/user-attachments/assets/216b1720-4002-48da-b5f3-32b5d48aaa54" /> I didn't add the docstring cause I thought I'm not supposed to add docstring for an EXISTING function. So this diff/PR is an exactly copy of the previous one, except for adding the docstring. ------------- This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly. The changes include: - Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc. - Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc. - MTIA specific codegen logic, for example, loading MTIA dynamic_library. - Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU. - Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend. - A change in Inductor runtime to avoid re-initialize MTIADriver. - BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag. - Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag. - Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose. Note: - This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead. - MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen. Internal: References: - [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/) - [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb) - [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w) - [early prototying diff](https://www.internalfb.com/diff/D75110196) - [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959) - [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678) Differential Revision: [D79040806](https://our.internmc.facebook.com/intern/diff/D79040806/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159211 Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/jansel
This commit is contained in:
@ -16,4 +16,5 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
|
||||
:nosignatures:
|
||||
|
||||
memory_stats
|
||||
memory_allocated
|
||||
```
|
||||
|
@ -1948,8 +1948,10 @@ def _mtia_isBuilt() -> _bool: ...
|
||||
def _mtia_isInBadFork() -> _bool: ...
|
||||
def _mtia_deviceSynchronize() -> None: ...
|
||||
def _mtia_getCurrentStream(device: _int) -> Stream: ...
|
||||
def _mtia_getCurrentRawStream(device: _int) -> _int: ...
|
||||
def _mtia_setCurrentStream(stream: Stream) -> None: ...
|
||||
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
||||
def _mtia_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
|
||||
def _mtia_memoryStats(device: _int) -> dict[str, Any]: ...
|
||||
def _mtia_getDeviceCapability(device: _int) -> tuple[_int, _int]: ...
|
||||
def _mtia_getDeviceProperties(device: _int) -> dict[str, Any]: ...
|
||||
|
@ -2,10 +2,10 @@
|
||||
Device abstraction layer for TorchDynamo and Inductor backends.
|
||||
|
||||
This module provides a unified interface for different hardware backends (CUDA, XPU,
|
||||
CPU, MPS) through a common device interface. Key components include:
|
||||
CPU, MPS, MTIA) through a common device interface. Key components include:
|
||||
|
||||
- DeviceInterface: Base class defining the common API for all device types
|
||||
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface
|
||||
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface, MtiaInterface
|
||||
- Device registration system for managing available backends
|
||||
- Worker APIs for multi-processing scenarios
|
||||
- Stream and event management across different devices
|
||||
@ -287,6 +287,87 @@ class CudaInterface(DeviceInterface):
|
||||
raise RuntimeError("triton not built with the 'nvidia' backend")
|
||||
|
||||
|
||||
get_mtia_stream: Optional[Callable[[int], int]]
|
||||
if torch.mtia._is_compiled():
|
||||
from torch._C import _mtia_getCurrentRawStream as get_mtia_stream
|
||||
else:
|
||||
get_mtia_stream = None
|
||||
|
||||
|
||||
class MtiaInterface(DeviceInterface):
|
||||
device = torch.mtia.device # type: ignore[assignment]
|
||||
Event = torch.mtia.Event # type: ignore[assignment]
|
||||
Stream = torch.mtia.Stream # type: ignore[assignment]
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def set_device(device: int) -> None:
|
||||
caching_worker_current_devices["mtia"] = device
|
||||
|
||||
@staticmethod
|
||||
def current_device() -> int:
|
||||
if "mtia" in caching_worker_current_devices:
|
||||
return caching_worker_current_devices["mtia"]
|
||||
return torch.mtia.current_device()
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
assert device.type == "mtia"
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if device is None:
|
||||
device = MtiaInterface.Worker.current_device()
|
||||
|
||||
if "mtia" not in caching_worker_device_properties:
|
||||
device_prop = [
|
||||
torch.mtia.get_device_properties(i)
|
||||
for i in range(torch.mtia.device_count())
|
||||
]
|
||||
caching_worker_device_properties["mtia"] = device_prop
|
||||
|
||||
return caching_worker_device_properties["mtia"][device]
|
||||
|
||||
current_device = staticmethod(torch.mtia.current_device)
|
||||
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
|
||||
device_count = staticmethod(torch.mtia.device_count)
|
||||
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
|
||||
current_stream = staticmethod(torch.mtia.current_stream)
|
||||
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
|
||||
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
|
||||
synchronize = staticmethod(torch.mtia.synchronize)
|
||||
get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment]
|
||||
get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type]
|
||||
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type]
|
||||
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type]
|
||||
memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment]
|
||||
is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type]
|
||||
|
||||
# Can be mock patched by @patch decorator.
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
ret = torch.mtia.is_available()
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None) -> Any:
|
||||
cc = torch.mtia.get_device_capability(device)
|
||||
return cc
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
|
||||
import triton.backends
|
||||
|
||||
if "mtia" not in triton.backends.backends:
|
||||
raise RuntimeError("triton not built with the 'mtia' backend")
|
||||
|
||||
|
||||
get_xpu_stream: Optional[Callable[[int], int]]
|
||||
if torch.xpu._is_compiled():
|
||||
from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
|
||||
@ -509,6 +590,10 @@ def init_device_reg() -> None:
|
||||
for i in range(torch.xpu.device_count()):
|
||||
register_interface_for_device(f"xpu:{i}", XpuInterface)
|
||||
|
||||
register_interface_for_device("mtia", MtiaInterface)
|
||||
for i in range(torch.mtia.device_count()):
|
||||
register_interface_for_device(f"mtia:{i}", MtiaInterface)
|
||||
|
||||
register_interface_for_device("cpu", CpuInterface)
|
||||
register_interface_for_device("mps", MpsInterface)
|
||||
|
||||
|
@ -3967,7 +3967,7 @@ def is_compile_supported(device_type):
|
||||
compile_supported = is_dynamo_supported()
|
||||
if type == "cpu":
|
||||
pass
|
||||
elif type in ["cuda", "xpu"] and compile_supported:
|
||||
elif type in ["cuda", "xpu", "mtia"] and compile_supported:
|
||||
compile_supported = has_triton()
|
||||
else:
|
||||
compile_supported = False
|
||||
|
@ -16,6 +16,7 @@ from .template_heuristics import (
|
||||
BaseConfigHeuristic,
|
||||
CPUConfigHeuristic,
|
||||
CUDAConfigHeuristic,
|
||||
MTIAConfigHeuristic,
|
||||
ROCmConfigHeuristic,
|
||||
XPUConfigHeuristic,
|
||||
)
|
||||
@ -65,6 +66,8 @@ class InductorChoices:
|
||||
return XPUConfigHeuristic()
|
||||
elif device_type == "cpu":
|
||||
return CPUConfigHeuristic()
|
||||
elif device_type == "mtia":
|
||||
return MTIAConfigHeuristic()
|
||||
else:
|
||||
return BaseConfigHeuristic()
|
||||
|
||||
|
@ -484,6 +484,10 @@ def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
|
||||
|
||||
@functools.cache
|
||||
def init_backend_registration() -> None:
|
||||
"""
|
||||
Register the backend for different devices, including the scheduling
|
||||
for kernel code generation and the host side wrapper code generation.
|
||||
"""
|
||||
from .cpp import CppScheduling
|
||||
from .cpp_wrapper_cpu import CppWrapperCpu
|
||||
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
|
||||
@ -492,6 +496,7 @@ def init_backend_registration() -> None:
|
||||
from .cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from .halide import HalideScheduling
|
||||
from .mps import MetalScheduling
|
||||
from .python_wrapper_mtia import PythonWrapperMtia
|
||||
from .triton import TritonScheduling
|
||||
from .wrapper import PythonWrapperCodegen
|
||||
|
||||
@ -539,6 +544,14 @@ def init_backend_registration() -> None:
|
||||
CppWrapperMps,
|
||||
)
|
||||
|
||||
if get_scheduling_for_device("mtia") is None:
|
||||
register_backend_for_device(
|
||||
"mtia",
|
||||
TritonScheduling,
|
||||
PythonWrapperMtia,
|
||||
CppWrapperGpu,
|
||||
)
|
||||
|
||||
private_backend = torch._C._get_privateuse1_backend_name()
|
||||
if (
|
||||
private_backend != "privateuseone"
|
||||
@ -584,6 +597,7 @@ def get_device_op_overrides(device: str) -> DeviceOpOverrides:
|
||||
if not device_op_overrides_dict:
|
||||
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
|
||||
from .cuda import device_op_overrides # noqa: F401
|
||||
from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401
|
||||
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
|
||||
|
||||
return device_op_overrides_dict[device]
|
||||
|
0
torch/_inductor/codegen/mtia/__init__.py
Normal file
0
torch/_inductor/codegen/mtia/__init__.py
Normal file
20
torch/_inductor/codegen/mtia/device_op_overrides.py
Normal file
20
torch/_inductor/codegen/mtia/device_op_overrides.py
Normal file
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..common import DeviceOpOverrides, register_device_op_overrides
|
||||
|
||||
|
||||
class MTIADeviceOpOverrides(DeviceOpOverrides):
|
||||
def import_get_raw_stream_as(self, name: str) -> str:
|
||||
return f"from torch._C import _mtia_getCurrentRawStream as {name}"
|
||||
|
||||
def set_device(self, device_idx: int) -> str:
|
||||
return f"torch.mtia.set_device({device_idx})"
|
||||
|
||||
def synchronize(self) -> str:
|
||||
return "torch.mtia.synchronize()"
|
||||
|
||||
def device_guard(self, device_idx: int) -> str:
|
||||
return f"torch.mtia.device({device_idx})"
|
||||
|
||||
|
||||
register_device_op_overrides("mtia", MTIADeviceOpOverrides())
|
34
torch/_inductor/codegen/python_wrapper_mtia.py
Normal file
34
torch/_inductor/codegen/python_wrapper_mtia.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from torch._inductor import ir
|
||||
|
||||
from .wrapper import PythonWrapperCodegen
|
||||
|
||||
|
||||
class PythonWrapperMtia(PythonWrapperCodegen):
|
||||
"""
|
||||
A thin wrapper of PythonWrapperCodegen with MTIA specific logic
|
||||
"""
|
||||
|
||||
@override
|
||||
def write_header(self) -> None:
|
||||
super().write_header()
|
||||
|
||||
# MITA specific imports
|
||||
self.imports.splice("import mtia.host_runtime.torch_mtia.dynamic_library")
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def create(
|
||||
is_subgraph: bool,
|
||||
subgraph_name: Optional[str],
|
||||
parent_wrapper: Optional[PythonWrapperCodegen],
|
||||
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
|
||||
) -> PythonWrapperCodegen:
|
||||
if is_subgraph:
|
||||
# Delegate to the parent class to handle the case of subgraph
|
||||
return PythonWrapperCodegen.create(
|
||||
is_subgraph, subgraph_name, parent_wrapper, partition_signatures
|
||||
)
|
||||
return PythonWrapperMtia()
|
@ -997,6 +997,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
||||
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
||||
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
||||
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
||||
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
||||
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
||||
async_compile = AsyncCompile()
|
||||
@ -2777,7 +2778,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
allocation_shape
|
||||
)
|
||||
codegen_stride_tuple = self.codegen_python_shape_tuple(stride)
|
||||
if device.type in ("cpu", "cuda", "xpu"):
|
||||
if device.type in ("cpu", "cuda", "xpu", "mtia"):
|
||||
# optimized path for faster allocations, saving ~2us versus the stuff below
|
||||
out = (
|
||||
f"{name} = empty_strided_{device.type}("
|
||||
|
@ -156,6 +156,8 @@ class DeviceProperties(typing.NamedTuple):
|
||||
elif device_type == "mps":
|
||||
# TODO: Fetch the actual value from ioreg
|
||||
multi_processor_count = 8
|
||||
elif device_type == "mtia":
|
||||
multi_processor_count = 64
|
||||
else:
|
||||
raise
|
||||
return cls(
|
||||
|
@ -1201,3 +1201,9 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
|
||||
"""
|
||||
Placeholder child class for XPU specific overrides.
|
||||
"""
|
||||
|
||||
|
||||
class MTIAConfigHeuristic(BaseConfigHeuristic):
|
||||
"""
|
||||
Placeholder child class for MTIA specific overrides.
|
||||
"""
|
||||
|
@ -94,7 +94,7 @@ if TYPE_CHECKING:
|
||||
from .scheduler import BaseSchedulerNode, SchedulerBuffer
|
||||
|
||||
|
||||
GPU_TYPES = ["cuda", "mps", "xpu"]
|
||||
GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
|
@ -31,6 +31,10 @@
|
||||
#include <ATen/xpu/EmptyTensor.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_MTIA
|
||||
#include <ATen/native/mtia/EmptyTensor.h>
|
||||
#endif
|
||||
|
||||
#include <chrono>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
@ -1059,6 +1063,12 @@ static PyObject* _empty_strided_device(
|
||||
return THPVariable_Wrap(at::detail::empty_strided_xpu(
|
||||
sizes, strides, dtype, c10::DeviceType::XPU));
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_MTIA
|
||||
else if (device_type == c10::DeviceType::MTIA) {
|
||||
return THPVariable_Wrap(at::detail::empty_strided_mtia(
|
||||
sizes, strides, dtype, c10::DeviceType::MTIA));
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
TORCH_CHECK(
|
||||
@ -1084,6 +1094,10 @@ static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) {
|
||||
return _empty_strided_device(dummy, args, c10::DeviceType::XPU);
|
||||
}
|
||||
|
||||
static PyObject* _empty_strided_mtia(PyObject* dummy, PyObject* args) {
|
||||
return _empty_strided_device(dummy, args, c10::DeviceType::MTIA);
|
||||
}
|
||||
|
||||
static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
static PythonArgParser parser(
|
||||
@ -1115,6 +1129,7 @@ static PyMethodDef _methods[] = {
|
||||
{"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
|
||||
{"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
|
||||
{"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr},
|
||||
{"_empty_strided_mtia", _empty_strided_mtia, METH_VARARGS, nullptr},
|
||||
{"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
|
@ -63,6 +63,18 @@ void initModule(PyObject* module) {
|
||||
return at::detail::getMTIAHooks().getDefaultStream(device_index);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_mtia_setStream",
|
||||
[](int64_t stream_id,
|
||||
c10::DeviceIndex device_index,
|
||||
int64_t device_type) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
at::detail::getMTIAHooks().setCurrentStream(c10::Stream::unpack3(
|
||||
stream_id,
|
||||
device_index,
|
||||
static_cast<c10::DeviceType>(device_type)));
|
||||
});
|
||||
|
||||
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
auto device = at::detail::getMTIAHooks().getCurrentDevice();
|
||||
|
@ -204,6 +204,10 @@ def attach_out_of_memory_observer(
|
||||
torch._C._mtia_attachOutOfMemoryObserver(observer)
|
||||
|
||||
|
||||
def is_bf16_supported(including_emulation: bool = True):
|
||||
return True
|
||||
|
||||
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
|
||||
r"""Return capability of a given device as a tuple of (major version, minor version).
|
||||
|
||||
@ -335,6 +339,17 @@ class StreamContext:
|
||||
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _set_stream_by_id(stream_id, device_index, device_type):
|
||||
r"""set stream specified by the stream id, device index and
|
||||
device type
|
||||
|
||||
Args: stream_id (int): stream id in stream pool
|
||||
device_index (int): device index in topo
|
||||
device_type (int): enum device type
|
||||
"""
|
||||
torch._C._mtia_setStream(stream_id, device_index, device_type)
|
||||
|
||||
|
||||
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
|
||||
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
||||
|
||||
@ -392,6 +407,7 @@ __all__ = [
|
||||
"default_stream",
|
||||
"memory_stats",
|
||||
"max_memory_allocated",
|
||||
"memory_allocated",
|
||||
"reset_peak_memory_stats",
|
||||
"get_device_capability",
|
||||
"get_device_properties",
|
||||
@ -405,4 +421,5 @@ __all__ = [
|
||||
"device",
|
||||
"set_rng_state",
|
||||
"get_rng_state",
|
||||
"is_bf16_supported",
|
||||
]
|
||||
|
@ -36,6 +36,19 @@ def max_memory_allocated(device: Optional[_device_t] = None) -> int:
|
||||
return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
|
||||
|
||||
|
||||
def memory_allocated(device: Optional[_device_t] = None) -> int:
|
||||
r"""Return the current MTIA memory occupied by tensors in bytes for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int or str, optional): selected device. Returns
|
||||
statistic for the current device, given by :func:`~torch.mtia.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
"""
|
||||
if not is_initialized():
|
||||
return 0
|
||||
return memory_stats(device).get("dram", 0).get("allocated_bytes", 0)
|
||||
|
||||
|
||||
def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
|
||||
r"""Reset the peak memory stats for a given device.
|
||||
|
||||
@ -53,5 +66,6 @@ def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
|
||||
__all__ = [
|
||||
"memory_stats",
|
||||
"max_memory_allocated",
|
||||
"memory_allocated",
|
||||
"reset_peak_memory_stats",
|
||||
]
|
||||
|
@ -135,6 +135,7 @@ def has_triton() -> bool:
|
||||
"cuda": cuda_extra_check,
|
||||
"xpu": _return_true,
|
||||
"cpu": cpu_extra_check,
|
||||
"mtia": _return_true,
|
||||
}
|
||||
|
||||
def is_device_compatible_with_triton() -> bool:
|
||||
|
Reference in New Issue
Block a user