Revert "torch.mtia module for MTIA device backend (#123612)"

This reverts commit d7e1bf9ff908d2a9c20d5354426d34c539fcb7a1.

Reverted https://github.com/pytorch/pytorch/pull/123612 on behalf of https://github.com/jeffdaily due to This broke ROCm. see test_overrides.py ([comment](https://github.com/pytorch/pytorch/pull/123611#issuecomment-2067363780))
This commit is contained in:
PyTorch MergeBot
2024-04-19 22:44:26 +00:00
parent 52da03edeb
commit 929242a15c
20 changed files with 21 additions and 655 deletions

View File

@ -68,8 +68,6 @@ class TORCH_API Context {
return at::detail::getMPSHooks();
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks();
} else if (device_type == at::kMTIA) {
return at::detail::getMTIAHooks();
} else {
AT_ERROR(
c10::DeviceTypeName(device_type), " device type not an accelerator.");
@ -154,9 +152,6 @@ class TORCH_API Context {
void lazyInitXPU() {
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
}
void lazyInitMTIA() {
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
}
void lazyInitPrivateUse1() {
c10::call_once(thp_init, [&] {
if (isPrivateUse1HooksRegistered()) {
@ -347,7 +342,6 @@ class TORCH_API Context {
c10::once_flag thc_init;
c10::once_flag thh_init;
c10::once_flag thx_init;
c10::once_flag th_mtia_init;
c10::once_flag thp_init;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;

View File

@ -10,9 +10,6 @@ C10_API std::optional<DeviceType> getAccelerator(bool checked) {
#define CHECK_NO_PU1 \
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
#define CHECK_NO_MTIA \
TORCH_CHECK(!at::hasMTIA(), "Cannot have MTIA with other devices");
if (is_privateuse1_backend_registered()) {
// We explicitly allow PrivateUse1 and another device at the same time
// as we use this for testing.
@ -20,12 +17,7 @@ C10_API std::optional<DeviceType> getAccelerator(bool checked) {
return kPrivateUse1;
} else if (at::hasCUDA()) {
CHECK_NO_PU1
CHECK_NO_MTIA
return kCUDA;
} else if (at::hasMTIA()) {
CHECK_NO_CUDA
CHECK_NO_PU1
return kMTIA;
} else {
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
return std::nullopt;

View File

@ -1,7 +1,7 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/Stream.h>
namespace at {
// AcceleratorHooksInterface is a shared interface provided by all
@ -16,29 +16,6 @@ struct TORCH_API AcceleratorHooksInterface {
// Whether the device at device_index is fully initialized or not.
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
virtual DeviceIndex deviceCount() const {
return 0;
}
virtual void setCurrentDevice(DeviceIndex device) const {
TORCH_CHECK(false, "Backend doesn't support setCurrentDevice()");
}
virtual DeviceIndex getCurrentDevice() const {
TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()");
return -1;
}
virtual DeviceIndex exchangeDevice(DeviceIndex device) const {
TORCH_CHECK(false, "Backend doesn't support exchangeDevice()");
return -1;
}
virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const {
TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()");
return -1;
}
};
} // namespace at

View File

@ -8,22 +8,19 @@
namespace at {
namespace detail {
const MTIAHooksInterface& getMTIAHooks() {
static std::unique_ptr<MTIAHooksInterface> mtia_hooks = nullptr;
const MTIAHooksInterface &getMTIAHooks() {
static MTIAHooksInterface* MTIA_hooks = nullptr;
static c10::once_flag once;
c10::call_once(once, [] {
mtia_hooks = MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{});
if (!mtia_hooks) {
mtia_hooks = std::make_unique<MTIAHooksInterface>();
MTIA_hooks =
MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{}).release();
if (!MTIA_hooks) {
MTIA_hooks = new MTIAHooksInterface();
}
});
return *mtia_hooks;
return *MTIA_hooks;
}
bool isMTIAHooksBuilt() {
return MTIAHooksRegistry()->Has("MTIAHooks");
}
} // namespace detail
C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs)

View File

@ -1,9 +1,7 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <c10/core/Stream.h>
#include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
@ -22,72 +20,33 @@ constexpr const char* MTIA_HELP =
"to use some MTIA's functionality without MTIA extension included.";
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
// this fails the implementation if MTIAHooks functions are called, but
// MTIA backend is not present.
#define FAIL_MTIAHOOKS_FUNC(func) \
TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
virtual ~MTIAHooksInterface() override = default;
virtual void initMTIA() const {
// Avoid logging here, since MTIA needs init devices first then it will know
// how many devices are available. Make it as no-op if mtia extension is not
// dynamically loaded.
return;
TORCH_CHECK(
false,
"Cannot initialize MTIA without MTIA Extension for PyTorch.",
MTIA_HELP);
}
virtual bool hasMTIA() const {
return false;
}
virtual DeviceIndex deviceCount() const override {
return 0;
}
virtual void deviceSynchronize(c10::DeviceIndex device_index) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual std::string showConfig() const {
FAIL_MTIAHOOKS_FUNC(__func__);
TORCH_CHECK(
false,
"Cannot query detailed MTIA version without MTIA Extension for PyTorch.",
MTIA_HELP);
}
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
return false;
TORCH_CHECK(
false,
"Cannot check MTIA primary context without MTIA Extension for PyTorch.",
MTIA_HELP);
}
virtual void setCurrentDevice(DeviceIndex device) const override {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual DeviceIndex getCurrentDevice() const override {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual DeviceIndex exchangeDevice(DeviceIndex device) const override {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual c10::Stream getCurrentStream(DeviceIndex device) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
}
virtual c10::Stream getDefaultStream(DeviceIndex device) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
}
virtual void setCurrentStream(const c10::Stream& stream) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
};
struct TORCH_API MTIAHooksArgs {};
@ -98,6 +57,5 @@ C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
namespace detail {
TORCH_API const MTIAHooksInterface& getMTIAHooks();
TORCH_API bool isMTIAHooksBuilt();
} // namespace detail
} // namespace at

View File

@ -822,7 +822,6 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",
"torch/csrc/jit/backends/backend_init.cpp",
"torch/csrc/jit/python/init.cpp",

View File

@ -69,7 +69,6 @@ Features described in this documentation are classified by release status:
torch.cuda.memory <torch_cuda_memory>
mps
xpu
mtia
meta
torch.backends <backends>
torch.export <export>

View File

@ -1,34 +0,0 @@
torch.mtia
===================================
The MTIA backend is implemented out of the tree, only interfaces are be defined here.
.. automodule:: torch.mtia
.. currentmodule:: torch.mtia
.. autosummary::
:toctree: generated
:nosignatures:
StreamContext
current_device
current_stream
default_stream
device_count
init
is_available
is_initialized
set_stream
stream
synchronize
device
DeferredMtiaCallError
Streams and events
------------------
.. autosummary::
:toctree: generated
:nosignatures:
Event
Stream

View File

@ -684,7 +684,6 @@ Utilities
set_float32_matmul_precision
get_float32_matmul_precision
set_warn_always
get_device_module
is_warn_always_enabled
vmap
_assert

View File

@ -1700,24 +1700,6 @@ _TensorBase = TensorBase
# Defined in torch/csrc/multiprocessing/init.cpp
def _multiprocessing_init() -> None: ...
# Defined in torch/csrc/Module.cpp
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int: ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ...
def _get_accelerator(check: _bool = False) -> _device: ...
# Defined in torch/csrc/mtia/Module.cpp
def _mtia_init() -> None: ...
def _mtia_isBuilt() -> _bool: ...
def _mtia_isInBadFork() -> _bool: ...
def _mtia_deviceSynchronize() -> None: ...
def _mtia_getCurrentStream(device: _int) -> Stream: ...
def _mtia_setCurrentStream(stream: Stream) -> None: ...
def _mtia_getDefaultStream(device: _int) -> Stream: ...
# Defined in torch/csrc/mps/Module.cpp
def _mps_deviceSynchronize() -> None: ...
def _mps_get_default_generator() -> Generator: ...

View File

@ -23,7 +23,6 @@ class DeviceType(Enum):
FPGA = ...
ORT = ...
XLA = ...
MTIA = ...
MPS = ...
HPU = ...
Meta = ...

View File

@ -58,7 +58,6 @@ __all__ = [
'SymBool', 'sym_not', 'unravel_index',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
'export', 'autocast', 'cond', 'GradScaler',
'get_device_module',
]
################################################################################
@ -1580,7 +1579,6 @@ from torch import cuda as cuda
from torch import cpu as cpu
from torch import mps as mps
from torch import xpu as xpu
from torch import mtia as mtia
from torch import autograd as autograd
from torch.autograd import (
no_grad as no_grad,
@ -2018,27 +2016,6 @@ else:
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
def get_device_module(device: Optional[Union[torch.device, str]] = None):
"""
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
If no device is given, return the module for the current accelerator or CPU if none is present.
"""
if isinstance(device, torch.device):
device_module_name = device.type
elif isinstance(device, str):
device_module_name = torch.device(device).type
elif device is None:
# Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
device_module_name = torch._C._get_accelerator().type
else:
raise RuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None")
device_module = getattr(torch, device_module_name, None)
if device_module is None:
raise RuntimeError(
f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
)
return device_module
def _constrain_as_value(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None):
"""

View File

@ -713,8 +713,6 @@ def _get_available_device_type():
return "cuda"
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
return "xpu"
if hasattr(torch, "mtia") and torch.mtia.is_available():
return "mtia"
custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_device_mod = getattr(torch, custom_backend_name, None)
if custom_device_mod and custom_device_mod.is_available():
@ -729,8 +727,6 @@ def _get_device_attr(get_member):
return get_member(torch.cuda)
if device_type and device_type.lower() == "xpu":
return get_member(torch.xpu) # type: ignore[attr-defined]
if device_type and device_type.lower() == "mtia":
return get_member(torch.mtia)
if device_type == torch._C._get_privateuse1_backend_name():
return get_member(getattr(torch, device_type))
# add more available device types here

View File

@ -1,4 +1,3 @@
#include <ATen/DeviceAccelerator.h>
#include <c10/util/Optional.h>
#include <fmt/core.h>
#include <sys/types.h>
@ -16,12 +15,10 @@
#include <ATen/Parallel.h>
#include <ATen/Utils.h>
#include <ATen/core/Vitals.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <ATen/dlpack.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/Normalization.h>
#include <c10/core/Device.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/AbortHandler.h>
#include <c10/util/Backtrace.h>
@ -74,7 +71,6 @@
#include <torch/csrc/lazy/python/init.h>
#include <torch/csrc/monitor/python_init.h>
#include <torch/csrc/mps/Module.h>
#include <torch/csrc/mtia/Module.h>
#include <torch/csrc/multiprocessing/init.h>
#include <torch/csrc/onnx/init.h>
#include <torch/csrc/profiler/python/init.h>
@ -1644,7 +1640,6 @@ PyObject* initModule() {
#ifdef USE_XPU
torch::xpu::initModule(module);
#endif
torch::mtia::initModule(module);
torch::cpu::initModule(module);
torch::initVerboseBindings(module);
ASSERT_TRUE(THPStorage_init(module));
@ -1940,70 +1935,6 @@ Call this whenever a new thread is created in order to propagate values from
return at::globalContext().linalgPreferredBackend();
});
py_module.def("_accelerator_hooks_device_count", []() {
auto device_type = at::getAccelerator();
if (device_type.has_value()) {
return at::globalContext()
.getAcceleratorHooksInterface(device_type.value())
.deviceCount();
}
return c10::DeviceIndex(-1);
});
py_module.def(
"_accelerator_hooks_set_current_device",
[](c10::DeviceIndex device_index) {
auto device_type = at::getAccelerator();
if (device_type.has_value()) {
at::globalContext()
.getAcceleratorHooksInterface(device_type.value())
.setCurrentDevice(device_index);
}
});
py_module.def("_accelerator_hooks_get_current_device", []() {
auto device_type = at::getAccelerator();
if (device_type.has_value()) {
return at::globalContext()
.getAcceleratorHooksInterface(device_type.value())
.getCurrentDevice();
}
return c10::DeviceIndex(-1);
});
py_module.def(
"_accelerator_hooks_exchange_device", [](c10::DeviceIndex device_index) {
auto device_type = at::getAccelerator();
if (device_type.has_value()) {
return at::globalContext()
.getAcceleratorHooksInterface(device_type.value())
.exchangeDevice(device_index);
}
return c10::DeviceIndex(-1);
});
py_module.def(
"_accelerator_hooks_maybe_exchange_device",
[](c10::DeviceIndex device_index) {
auto device_type = at::getAccelerator();
if (device_type.has_value()) {
return at::globalContext()
.getAcceleratorHooksInterface(device_type.value())
.maybeExchangeDevice(device_index);
}
return c10::DeviceIndex(-1);
});
py_module.def(
"_get_accelerator",
[](c10::optional<bool> check = c10::nullopt) {
return c10::Device(
at::getAccelerator(check.value_or(false))
.value_or(c10::DeviceType::CPU),
-1);
},
py::arg("check") = nullptr);
py_module.def(
"_construct_storage_from_data_pointer",
[](int64_t data_ptr, c10::Device device, size_t size_bytes) {

View File

@ -1,81 +0,0 @@
#include <ATen/ATen.h>
#include <c10/util/CallOnce.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#ifndef WIN32
#include <pthread.h>
#endif
namespace torch {
namespace mtia {
static bool in_bad_fork = false; // True for children forked after mtia init
#ifndef WIN32
// Called in the forked child if mtia has already been initialized
static void forked_child() {
in_bad_fork = true;
torch::utils::set_requires_device_init(at::kMTIA, true);
}
#endif
// Should be called before the first mtia call.
// Note: This is distinct from initExtension because a stub mtia implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
#ifndef WIN32
static c10::once_flag flag;
c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
#endif
}
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_mtia_init", []() {
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
poison_fork();
at::globalContext().lazyInitMTIA();
});
m.def("_mtia_isBuilt", []() {
// Check if the MTIAHooks class has been registered with the registry.
return at::detail::isMTIAHooksBuilt();
});
m.def("_mtia_isInBadFork", []() { return in_bad_fork; });
m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getCurrentStream(device_index);
});
m.def("_mtia_deviceSynchronize", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().deviceSynchronize(
at::detail::getMTIAHooks().getCurrentDevice());
});
m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getDefaultStream(device_index);
});
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
torch::utils::device_lazy_init(at::kMTIA);
auto device = at::detail::getMTIAHooks().getCurrentDevice();
if (device != stream.device_index()) {
at::detail::getMTIAHooks().setCurrentDevice(stream.device_index());
}
at::detail::getMTIAHooks().setCurrentStream(stream);
});
}
} // namespace mtia
} // namespace torch

View File

@ -1,12 +0,0 @@
#pragma once
#include <torch/csrc/python_headers.h>
namespace torch {
namespace mtia {
// PyMethodDef* python_functions();
void initModule(PyObject* module);
} // namespace mtia
} // namespace torch

View File

@ -194,12 +194,6 @@ struct type_caster<c10::Stream> {
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream"));
// PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream
// cannot be default-initialized, we provide this constructor to explicitly
// initialize that field. The value doesn't matter as it will be overwritten
// after a successful call to load.
type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {}
bool load(handle src, bool) {
PyObject* obj = src.ptr();
if (THPStream_Check(obj)) {

View File

@ -1,262 +0,0 @@
r"""
This package enables an interface for accessing MTIA backend in python
"""
import threading
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.types import Device
from .. import device as _device, Tensor
from .._utils import _dummy_type, _LazySeedTracker, classproperty
from ._utils import _get_device_index
_device_t = Union[_device, str, int, None]
# torch.mtia.Event/Stream is alias of torch.Event/Stream
Event = torch.Event
Stream = torch.Stream
_initialized = False
_queued_calls: List[
Tuple[Callable[[], None], List[str]]
] = [] # don't invoke these until initialization occurs
_tls = threading.local()
_initialization_lock = threading.Lock()
_lazy_seed_tracker = _LazySeedTracker()
def init():
_lazy_init()
def is_initialized():
r"""Return whether PyTorch's MTIA state has been initialized."""
return _initialized and not _is_in_bad_fork()
def _is_in_bad_fork() -> bool:
return torch._C._mtia_isInBadFork()
def _lazy_init() -> None:
global _initialized, _queued_calls
if is_initialized() or hasattr(_tls, "is_initializing"):
return
with _initialization_lock:
# We be double-checked locking, boys! This is OK because
# the above test was GIL protected anyway. The inner test
# is for when a thread blocked on some other thread which was
# doing the initialization; when they get the lock, they will
# find there is nothing left to do.
if is_initialized():
return
# It is important to prevent other threads from entering _lazy_init
# immediately, while we are still guaranteed to have the GIL, because some
# of the C calls we make below will release the GIL
if _is_in_bad_fork():
raise RuntimeError(
"Cannot re-initialize MTIA in forked subprocess. To use MTIA with "
"multiprocessing, you must use the 'spawn' start method"
)
if not _is_compiled():
raise AssertionError("Torch not compiled with MTIA enabled")
torch._C._mtia_init()
# Some of the queued calls may reentrantly call _lazy_init();
# we need to just return without initializing in that case.
# However, we must not let any *other* threads in!
_tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)
try:
for queued_call, orig_traceback in _queued_calls:
try:
queued_call()
except Exception as e:
msg = (
f"MTIA call failed lazily at initialization with error: {str(e)}\n\n"
f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}"
)
raise DeferredMtiaCallError(msg) from e
finally:
delattr(_tls, "is_initializing")
_initialized = True
class DeferredMtiaCallError(Exception):
pass
def _is_compiled() -> bool:
r"""Return true if compiled with MTIA support."""
return torch._C._mtia_isBuilt()
def is_available() -> bool:
r"""Return true if MTIA device is available"""
if not _is_compiled():
return False
# MTIA has to init devices first to know if there is any devices available.
return device_count() > 0
def synchronize() -> None:
r"""Waits for all jobs in all streams on a MTIA device to complete."""
return torch._C._mtia_deviceSynchronize()
def device_count() -> int:
r"""Return the number of MTIA devices available."""
return torch._C._accelerator_hooks_device_count()
def current_device() -> int:
r"""Return the index of a currently selected device."""
return torch._C._accelerator_hooks_get_current_device()
def current_stream(device: Optional[_device_t] = None) -> Stream:
r"""Return the currently selected :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the currently selected :class:`Stream` for the current device, given
by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
(default).
"""
return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True))
def default_stream(device: Optional[_device_t] = None) -> Stream:
r"""Return the default :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the default :class:`Stream` for the current device, given by
:func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
(default).
"""
return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
def set_stream(stream: Stream):
r"""Set the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream``
context manager.
Args:
stream (Stream): selected stream. This function is a no-op
if this argument is ``None``.
"""
if stream is None:
return
torch._C._mtia_setCurrentStream(stream)
class device:
r"""Context-manager that changes the selected device.
Args:
device (torch.device or int): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device: Any):
self.idx = _get_device_index(device, optional=True)
self.prev_idx = -1
def __enter__(self):
self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx)
return False
class StreamContext:
r"""Context-manager that selects a given stream.
All MTIA kernels queued within its context will be enqueued on a selected
stream.
Args:
Stream (Stream): selected stream. This manager is a no-op if it's
``None``.
.. note:: Streams are per-device.
"""
cur_stream: Optional["torch.mtia.Stream"]
def __init__(self, stream: Optional["torch.mtia.Stream"]):
self.stream = stream
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
self.idx = -1
self.src_prev_stream = (
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
)
self.dst_prev_stream = (
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
)
def __enter__(self):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# Return if stream is None or MTIA device not available
if cur_stream is None or self.idx == -1:
return
self.src_prev_stream = torch.mtia.current_stream(None)
# If the stream is not on the current device, then
# set the current stream on the device
if self.src_prev_stream.device != cur_stream.device:
with device(cur_stream.device):
self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device)
torch.mtia.set_stream(cur_stream)
def __exit__(self, type: Any, value: Any, traceback: Any):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# If stream is None or no MTIA device available, return
if cur_stream is None or self.idx == -1:
return
# Reset the stream on the original device
# and destination device
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
r"""Wrap around the Context-manager StreamContext that selects a given stream.
Arguments:
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream
"""
return StreamContext(stream)
__all__ = [
"init",
"is_available",
"is_initialized",
"synchronize",
"device_count",
"current_device",
"current_stream",
"default_stream",
"set_stream",
"stream",
"device",
]

View File

@ -1,38 +0,0 @@
from typing import Any
import torch
# The _get_device_index has been moved to torch.utils._get_device_index
from torch._utils import _get_device_index as _torch_get_device_index
def _get_device_index(
device: Any, optional: bool = False, allow_cpu: bool = False
) -> int:
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
is a MTIA device. Note that for a MTIA device without a specified index,
i.e., ``torch.device('mtia')``, this will return the current default MTIA
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
CPU devices will be accepted and ``-1`` will be returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default MTIA
device if :attr:`optional` is ``True``.
"""
if isinstance(device, int):
return device
if isinstance(device, str):
device = torch.device(device)
if isinstance(device, torch.device):
if allow_cpu:
if device.type not in ["mtia", "cpu"]:
raise ValueError(f"Expected a mtia or cpu device, but got: {device}")
elif device.type != "mtia":
raise ValueError(f"Expected a mtia device, but got: {device}")
if not torch.jit.is_scripting():
if isinstance(device, torch.mtia.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)

View File

@ -281,7 +281,6 @@ def get_ignored_functions() -> Set[Callable]:
torch.use_deterministic_algorithms,
torch.is_deterministic_algorithms_warn_only_enabled,
torch.set_deterministic_debug_mode,
torch.get_device_module,
torch.get_deterministic_debug_mode,
torch.set_float32_matmul_precision,
torch.get_float32_matmul_precision,