diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 99ed7c53fc10..931cd86e77d9 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -68,8 +68,6 @@ class TORCH_API Context { return at::detail::getMPSHooks(); } else if (device_type == at::kPrivateUse1) { return at::detail::getPrivateUse1Hooks(); - } else if (device_type == at::kMTIA) { - return at::detail::getMTIAHooks(); } else { AT_ERROR( c10::DeviceTypeName(device_type), " device type not an accelerator."); @@ -154,9 +152,6 @@ class TORCH_API Context { void lazyInitXPU() { c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); }); } - void lazyInitMTIA() { - c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); }); - } void lazyInitPrivateUse1() { c10::call_once(thp_init, [&] { if (isPrivateUse1HooksRegistered()) { @@ -347,7 +342,6 @@ class TORCH_API Context { c10::once_flag thc_init; c10::once_flag thh_init; c10::once_flag thx_init; - c10::once_flag th_mtia_init; c10::once_flag thp_init; bool enabled_cudnn = true; bool deterministic_cudnn = false; diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index ec3cd2a2f552..05327cc219ef 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -10,9 +10,6 @@ C10_API std::optional getAccelerator(bool checked) { #define CHECK_NO_PU1 \ TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1"); -#define CHECK_NO_MTIA \ - TORCH_CHECK(!at::hasMTIA(), "Cannot have MTIA with other devices"); - if (is_privateuse1_backend_registered()) { // We explicitly allow PrivateUse1 and another device at the same time // as we use this for testing. @@ -20,12 +17,7 @@ C10_API std::optional getAccelerator(bool checked) { return kPrivateUse1; } else if (at::hasCUDA()) { CHECK_NO_PU1 - CHECK_NO_MTIA return kCUDA; - } else if (at::hasMTIA()) { - CHECK_NO_CUDA - CHECK_NO_PU1 - return kMTIA; } else { TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.") return std::nullopt; diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index 96e15e1f69da..c099c9f59a61 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -1,7 +1,7 @@ #pragma once #include -#include + namespace at { // AcceleratorHooksInterface is a shared interface provided by all @@ -16,29 +16,6 @@ struct TORCH_API AcceleratorHooksInterface { // Whether the device at device_index is fully initialized or not. virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0; - - virtual DeviceIndex deviceCount() const { - return 0; - } - - virtual void setCurrentDevice(DeviceIndex device) const { - TORCH_CHECK(false, "Backend doesn't support setCurrentDevice()"); - } - - virtual DeviceIndex getCurrentDevice() const { - TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()"); - return -1; - } - - virtual DeviceIndex exchangeDevice(DeviceIndex device) const { - TORCH_CHECK(false, "Backend doesn't support exchangeDevice()"); - return -1; - } - - virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const { - TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()"); - return -1; - } }; } // namespace at diff --git a/aten/src/ATen/detail/MTIAHooksInterface.cpp b/aten/src/ATen/detail/MTIAHooksInterface.cpp index 096388171386..6b69fdb03f3d 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.cpp +++ b/aten/src/ATen/detail/MTIAHooksInterface.cpp @@ -8,22 +8,19 @@ namespace at { namespace detail { -const MTIAHooksInterface& getMTIAHooks() { - static std::unique_ptr mtia_hooks = nullptr; + +const MTIAHooksInterface &getMTIAHooks() { + static MTIAHooksInterface* MTIA_hooks = nullptr; static c10::once_flag once; c10::call_once(once, [] { - mtia_hooks = MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{}); - if (!mtia_hooks) { - mtia_hooks = std::make_unique(); + MTIA_hooks = + MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{}).release(); + if (!MTIA_hooks) { + MTIA_hooks = new MTIAHooksInterface(); } }); - return *mtia_hooks; + return *MTIA_hooks; } - -bool isMTIAHooksBuilt() { - return MTIAHooksRegistry()->Has("MTIAHooks"); -} - } // namespace detail C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs) diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index 1da1bda4e613..c843ca52c2b4 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -1,9 +1,7 @@ #pragma once -#include #include -#include #include #include @@ -22,72 +20,33 @@ constexpr const char* MTIA_HELP = "to use some MTIA's functionality without MTIA extension included."; struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { -// this fails the implementation if MTIAHooks functions are called, but -// MTIA backend is not present. -#define FAIL_MTIAHOOKS_FUNC(func) \ - TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend."); - virtual ~MTIAHooksInterface() override = default; virtual void initMTIA() const { - // Avoid logging here, since MTIA needs init devices first then it will know - // how many devices are available. Make it as no-op if mtia extension is not - // dynamically loaded. - return; + TORCH_CHECK( + false, + "Cannot initialize MTIA without MTIA Extension for PyTorch.", + MTIA_HELP); } virtual bool hasMTIA() const { return false; } - virtual DeviceIndex deviceCount() const override { - return 0; - } - - virtual void deviceSynchronize(c10::DeviceIndex device_index) const { - FAIL_MTIAHOOKS_FUNC(__func__); - } - virtual std::string showConfig() const { - FAIL_MTIAHOOKS_FUNC(__func__); + TORCH_CHECK( + false, + "Cannot query detailed MTIA version without MTIA Extension for PyTorch.", + MTIA_HELP); } virtual bool hasPrimaryContext(DeviceIndex device_index) const override { - return false; + TORCH_CHECK( + false, + "Cannot check MTIA primary context without MTIA Extension for PyTorch.", + MTIA_HELP); } - virtual void setCurrentDevice(DeviceIndex device) const override { - FAIL_MTIAHOOKS_FUNC(__func__); - } - - virtual DeviceIndex getCurrentDevice() const override { - FAIL_MTIAHOOKS_FUNC(__func__); - return -1; - } - - virtual DeviceIndex exchangeDevice(DeviceIndex device) const override { - FAIL_MTIAHOOKS_FUNC(__func__); - return -1; - } - - virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const override { - FAIL_MTIAHOOKS_FUNC(__func__); - return -1; - } - - virtual c10::Stream getCurrentStream(DeviceIndex device) const { - FAIL_MTIAHOOKS_FUNC(__func__); - return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); - } - - virtual c10::Stream getDefaultStream(DeviceIndex device) const { - FAIL_MTIAHOOKS_FUNC(__func__); - return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); - } - - virtual void setCurrentStream(const c10::Stream& stream) const { - FAIL_MTIAHOOKS_FUNC(__func__); - } }; struct TORCH_API MTIAHooksArgs {}; @@ -98,6 +57,5 @@ C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs); namespace detail { TORCH_API const MTIAHooksInterface& getMTIAHooks(); -TORCH_API bool isMTIAHooksBuilt(); } // namespace detail } // namespace at diff --git a/build_variables.bzl b/build_variables.bzl index 36e54ffda40f..c7bddeaa3bbb 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -822,7 +822,6 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", "torch/csrc/mps/Module.cpp", - "torch/csrc/mtia/Module.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/jit/backends/backend_init.cpp", "torch/csrc/jit/python/init.cpp", diff --git a/docs/source/index.rst b/docs/source/index.rst index a7afe60bc287..9e7cc6a9a6dd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -69,7 +69,6 @@ Features described in this documentation are classified by release status: torch.cuda.memory mps xpu - mtia meta torch.backends torch.export diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst deleted file mode 100644 index f2f5b5195dcb..000000000000 --- a/docs/source/mtia.rst +++ /dev/null @@ -1,34 +0,0 @@ -torch.mtia -=================================== - -The MTIA backend is implemented out of the tree, only interfaces are be defined here. - -.. automodule:: torch.mtia -.. currentmodule:: torch.mtia - -.. autosummary:: - :toctree: generated - :nosignatures: - - StreamContext - current_device - current_stream - default_stream - device_count - init - is_available - is_initialized - set_stream - stream - synchronize - device - DeferredMtiaCallError - -Streams and events ------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - Event - Stream diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 32bcadc15452..b65a7a523983 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -684,7 +684,6 @@ Utilities set_float32_matmul_precision get_float32_matmul_precision set_warn_always - get_device_module is_warn_always_enabled vmap _assert diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 583bd384ed11..4d3f2b64ff7f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1700,24 +1700,6 @@ _TensorBase = TensorBase # Defined in torch/csrc/multiprocessing/init.cpp def _multiprocessing_init() -> None: ... -# Defined in torch/csrc/Module.cpp -def _accelerator_hooks_device_count() -> _int: ... -def _accelerator_hooks_set_current_device(device_index: _int) -> None: ... -def _accelerator_hooks_get_current_device() -> _int: ... -def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ... -def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ... -def _get_accelerator(check: _bool = False) -> _device: ... - -# Defined in torch/csrc/mtia/Module.cpp -def _mtia_init() -> None: ... -def _mtia_isBuilt() -> _bool: ... -def _mtia_isInBadFork() -> _bool: ... -def _mtia_deviceSynchronize() -> None: ... -def _mtia_getCurrentStream(device: _int) -> Stream: ... -def _mtia_setCurrentStream(stream: Stream) -> None: ... -def _mtia_getDefaultStream(device: _int) -> Stream: ... - - # Defined in torch/csrc/mps/Module.cpp def _mps_deviceSynchronize() -> None: ... def _mps_get_default_generator() -> Generator: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 7e503a8e90ea..92b21f96dff6 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -23,7 +23,6 @@ class DeviceType(Enum): FPGA = ... ORT = ... XLA = ... - MTIA = ... MPS = ... HPU = ... Meta = ... diff --git a/torch/__init__.py b/torch/__init__.py index 846038e35103..9a7249f22026 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -58,7 +58,6 @@ __all__ = [ 'SymBool', 'sym_not', 'unravel_index', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap', 'export', 'autocast', 'cond', 'GradScaler', - 'get_device_module', ] ################################################################################ @@ -1580,7 +1579,6 @@ from torch import cuda as cuda from torch import cpu as cpu from torch import mps as mps from torch import xpu as xpu -from torch import mtia as mtia from torch import autograd as autograd from torch.autograd import ( no_grad as no_grad, @@ -2018,27 +2016,6 @@ else: raise AttributeError(f"module '{__name__}' has no attribute '{name}'") -def get_device_module(device: Optional[Union[torch.device, str]] = None): - """ - Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). - If no device is given, return the module for the current accelerator or CPU if none is present. - """ - if isinstance(device, torch.device): - device_module_name = device.type - elif isinstance(device, str): - device_module_name = torch.device(device).type - elif device is None: - # Using default accelerator type. If no accelerator is available, it automatically returns CPU device. - device_module_name = torch._C._get_accelerator().type - else: - raise RuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None") - device_module = getattr(torch, device_module_name, None) - if device_module is None: - raise RuntimeError( - f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'." - ) - return device_module - def _constrain_as_value(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None): """ diff --git a/torch/_utils.py b/torch/_utils.py index 43c6284d2414..7f9a1af43fe4 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -713,8 +713,6 @@ def _get_available_device_type(): return "cuda" if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined] return "xpu" - if hasattr(torch, "mtia") and torch.mtia.is_available(): - return "mtia" custom_backend_name = torch._C._get_privateuse1_backend_name() custom_device_mod = getattr(torch, custom_backend_name, None) if custom_device_mod and custom_device_mod.is_available(): @@ -729,8 +727,6 @@ def _get_device_attr(get_member): return get_member(torch.cuda) if device_type and device_type.lower() == "xpu": return get_member(torch.xpu) # type: ignore[attr-defined] - if device_type and device_type.lower() == "mtia": - return get_member(torch.mtia) if device_type == torch._C._get_privateuse1_backend_name(): return get_member(getattr(torch, device_type)) # add more available device types here diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index dd7b74c9099c..8aff73047f12 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -16,12 +15,10 @@ #include #include #include -#include #include #include #include #include -#include #include #include #include @@ -74,7 +71,6 @@ #include #include #include -#include #include #include #include @@ -1644,7 +1640,6 @@ PyObject* initModule() { #ifdef USE_XPU torch::xpu::initModule(module); #endif - torch::mtia::initModule(module); torch::cpu::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); @@ -1940,70 +1935,6 @@ Call this whenever a new thread is created in order to propagate values from return at::globalContext().linalgPreferredBackend(); }); - py_module.def("_accelerator_hooks_device_count", []() { - auto device_type = at::getAccelerator(); - if (device_type.has_value()) { - return at::globalContext() - .getAcceleratorHooksInterface(device_type.value()) - .deviceCount(); - } - return c10::DeviceIndex(-1); - }); - - py_module.def( - "_accelerator_hooks_set_current_device", - [](c10::DeviceIndex device_index) { - auto device_type = at::getAccelerator(); - if (device_type.has_value()) { - at::globalContext() - .getAcceleratorHooksInterface(device_type.value()) - .setCurrentDevice(device_index); - } - }); - - py_module.def("_accelerator_hooks_get_current_device", []() { - auto device_type = at::getAccelerator(); - if (device_type.has_value()) { - return at::globalContext() - .getAcceleratorHooksInterface(device_type.value()) - .getCurrentDevice(); - } - return c10::DeviceIndex(-1); - }); - - py_module.def( - "_accelerator_hooks_exchange_device", [](c10::DeviceIndex device_index) { - auto device_type = at::getAccelerator(); - if (device_type.has_value()) { - return at::globalContext() - .getAcceleratorHooksInterface(device_type.value()) - .exchangeDevice(device_index); - } - return c10::DeviceIndex(-1); - }); - - py_module.def( - "_accelerator_hooks_maybe_exchange_device", - [](c10::DeviceIndex device_index) { - auto device_type = at::getAccelerator(); - if (device_type.has_value()) { - return at::globalContext() - .getAcceleratorHooksInterface(device_type.value()) - .maybeExchangeDevice(device_index); - } - return c10::DeviceIndex(-1); - }); - - py_module.def( - "_get_accelerator", - [](c10::optional check = c10::nullopt) { - return c10::Device( - at::getAccelerator(check.value_or(false)) - .value_or(c10::DeviceType::CPU), - -1); - }, - py::arg("check") = nullptr); - py_module.def( "_construct_storage_from_data_pointer", [](int64_t data_ptr, c10::Device device, size_t size_bytes) { diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp deleted file mode 100644 index 84cc11f71875..000000000000 --- a/torch/csrc/mtia/Module.cpp +++ /dev/null @@ -1,81 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#ifndef WIN32 -#include -#endif - -namespace torch { -namespace mtia { - -static bool in_bad_fork = false; // True for children forked after mtia init - -#ifndef WIN32 -// Called in the forked child if mtia has already been initialized -static void forked_child() { - in_bad_fork = true; - torch::utils::set_requires_device_init(at::kMTIA, true); -} -#endif - -// Should be called before the first mtia call. -// Note: This is distinct from initExtension because a stub mtia implementation -// has some working functions (e.g. device_count) but cannot fully initialize. -static void poison_fork() { -#ifndef WIN32 - static c10::once_flag flag; - c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); -#endif -} - -void initModule(PyObject* module) { - auto m = py::handle(module).cast(); - - m.def("_mtia_init", []() { - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level - poison_fork(); - at::globalContext().lazyInitMTIA(); - }); - - m.def("_mtia_isBuilt", []() { - // Check if the MTIAHooks class has been registered with the registry. - return at::detail::isMTIAHooksBuilt(); - }); - - m.def("_mtia_isInBadFork", []() { return in_bad_fork; }); - - m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) { - torch::utils::device_lazy_init(at::kMTIA); - return at::detail::getMTIAHooks().getCurrentStream(device_index); - }); - - m.def("_mtia_deviceSynchronize", [](c10::DeviceIndex device_index) { - torch::utils::device_lazy_init(at::kMTIA); - at::detail::getMTIAHooks().deviceSynchronize( - at::detail::getMTIAHooks().getCurrentDevice()); - }); - - m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) { - torch::utils::device_lazy_init(at::kMTIA); - return at::detail::getMTIAHooks().getDefaultStream(device_index); - }); - - m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) { - torch::utils::device_lazy_init(at::kMTIA); - auto device = at::detail::getMTIAHooks().getCurrentDevice(); - if (device != stream.device_index()) { - at::detail::getMTIAHooks().setCurrentDevice(stream.device_index()); - } - at::detail::getMTIAHooks().setCurrentStream(stream); - }); -} - -} // namespace mtia -} // namespace torch diff --git a/torch/csrc/mtia/Module.h b/torch/csrc/mtia/Module.h deleted file mode 100644 index 96a98ed448e1..000000000000 --- a/torch/csrc/mtia/Module.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace mtia { - -// PyMethodDef* python_functions(); -void initModule(PyObject* module); - -} // namespace mtia -} // namespace torch diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 1a4e7bb26fc0..36cb83659aa2 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -194,12 +194,6 @@ struct type_caster { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream")); - // PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream - // cannot be default-initialized, we provide this constructor to explicitly - // initialize that field. The value doesn't matter as it will be overwritten - // after a successful call to load. - type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {} - bool load(handle src, bool) { PyObject* obj = src.ptr(); if (THPStream_Check(obj)) { diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py deleted file mode 100644 index 4007f0e584f2..000000000000 --- a/torch/mtia/__init__.py +++ /dev/null @@ -1,262 +0,0 @@ -r""" -This package enables an interface for accessing MTIA backend in python -""" - -import threading -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch - -from torch.types import Device - -from .. import device as _device, Tensor -from .._utils import _dummy_type, _LazySeedTracker, classproperty -from ._utils import _get_device_index - -_device_t = Union[_device, str, int, None] - -# torch.mtia.Event/Stream is alias of torch.Event/Stream -Event = torch.Event -Stream = torch.Stream - -_initialized = False -_queued_calls: List[ - Tuple[Callable[[], None], List[str]] -] = [] # don't invoke these until initialization occurs -_tls = threading.local() -_initialization_lock = threading.Lock() -_lazy_seed_tracker = _LazySeedTracker() - - -def init(): - _lazy_init() - - -def is_initialized(): - r"""Return whether PyTorch's MTIA state has been initialized.""" - return _initialized and not _is_in_bad_fork() - - -def _is_in_bad_fork() -> bool: - return torch._C._mtia_isInBadFork() - - -def _lazy_init() -> None: - global _initialized, _queued_calls - if is_initialized() or hasattr(_tls, "is_initializing"): - return - with _initialization_lock: - # We be double-checked locking, boys! This is OK because - # the above test was GIL protected anyway. The inner test - # is for when a thread blocked on some other thread which was - # doing the initialization; when they get the lock, they will - # find there is nothing left to do. - if is_initialized(): - return - # It is important to prevent other threads from entering _lazy_init - # immediately, while we are still guaranteed to have the GIL, because some - # of the C calls we make below will release the GIL - if _is_in_bad_fork(): - raise RuntimeError( - "Cannot re-initialize MTIA in forked subprocess. To use MTIA with " - "multiprocessing, you must use the 'spawn' start method" - ) - if not _is_compiled(): - raise AssertionError("Torch not compiled with MTIA enabled") - - torch._C._mtia_init() - # Some of the queued calls may reentrantly call _lazy_init(); - # we need to just return without initializing in that case. - # However, we must not let any *other* threads in! - _tls.is_initializing = True - - for calls in _lazy_seed_tracker.get_calls(): - if calls: - _queued_calls.append(calls) - - try: - for queued_call, orig_traceback in _queued_calls: - try: - queued_call() - except Exception as e: - msg = ( - f"MTIA call failed lazily at initialization with error: {str(e)}\n\n" - f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}" - ) - raise DeferredMtiaCallError(msg) from e - finally: - delattr(_tls, "is_initializing") - _initialized = True - - -class DeferredMtiaCallError(Exception): - pass - - -def _is_compiled() -> bool: - r"""Return true if compiled with MTIA support.""" - return torch._C._mtia_isBuilt() - - -def is_available() -> bool: - r"""Return true if MTIA device is available""" - if not _is_compiled(): - return False - # MTIA has to init devices first to know if there is any devices available. - return device_count() > 0 - - -def synchronize() -> None: - r"""Waits for all jobs in all streams on a MTIA device to complete.""" - return torch._C._mtia_deviceSynchronize() - - -def device_count() -> int: - r"""Return the number of MTIA devices available.""" - return torch._C._accelerator_hooks_device_count() - - -def current_device() -> int: - r"""Return the index of a currently selected device.""" - return torch._C._accelerator_hooks_get_current_device() - - -def current_stream(device: Optional[_device_t] = None) -> Stream: - r"""Return the currently selected :class:`Stream` for a given device. - - Args: - device (torch.device or int, optional): selected device. Returns - the currently selected :class:`Stream` for the current device, given - by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` - (default). - """ - return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True)) - - -def default_stream(device: Optional[_device_t] = None) -> Stream: - r"""Return the default :class:`Stream` for a given device. - - Args: - device (torch.device or int, optional): selected device. Returns - the default :class:`Stream` for the current device, given by - :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` - (default). - """ - return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True)) - - -def set_stream(stream: Stream): - r"""Set the current stream.This is a wrapper API to set the stream. - Usage of this function is discouraged in favor of the ``stream`` - context manager. - - Args: - stream (Stream): selected stream. This function is a no-op - if this argument is ``None``. - """ - if stream is None: - return - torch._C._mtia_setCurrentStream(stream) - - -class device: - r"""Context-manager that changes the selected device. - - Args: - device (torch.device or int): device index to select. It's a no-op if - this argument is a negative integer or ``None``. - """ - - def __init__(self, device: Any): - self.idx = _get_device_index(device, optional=True) - self.prev_idx = -1 - - def __enter__(self): - self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx) - - def __exit__(self, type: Any, value: Any, traceback: Any): - self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx) - return False - - -class StreamContext: - r"""Context-manager that selects a given stream. - - All MTIA kernels queued within its context will be enqueued on a selected - stream. - - Args: - Stream (Stream): selected stream. This manager is a no-op if it's - ``None``. - .. note:: Streams are per-device. - """ - - cur_stream: Optional["torch.mtia.Stream"] - - def __init__(self, stream: Optional["torch.mtia.Stream"]): - self.stream = stream - self.idx = _get_device_index(None, True) - if not torch.jit.is_scripting(): - if self.idx is None: - self.idx = -1 - - self.src_prev_stream = ( - None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) - ) - self.dst_prev_stream = ( - None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) - ) - - def __enter__(self): - # Local cur_stream variable for type refinement - cur_stream = self.stream - # Return if stream is None or MTIA device not available - if cur_stream is None or self.idx == -1: - return - self.src_prev_stream = torch.mtia.current_stream(None) - - # If the stream is not on the current device, then - # set the current stream on the device - if self.src_prev_stream.device != cur_stream.device: - with device(cur_stream.device): - self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device) - torch.mtia.set_stream(cur_stream) - - def __exit__(self, type: Any, value: Any, traceback: Any): - # Local cur_stream variable for type refinement - cur_stream = self.stream - # If stream is None or no MTIA device available, return - if cur_stream is None or self.idx == -1: - return - - # Reset the stream on the original device - # and destination device - if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] - torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type] - torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type] - - -def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: - r"""Wrap around the Context-manager StreamContext that selects a given stream. - - Arguments: - stream (Stream): selected stream. This manager is a no-op if it's - ``None``. - ..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream - """ - return StreamContext(stream) - - -__all__ = [ - "init", - "is_available", - "is_initialized", - "synchronize", - "device_count", - "current_device", - "current_stream", - "default_stream", - "set_stream", - "stream", - "device", -] diff --git a/torch/mtia/_utils.py b/torch/mtia/_utils.py deleted file mode 100644 index 090e26f32123..000000000000 --- a/torch/mtia/_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Any - -import torch - -# The _get_device_index has been moved to torch.utils._get_device_index -from torch._utils import _get_device_index as _torch_get_device_index - - -def _get_device_index( - device: Any, optional: bool = False, allow_cpu: bool = False -) -> int: - r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. - - If :attr:`device` is a torch.device object, returns the device index if it - is a MTIA device. Note that for a MTIA device without a specified index, - i.e., ``torch.device('mtia')``, this will return the current default MTIA - device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, - CPU devices will be accepted and ``-1`` will be returned in this case. - - If :attr:`device` is a Python integer, it is returned as is. - - If :attr:`device` is ``None``, this will return the current default MTIA - device if :attr:`optional` is ``True``. - """ - if isinstance(device, int): - return device - if isinstance(device, str): - device = torch.device(device) - if isinstance(device, torch.device): - if allow_cpu: - if device.type not in ["mtia", "cpu"]: - raise ValueError(f"Expected a mtia or cpu device, but got: {device}") - elif device.type != "mtia": - raise ValueError(f"Expected a mtia device, but got: {device}") - if not torch.jit.is_scripting(): - if isinstance(device, torch.mtia.device): - return device.idx - return _torch_get_device_index(device, optional, allow_cpu) diff --git a/torch/overrides.py b/torch/overrides.py index 4ce254880019..6a5d3e891dc8 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -281,7 +281,6 @@ def get_ignored_functions() -> Set[Callable]: torch.use_deterministic_algorithms, torch.is_deterministic_algorithms_warn_only_enabled, torch.set_deterministic_debug_mode, - torch.get_device_module, torch.get_deterministic_debug_mode, torch.set_float32_matmul_precision, torch.get_float32_matmul_precision,