mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "torch.mtia module for MTIA device backend (#123612)"
This reverts commit d7e1bf9ff908d2a9c20d5354426d34c539fcb7a1. Reverted https://github.com/pytorch/pytorch/pull/123612 on behalf of https://github.com/jeffdaily due to This broke ROCm. see test_overrides.py ([comment](https://github.com/pytorch/pytorch/pull/123611#issuecomment-2067363780))
This commit is contained in:
@ -68,8 +68,6 @@ class TORCH_API Context {
|
||||
return at::detail::getMPSHooks();
|
||||
} else if (device_type == at::kPrivateUse1) {
|
||||
return at::detail::getPrivateUse1Hooks();
|
||||
} else if (device_type == at::kMTIA) {
|
||||
return at::detail::getMTIAHooks();
|
||||
} else {
|
||||
AT_ERROR(
|
||||
c10::DeviceTypeName(device_type), " device type not an accelerator.");
|
||||
@ -154,9 +152,6 @@ class TORCH_API Context {
|
||||
void lazyInitXPU() {
|
||||
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
|
||||
}
|
||||
void lazyInitMTIA() {
|
||||
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
|
||||
}
|
||||
void lazyInitPrivateUse1() {
|
||||
c10::call_once(thp_init, [&] {
|
||||
if (isPrivateUse1HooksRegistered()) {
|
||||
@ -347,7 +342,6 @@ class TORCH_API Context {
|
||||
c10::once_flag thc_init;
|
||||
c10::once_flag thh_init;
|
||||
c10::once_flag thx_init;
|
||||
c10::once_flag th_mtia_init;
|
||||
c10::once_flag thp_init;
|
||||
bool enabled_cudnn = true;
|
||||
bool deterministic_cudnn = false;
|
||||
|
@ -10,9 +10,6 @@ C10_API std::optional<DeviceType> getAccelerator(bool checked) {
|
||||
#define CHECK_NO_PU1 \
|
||||
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
|
||||
|
||||
#define CHECK_NO_MTIA \
|
||||
TORCH_CHECK(!at::hasMTIA(), "Cannot have MTIA with other devices");
|
||||
|
||||
if (is_privateuse1_backend_registered()) {
|
||||
// We explicitly allow PrivateUse1 and another device at the same time
|
||||
// as we use this for testing.
|
||||
@ -20,12 +17,7 @@ C10_API std::optional<DeviceType> getAccelerator(bool checked) {
|
||||
return kPrivateUse1;
|
||||
} else if (at::hasCUDA()) {
|
||||
CHECK_NO_PU1
|
||||
CHECK_NO_MTIA
|
||||
return kCUDA;
|
||||
} else if (at::hasMTIA()) {
|
||||
CHECK_NO_CUDA
|
||||
CHECK_NO_PU1
|
||||
return kMTIA;
|
||||
} else {
|
||||
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
|
||||
return std::nullopt;
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Stream.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
// AcceleratorHooksInterface is a shared interface provided by all
|
||||
@ -16,29 +16,6 @@ struct TORCH_API AcceleratorHooksInterface {
|
||||
|
||||
// Whether the device at device_index is fully initialized or not.
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
|
||||
|
||||
virtual DeviceIndex deviceCount() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual void setCurrentDevice(DeviceIndex device) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support setCurrentDevice()");
|
||||
}
|
||||
|
||||
virtual DeviceIndex getCurrentDevice() const {
|
||||
TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()");
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual DeviceIndex exchangeDevice(DeviceIndex device) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support exchangeDevice()");
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()");
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
@ -8,22 +8,19 @@
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
const MTIAHooksInterface& getMTIAHooks() {
|
||||
static std::unique_ptr<MTIAHooksInterface> mtia_hooks = nullptr;
|
||||
|
||||
const MTIAHooksInterface &getMTIAHooks() {
|
||||
static MTIAHooksInterface* MTIA_hooks = nullptr;
|
||||
static c10::once_flag once;
|
||||
c10::call_once(once, [] {
|
||||
mtia_hooks = MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{});
|
||||
if (!mtia_hooks) {
|
||||
mtia_hooks = std::make_unique<MTIAHooksInterface>();
|
||||
MTIA_hooks =
|
||||
MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{}).release();
|
||||
if (!MTIA_hooks) {
|
||||
MTIA_hooks = new MTIAHooksInterface();
|
||||
}
|
||||
});
|
||||
return *mtia_hooks;
|
||||
return *MTIA_hooks;
|
||||
}
|
||||
|
||||
bool isMTIAHooksBuilt() {
|
||||
return MTIAHooksRegistry()->Has("MTIAHooks");
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs)
|
||||
|
@ -1,9 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
@ -22,72 +20,33 @@ constexpr const char* MTIA_HELP =
|
||||
"to use some MTIA's functionality without MTIA extension included.";
|
||||
|
||||
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
// this fails the implementation if MTIAHooks functions are called, but
|
||||
// MTIA backend is not present.
|
||||
#define FAIL_MTIAHOOKS_FUNC(func) \
|
||||
TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
|
||||
|
||||
virtual ~MTIAHooksInterface() override = default;
|
||||
|
||||
virtual void initMTIA() const {
|
||||
// Avoid logging here, since MTIA needs init devices first then it will know
|
||||
// how many devices are available. Make it as no-op if mtia extension is not
|
||||
// dynamically loaded.
|
||||
return;
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot initialize MTIA without MTIA Extension for PyTorch.",
|
||||
MTIA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasMTIA() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual DeviceIndex deviceCount() const override {
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual void deviceSynchronize(c10::DeviceIndex device_index) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual std::string showConfig() const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot query detailed MTIA version without MTIA Extension for PyTorch.",
|
||||
MTIA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
return false;
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot check MTIA primary context without MTIA Extension for PyTorch.",
|
||||
MTIA_HELP);
|
||||
}
|
||||
|
||||
virtual void setCurrentDevice(DeviceIndex device) const override {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual DeviceIndex getCurrentDevice() const override {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual DeviceIndex exchangeDevice(DeviceIndex device) const override {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual c10::Stream getCurrentStream(DeviceIndex device) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
||||
}
|
||||
|
||||
virtual c10::Stream getDefaultStream(DeviceIndex device) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
||||
}
|
||||
|
||||
virtual void setCurrentStream(const c10::Stream& stream) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API MTIAHooksArgs {};
|
||||
@ -98,6 +57,5 @@ C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
|
||||
|
||||
namespace detail {
|
||||
TORCH_API const MTIAHooksInterface& getMTIAHooks();
|
||||
TORCH_API bool isMTIAHooksBuilt();
|
||||
} // namespace detail
|
||||
} // namespace at
|
||||
|
@ -822,7 +822,6 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
"torch/csrc/mtia/Module.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||
"torch/csrc/jit/backends/backend_init.cpp",
|
||||
"torch/csrc/jit/python/init.cpp",
|
||||
|
@ -69,7 +69,6 @@ Features described in this documentation are classified by release status:
|
||||
torch.cuda.memory <torch_cuda_memory>
|
||||
mps
|
||||
xpu
|
||||
mtia
|
||||
meta
|
||||
torch.backends <backends>
|
||||
torch.export <export>
|
||||
|
@ -1,34 +0,0 @@
|
||||
torch.mtia
|
||||
===================================
|
||||
|
||||
The MTIA backend is implemented out of the tree, only interfaces are be defined here.
|
||||
|
||||
.. automodule:: torch.mtia
|
||||
.. currentmodule:: torch.mtia
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
StreamContext
|
||||
current_device
|
||||
current_stream
|
||||
default_stream
|
||||
device_count
|
||||
init
|
||||
is_available
|
||||
is_initialized
|
||||
set_stream
|
||||
stream
|
||||
synchronize
|
||||
device
|
||||
DeferredMtiaCallError
|
||||
|
||||
Streams and events
|
||||
------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
Event
|
||||
Stream
|
@ -684,7 +684,6 @@ Utilities
|
||||
set_float32_matmul_precision
|
||||
get_float32_matmul_precision
|
||||
set_warn_always
|
||||
get_device_module
|
||||
is_warn_always_enabled
|
||||
vmap
|
||||
_assert
|
||||
|
@ -1700,24 +1700,6 @@ _TensorBase = TensorBase
|
||||
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||
def _multiprocessing_init() -> None: ...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _accelerator_hooks_device_count() -> _int: ...
|
||||
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
|
||||
def _accelerator_hooks_get_current_device() -> _int: ...
|
||||
def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ...
|
||||
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ...
|
||||
def _get_accelerator(check: _bool = False) -> _device: ...
|
||||
|
||||
# Defined in torch/csrc/mtia/Module.cpp
|
||||
def _mtia_init() -> None: ...
|
||||
def _mtia_isBuilt() -> _bool: ...
|
||||
def _mtia_isInBadFork() -> _bool: ...
|
||||
def _mtia_deviceSynchronize() -> None: ...
|
||||
def _mtia_getCurrentStream(device: _int) -> Stream: ...
|
||||
def _mtia_setCurrentStream(stream: Stream) -> None: ...
|
||||
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
||||
|
||||
|
||||
# Defined in torch/csrc/mps/Module.cpp
|
||||
def _mps_deviceSynchronize() -> None: ...
|
||||
def _mps_get_default_generator() -> Generator: ...
|
||||
|
@ -23,7 +23,6 @@ class DeviceType(Enum):
|
||||
FPGA = ...
|
||||
ORT = ...
|
||||
XLA = ...
|
||||
MTIA = ...
|
||||
MPS = ...
|
||||
HPU = ...
|
||||
Meta = ...
|
||||
|
@ -58,7 +58,6 @@ __all__ = [
|
||||
'SymBool', 'sym_not', 'unravel_index',
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
|
||||
'export', 'autocast', 'cond', 'GradScaler',
|
||||
'get_device_module',
|
||||
]
|
||||
|
||||
################################################################################
|
||||
@ -1580,7 +1579,6 @@ from torch import cuda as cuda
|
||||
from torch import cpu as cpu
|
||||
from torch import mps as mps
|
||||
from torch import xpu as xpu
|
||||
from torch import mtia as mtia
|
||||
from torch import autograd as autograd
|
||||
from torch.autograd import (
|
||||
no_grad as no_grad,
|
||||
@ -2018,27 +2016,6 @@ else:
|
||||
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
def get_device_module(device: Optional[Union[torch.device, str]] = None):
|
||||
"""
|
||||
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
|
||||
If no device is given, return the module for the current accelerator or CPU if none is present.
|
||||
"""
|
||||
if isinstance(device, torch.device):
|
||||
device_module_name = device.type
|
||||
elif isinstance(device, str):
|
||||
device_module_name = torch.device(device).type
|
||||
elif device is None:
|
||||
# Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
|
||||
device_module_name = torch._C._get_accelerator().type
|
||||
else:
|
||||
raise RuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None")
|
||||
device_module = getattr(torch, device_module_name, None)
|
||||
if device_module is None:
|
||||
raise RuntimeError(
|
||||
f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
|
||||
)
|
||||
return device_module
|
||||
|
||||
|
||||
def _constrain_as_value(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None):
|
||||
"""
|
||||
|
@ -713,8 +713,6 @@ def _get_available_device_type():
|
||||
return "cuda"
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
|
||||
return "xpu"
|
||||
if hasattr(torch, "mtia") and torch.mtia.is_available():
|
||||
return "mtia"
|
||||
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
custom_device_mod = getattr(torch, custom_backend_name, None)
|
||||
if custom_device_mod and custom_device_mod.is_available():
|
||||
@ -729,8 +727,6 @@ def _get_device_attr(get_member):
|
||||
return get_member(torch.cuda)
|
||||
if device_type and device_type.lower() == "xpu":
|
||||
return get_member(torch.xpu) # type: ignore[attr-defined]
|
||||
if device_type and device_type.lower() == "mtia":
|
||||
return get_member(torch.mtia)
|
||||
if device_type == torch._C._get_privateuse1_backend_name():
|
||||
return get_member(getattr(torch, device_type))
|
||||
# add more available device types here
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <ATen/DeviceAccelerator.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <fmt/core.h>
|
||||
#include <sys/types.h>
|
||||
@ -16,12 +15,10 @@
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/core/Vitals.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/Normalization.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/util/AbortHandler.h>
|
||||
#include <c10/util/Backtrace.h>
|
||||
@ -74,7 +71,6 @@
|
||||
#include <torch/csrc/lazy/python/init.h>
|
||||
#include <torch/csrc/monitor/python_init.h>
|
||||
#include <torch/csrc/mps/Module.h>
|
||||
#include <torch/csrc/mtia/Module.h>
|
||||
#include <torch/csrc/multiprocessing/init.h>
|
||||
#include <torch/csrc/onnx/init.h>
|
||||
#include <torch/csrc/profiler/python/init.h>
|
||||
@ -1644,7 +1640,6 @@ PyObject* initModule() {
|
||||
#ifdef USE_XPU
|
||||
torch::xpu::initModule(module);
|
||||
#endif
|
||||
torch::mtia::initModule(module);
|
||||
torch::cpu::initModule(module);
|
||||
torch::initVerboseBindings(module);
|
||||
ASSERT_TRUE(THPStorage_init(module));
|
||||
@ -1940,70 +1935,6 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
return at::globalContext().linalgPreferredBackend();
|
||||
});
|
||||
|
||||
py_module.def("_accelerator_hooks_device_count", []() {
|
||||
auto device_type = at::getAccelerator();
|
||||
if (device_type.has_value()) {
|
||||
return at::globalContext()
|
||||
.getAcceleratorHooksInterface(device_type.value())
|
||||
.deviceCount();
|
||||
}
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_accelerator_hooks_set_current_device",
|
||||
[](c10::DeviceIndex device_index) {
|
||||
auto device_type = at::getAccelerator();
|
||||
if (device_type.has_value()) {
|
||||
at::globalContext()
|
||||
.getAcceleratorHooksInterface(device_type.value())
|
||||
.setCurrentDevice(device_index);
|
||||
}
|
||||
});
|
||||
|
||||
py_module.def("_accelerator_hooks_get_current_device", []() {
|
||||
auto device_type = at::getAccelerator();
|
||||
if (device_type.has_value()) {
|
||||
return at::globalContext()
|
||||
.getAcceleratorHooksInterface(device_type.value())
|
||||
.getCurrentDevice();
|
||||
}
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_accelerator_hooks_exchange_device", [](c10::DeviceIndex device_index) {
|
||||
auto device_type = at::getAccelerator();
|
||||
if (device_type.has_value()) {
|
||||
return at::globalContext()
|
||||
.getAcceleratorHooksInterface(device_type.value())
|
||||
.exchangeDevice(device_index);
|
||||
}
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_accelerator_hooks_maybe_exchange_device",
|
||||
[](c10::DeviceIndex device_index) {
|
||||
auto device_type = at::getAccelerator();
|
||||
if (device_type.has_value()) {
|
||||
return at::globalContext()
|
||||
.getAcceleratorHooksInterface(device_type.value())
|
||||
.maybeExchangeDevice(device_index);
|
||||
}
|
||||
return c10::DeviceIndex(-1);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_get_accelerator",
|
||||
[](c10::optional<bool> check = c10::nullopt) {
|
||||
return c10::Device(
|
||||
at::getAccelerator(check.value_or(false))
|
||||
.value_or(c10::DeviceType::CPU),
|
||||
-1);
|
||||
},
|
||||
py::arg("check") = nullptr);
|
||||
|
||||
py_module.def(
|
||||
"_construct_storage_from_data_pointer",
|
||||
[](int64_t data_ptr, c10::Device device, size_t size_bytes) {
|
||||
|
@ -1,81 +0,0 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/Stream.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#ifndef WIN32
|
||||
#include <pthread.h>
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace mtia {
|
||||
|
||||
static bool in_bad_fork = false; // True for children forked after mtia init
|
||||
|
||||
#ifndef WIN32
|
||||
// Called in the forked child if mtia has already been initialized
|
||||
static void forked_child() {
|
||||
in_bad_fork = true;
|
||||
torch::utils::set_requires_device_init(at::kMTIA, true);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Should be called before the first mtia call.
|
||||
// Note: This is distinct from initExtension because a stub mtia implementation
|
||||
// has some working functions (e.g. device_count) but cannot fully initialize.
|
||||
static void poison_fork() {
|
||||
#ifndef WIN32
|
||||
static c10::once_flag flag;
|
||||
c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
|
||||
#endif
|
||||
}
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
m.def("_mtia_init", []() {
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
at::globalContext().lazyInitMTIA();
|
||||
});
|
||||
|
||||
m.def("_mtia_isBuilt", []() {
|
||||
// Check if the MTIAHooks class has been registered with the registry.
|
||||
return at::detail::isMTIAHooksBuilt();
|
||||
});
|
||||
|
||||
m.def("_mtia_isInBadFork", []() { return in_bad_fork; });
|
||||
|
||||
m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
return at::detail::getMTIAHooks().getCurrentStream(device_index);
|
||||
});
|
||||
|
||||
m.def("_mtia_deviceSynchronize", [](c10::DeviceIndex device_index) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
at::detail::getMTIAHooks().deviceSynchronize(
|
||||
at::detail::getMTIAHooks().getCurrentDevice());
|
||||
});
|
||||
|
||||
m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
return at::detail::getMTIAHooks().getDefaultStream(device_index);
|
||||
});
|
||||
|
||||
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
auto device = at::detail::getMTIAHooks().getCurrentDevice();
|
||||
if (device != stream.device_index()) {
|
||||
at::detail::getMTIAHooks().setCurrentDevice(stream.device_index());
|
||||
}
|
||||
at::detail::getMTIAHooks().setCurrentStream(stream);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mtia
|
||||
} // namespace torch
|
@ -1,12 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch {
|
||||
namespace mtia {
|
||||
|
||||
// PyMethodDef* python_functions();
|
||||
void initModule(PyObject* module);
|
||||
|
||||
} // namespace mtia
|
||||
} // namespace torch
|
@ -194,12 +194,6 @@ struct type_caster<c10::Stream> {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream"));
|
||||
|
||||
// PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream
|
||||
// cannot be default-initialized, we provide this constructor to explicitly
|
||||
// initialize that field. The value doesn't matter as it will be overwritten
|
||||
// after a successful call to load.
|
||||
type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {}
|
||||
|
||||
bool load(handle src, bool) {
|
||||
PyObject* obj = src.ptr();
|
||||
if (THPStream_Check(obj)) {
|
||||
|
@ -1,262 +0,0 @@
|
||||
r"""
|
||||
This package enables an interface for accessing MTIA backend in python
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.types import Device
|
||||
|
||||
from .. import device as _device, Tensor
|
||||
from .._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||
from ._utils import _get_device_index
|
||||
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
# torch.mtia.Event/Stream is alias of torch.Event/Stream
|
||||
Event = torch.Event
|
||||
Stream = torch.Stream
|
||||
|
||||
_initialized = False
|
||||
_queued_calls: List[
|
||||
Tuple[Callable[[], None], List[str]]
|
||||
] = [] # don't invoke these until initialization occurs
|
||||
_tls = threading.local()
|
||||
_initialization_lock = threading.Lock()
|
||||
_lazy_seed_tracker = _LazySeedTracker()
|
||||
|
||||
|
||||
def init():
|
||||
_lazy_init()
|
||||
|
||||
|
||||
def is_initialized():
|
||||
r"""Return whether PyTorch's MTIA state has been initialized."""
|
||||
return _initialized and not _is_in_bad_fork()
|
||||
|
||||
|
||||
def _is_in_bad_fork() -> bool:
|
||||
return torch._C._mtia_isInBadFork()
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
global _initialized, _queued_calls
|
||||
if is_initialized() or hasattr(_tls, "is_initializing"):
|
||||
return
|
||||
with _initialization_lock:
|
||||
# We be double-checked locking, boys! This is OK because
|
||||
# the above test was GIL protected anyway. The inner test
|
||||
# is for when a thread blocked on some other thread which was
|
||||
# doing the initialization; when they get the lock, they will
|
||||
# find there is nothing left to do.
|
||||
if is_initialized():
|
||||
return
|
||||
# It is important to prevent other threads from entering _lazy_init
|
||||
# immediately, while we are still guaranteed to have the GIL, because some
|
||||
# of the C calls we make below will release the GIL
|
||||
if _is_in_bad_fork():
|
||||
raise RuntimeError(
|
||||
"Cannot re-initialize MTIA in forked subprocess. To use MTIA with "
|
||||
"multiprocessing, you must use the 'spawn' start method"
|
||||
)
|
||||
if not _is_compiled():
|
||||
raise AssertionError("Torch not compiled with MTIA enabled")
|
||||
|
||||
torch._C._mtia_init()
|
||||
# Some of the queued calls may reentrantly call _lazy_init();
|
||||
# we need to just return without initializing in that case.
|
||||
# However, we must not let any *other* threads in!
|
||||
_tls.is_initializing = True
|
||||
|
||||
for calls in _lazy_seed_tracker.get_calls():
|
||||
if calls:
|
||||
_queued_calls.append(calls)
|
||||
|
||||
try:
|
||||
for queued_call, orig_traceback in _queued_calls:
|
||||
try:
|
||||
queued_call()
|
||||
except Exception as e:
|
||||
msg = (
|
||||
f"MTIA call failed lazily at initialization with error: {str(e)}\n\n"
|
||||
f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
||||
)
|
||||
raise DeferredMtiaCallError(msg) from e
|
||||
finally:
|
||||
delattr(_tls, "is_initializing")
|
||||
_initialized = True
|
||||
|
||||
|
||||
class DeferredMtiaCallError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _is_compiled() -> bool:
|
||||
r"""Return true if compiled with MTIA support."""
|
||||
return torch._C._mtia_isBuilt()
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Return true if MTIA device is available"""
|
||||
if not _is_compiled():
|
||||
return False
|
||||
# MTIA has to init devices first to know if there is any devices available.
|
||||
return device_count() > 0
|
||||
|
||||
|
||||
def synchronize() -> None:
|
||||
r"""Waits for all jobs in all streams on a MTIA device to complete."""
|
||||
return torch._C._mtia_deviceSynchronize()
|
||||
|
||||
|
||||
def device_count() -> int:
|
||||
r"""Return the number of MTIA devices available."""
|
||||
return torch._C._accelerator_hooks_device_count()
|
||||
|
||||
|
||||
def current_device() -> int:
|
||||
r"""Return the index of a currently selected device."""
|
||||
return torch._C._accelerator_hooks_get_current_device()
|
||||
|
||||
|
||||
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
||||
r"""Return the currently selected :class:`Stream` for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
the currently selected :class:`Stream` for the current device, given
|
||||
by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
|
||||
(default).
|
||||
"""
|
||||
return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True))
|
||||
|
||||
|
||||
def default_stream(device: Optional[_device_t] = None) -> Stream:
|
||||
r"""Return the default :class:`Stream` for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): selected device. Returns
|
||||
the default :class:`Stream` for the current device, given by
|
||||
:func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
|
||||
(default).
|
||||
"""
|
||||
return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
|
||||
|
||||
|
||||
def set_stream(stream: Stream):
|
||||
r"""Set the current stream.This is a wrapper API to set the stream.
|
||||
Usage of this function is discouraged in favor of the ``stream``
|
||||
context manager.
|
||||
|
||||
Args:
|
||||
stream (Stream): selected stream. This function is a no-op
|
||||
if this argument is ``None``.
|
||||
"""
|
||||
if stream is None:
|
||||
return
|
||||
torch._C._mtia_setCurrentStream(stream)
|
||||
|
||||
|
||||
class device:
|
||||
r"""Context-manager that changes the selected device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int): device index to select. It's a no-op if
|
||||
this argument is a negative integer or ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self, device: Any):
|
||||
self.idx = _get_device_index(device, optional=True)
|
||||
self.prev_idx = -1
|
||||
|
||||
def __enter__(self):
|
||||
self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx)
|
||||
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||
self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx)
|
||||
return False
|
||||
|
||||
|
||||
class StreamContext:
|
||||
r"""Context-manager that selects a given stream.
|
||||
|
||||
All MTIA kernels queued within its context will be enqueued on a selected
|
||||
stream.
|
||||
|
||||
Args:
|
||||
Stream (Stream): selected stream. This manager is a no-op if it's
|
||||
``None``.
|
||||
.. note:: Streams are per-device.
|
||||
"""
|
||||
|
||||
cur_stream: Optional["torch.mtia.Stream"]
|
||||
|
||||
def __init__(self, stream: Optional["torch.mtia.Stream"]):
|
||||
self.stream = stream
|
||||
self.idx = _get_device_index(None, True)
|
||||
if not torch.jit.is_scripting():
|
||||
if self.idx is None:
|
||||
self.idx = -1
|
||||
|
||||
self.src_prev_stream = (
|
||||
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
||||
)
|
||||
self.dst_prev_stream = (
|
||||
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
# Local cur_stream variable for type refinement
|
||||
cur_stream = self.stream
|
||||
# Return if stream is None or MTIA device not available
|
||||
if cur_stream is None or self.idx == -1:
|
||||
return
|
||||
self.src_prev_stream = torch.mtia.current_stream(None)
|
||||
|
||||
# If the stream is not on the current device, then
|
||||
# set the current stream on the device
|
||||
if self.src_prev_stream.device != cur_stream.device:
|
||||
with device(cur_stream.device):
|
||||
self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device)
|
||||
torch.mtia.set_stream(cur_stream)
|
||||
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||
# Local cur_stream variable for type refinement
|
||||
cur_stream = self.stream
|
||||
# If stream is None or no MTIA device available, return
|
||||
if cur_stream is None or self.idx == -1:
|
||||
return
|
||||
|
||||
# Reset the stream on the original device
|
||||
# and destination device
|
||||
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
|
||||
torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
|
||||
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
|
||||
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
||||
|
||||
Arguments:
|
||||
stream (Stream): selected stream. This manager is a no-op if it's
|
||||
``None``.
|
||||
..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream
|
||||
"""
|
||||
return StreamContext(stream)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"init",
|
||||
"is_available",
|
||||
"is_initialized",
|
||||
"synchronize",
|
||||
"device_count",
|
||||
"current_device",
|
||||
"current_stream",
|
||||
"default_stream",
|
||||
"set_stream",
|
||||
"stream",
|
||||
"device",
|
||||
]
|
@ -1,38 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
# The _get_device_index has been moved to torch.utils._get_device_index
|
||||
from torch._utils import _get_device_index as _torch_get_device_index
|
||||
|
||||
|
||||
def _get_device_index(
|
||||
device: Any, optional: bool = False, allow_cpu: bool = False
|
||||
) -> int:
|
||||
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
|
||||
|
||||
If :attr:`device` is a torch.device object, returns the device index if it
|
||||
is a MTIA device. Note that for a MTIA device without a specified index,
|
||||
i.e., ``torch.device('mtia')``, this will return the current default MTIA
|
||||
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
||||
CPU devices will be accepted and ``-1`` will be returned in this case.
|
||||
|
||||
If :attr:`device` is a Python integer, it is returned as is.
|
||||
|
||||
If :attr:`device` is ``None``, this will return the current default MTIA
|
||||
device if :attr:`optional` is ``True``.
|
||||
"""
|
||||
if isinstance(device, int):
|
||||
return device
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if isinstance(device, torch.device):
|
||||
if allow_cpu:
|
||||
if device.type not in ["mtia", "cpu"]:
|
||||
raise ValueError(f"Expected a mtia or cpu device, but got: {device}")
|
||||
elif device.type != "mtia":
|
||||
raise ValueError(f"Expected a mtia device, but got: {device}")
|
||||
if not torch.jit.is_scripting():
|
||||
if isinstance(device, torch.mtia.device):
|
||||
return device.idx
|
||||
return _torch_get_device_index(device, optional, allow_cpu)
|
@ -281,7 +281,6 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.use_deterministic_algorithms,
|
||||
torch.is_deterministic_algorithms_warn_only_enabled,
|
||||
torch.set_deterministic_debug_mode,
|
||||
torch.get_device_module,
|
||||
torch.get_deterministic_debug_mode,
|
||||
torch.set_float32_matmul_precision,
|
||||
torch.get_float32_matmul_precision,
|
||||
|
Reference in New Issue
Block a user