Move get accelerator to use build time flags when possible (#146098)

This PR does two main things (they are in a single PR to show how the newly added APIs are used).

- Add isBuilt and isAvailable APIs to the AcceleratorHook interface. See inline doc for their exact semantic
- Use the newly added isBuilt for accelerator check to ensure it does not poison fork

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146098
Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/EikanWang, https://github.com/jeromean

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
albanD
2025-03-10 13:17:55 +00:00
committed by PyTorch MergeBot
parent 098494e9cb
commit 68c12ecfe2
15 changed files with 185 additions and 58 deletions

View File

@ -5,38 +5,53 @@
namespace at::accelerator {
std::optional<c10::DeviceType> getAccelerator(bool checked) {
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
device_type = k##device_name; \
TORCH_CHECK( \
!is_accelerator_detected, \
"Cannot have ", \
device_type.value(), \
" with other accelerators."); \
is_accelerator_detected = true; \
}
// 1. Check PrivateUse1 backends
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
// Note that this check is only for hook registration and thus is NOT initializing
// the device or poisoning fork.
if (is_privateuse1_backend_registered()) {
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
return kPrivateUse1;
}
// 2. Check runtime backends
// This state is temporary, these runtime checks should be moved to compile-time
// once they provide the new isBuilt API and we are sure they're never in the
// same binary as another accelerator.
#define DETECT_RUNTIME_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
return k##device_name; \
}
DETECT_RUNTIME_ACCELERATOR(MTIA)
#undef DETECT_RUNTIME_ACCELERATOR
// 2. Check compile-time backends
std::optional<c10::DeviceType> device_type = std::nullopt;
bool is_accelerator_detected = false;
DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
DETECT_AND_ASSIGN_ACCELERATOR(XPU)
DETECT_AND_ASSIGN_ACCELERATOR(HIP)
DETECT_AND_ASSIGN_ACCELERATOR(MPS)
DETECT_AND_ASSIGN_ACCELERATOR(HPU)
#define DETECT_AND_ASSIGN_ACCELERATOR_COMP(device_name) \
if (at::detail::get##device_name##Hooks().isBuilt()) { \
TORCH_CHECK( \
!device_type.has_value(), \
"Cannot have both " #device_name " and ", \
device_type.value(), "."); \
device_type = k##device_name; \
}
DETECT_AND_ASSIGN_ACCELERATOR_COMP(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(XPU)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HIP)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(MPS)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HPU)
if (checked) {
TORCH_CHECK(
device_type, "Cannot access accelerator device when none is available.")
}
return device_type;
#undef DETECT_AND_ASSIGN_ACCELERATOR
#undef DETECT_AND_ASSIGN_ACCELERATOR_COMP
}
bool isAccelerator(c10::DeviceType device_type) {

View File

@ -33,6 +33,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override;
DeviceIndex current_device() const override;
bool isBuilt() const override {return true;}
bool isAvailable() const override {return hasCUDA();}
bool hasPrimaryContext(DeviceIndex device_index) const override;
Allocator* getCUDADeviceAllocator() const override;
Allocator* getPinnedMemoryAllocator() const override;

View File

@ -20,6 +20,23 @@ struct TORCH_API AcceleratorHooksInterface {
// squelch -Werror=non-virtual-dtor
virtual ~AcceleratorHooksInterface() = default;
// Whether this backend was enabled at compilation time.
// This function should NEVER throw.
virtual bool isBuilt() const {
return false;
}
// Whether this backend can be used at runtime, meaning it was built,
// its runtime dependencies are available (driver) and at least one
// supported device can be used.
// This function should NEVER throw. This function should NOT initialize the context
// on any device (result of hasPrimaryContext below should not change).
// While it is acceptable for this function to poison fork, it is
// recommended to avoid doing so whenever possible.
virtual bool isAvailable() const {
return false;
}
// Whether the device at device_index is fully initialized or not.
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;

View File

@ -54,7 +54,12 @@ struct MPSHooks : public at::MPSHooksInterface {
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
const override;
// Compatibility with Accelerator API
bool isBuilt() const override {
return true;
}
bool isAvailable() const override {
return hasMPS();
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
// When MPS is available, it is always in use for the one device.
return true;

View File

@ -84,9 +84,14 @@ bool XPUHooks::isPinnedPtr(const void* data) const {
sycl::get_pointer_type(data, c10::xpu::get_device_context());
}
bool XPUHooks::isAvailable() const {
return at::xpu::is_available();
}
bool XPUHooks::hasPrimaryContext(DeviceIndex device_index) const {
// The default context is utilized for each device. So it always returns true.
return true;
// The default context is utilized for each device.
// So it always returns true if a device is available.
return isAvailable();
}
DeviceIndex XPUHooks::deviceCount() const {

View File

@ -19,6 +19,11 @@ struct XPUHooks : public at::XPUHooksInterface {
DeviceIndex current_device() const override;
void deviceSynchronize(DeviceIndex device_index) const override;
Allocator* getPinnedMemoryAllocator() const override;
bool isBuilt() const override {
return true;
}
bool isAvailable() const override;
bool isPinnedPtr(const void* data) const override;
bool hasPrimaryContext(DeviceIndex device_index) const override;
DeviceIndex deviceCount() const override;

View File

@ -152,7 +152,19 @@ us to use the current accelerator as the default device for relevant concepts su
Stream device_type, FSDP, etc.
As of today, accelerator devices are (in no particular order) :doc:`"CUDA" <cuda>`, :doc:`"MTIA" <mtia>`,
:doc:`"XPU" <xpu>`, and PrivateUse1 (many device not in the PyTorch repo itself).
:doc:`"XPU" <xpu>`, :doc:`"MPS" <mps>`, "HPU", and PrivateUse1 (many device not in the PyTorch repo itself).
Many tools in the PyTorch Ecosystem use fork to create subprocesses (for example dataloading
or intra-op parallelism), it is thus important to delay as much as possible any
operation that would prevent further forks. This is especially important here as most accelerator's initialization has such effect.
In practice, you should keep in mind that checking :func:`torch.accelerator.current_accelerator`
is a compile-time check by default, it is thus always fork-safe.
On the contrary, passing the ``check_available=True`` flag to this function or calling
:func:`torch.accelerator.is_available()` will usually prevent later fork.
Some backends provide an experimental opt-in option to make the runtime availability
check fork-safe. When using the CUDA device ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` can be
used for example.
.. autosummary::
:toctree: generated

View File

@ -3,6 +3,7 @@
import copy
import os
import pickle
import subprocess
import sys
import tempfile
import threading
@ -1907,6 +1908,35 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
dist.destroy_process_group()
def test_default_process_group(self):
script = """
# Hide all GPUs
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import torch
from torch import distributed as dist
# This should initialize on CPU even though this is a CUDA-enabled build
dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
"""
try:
subprocess.check_output(
[sys.executable, "-c", script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
# It is ok to have an extra long timeout here as a timeout means the test failed
timeout=20,
)
except subprocess.TimeoutExpired:
self.fail(
msg="Example code timed out! See the code sample in the test for details."
)
except subprocess.CalledProcessError as e:
self.fail(f"""Subprocess failed with {e.output.decode("utf-8")}""")
def _call_collective_with_varying_tensors(self, backend, collective, *args):
# call collective with varying tensors to ensure that the tensors are
# correctly dispatched

View File

@ -3597,8 +3597,9 @@ def fork_and_check_is_pinned():
def worker(conn):
try:
x = torch.randn(10)
x.is_pinned(device="cuda")
x = torch.ones(10, device="cuda")[0].item()
x.is_pinned()
dev = torch.accelerator.current_accelerator()
x = torch.ones(10, device=dev)[0].item()
conn.send(x)
except Exception as e:
conn.send(str(e))
@ -3618,7 +3619,7 @@ def fork_and_check_is_pinned():
x = torch.randn(10)
# check that is_pinned won't poison future fork
x.is_pinned(device="cuda")
x.is_pinned()
ret = fork_and_check_is_pinned()
print(ret)

View File

@ -2,6 +2,7 @@ r"""
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
"""
from typing import Optional
from typing_extensions import deprecated
import torch
@ -34,7 +35,9 @@ def device_count() -> int:
def is_available() -> bool:
r"""Check if there is an available :ref:`accelerator<accelerators>`.
r"""Check if the current accelerator is available at runtime: it was build, all the
required drivers are available and at least one device is visible.
See :ref:`accelerator<accelerators>` for details.
Returns:
bool: A boolean indicating if there is an available :ref:`accelerator<accelerators>`.
@ -43,35 +46,47 @@ def is_available() -> bool:
>>> assert torch.accelerator.is_available() "No available accelerators detected."
"""
return device_count() > 0
# Why not just check "device_count() > 0" like other is_available call?
# Because device like CUDA have a python implementation of is_available that is
# non-poisoning and some features like Dataloader rely on it.
# So we are careful to delegate to the Python version of the accelerator here
acc = current_accelerator()
if acc is None:
return False
mod = torch.get_device_module(acc)
return mod.is_available()
def current_accelerator() -> torch.device:
r"""Return the device of the current :ref:`accelerator<accelerators>`.
def current_accelerator(check_available: bool = False) -> Optional[torch.device]:
r"""Return the device of the accelerator available at compilation time.
If no accelerator were available at compilation time, returns None.
See :ref:`accelerator<accelerators>` for details.
Args:
check_available (bool, optional): if True, will also do a runtime check to see
if the device :func:`torch.accelerator.is_available` on top of the compile-time
check.
Default: ``False``
Returns:
torch.device: return the current accelerator as :class:`torch.device`.
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
:func:`torch.accelerator.current_device_index` to know the current index being used.
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
accelerator. If there is no available accelerator, this function will raise an exception.
Example::
>>> # xdoctest:
>>> if torch.accelerator.is_available():
>>> current_device = torch.accelerator.current_accelerator()
>>> else:
>>> current_device = torch.device("cpu")
>>> if current_device.type == 'cuda':
>>> is_half_supported = torch.cuda.has_half
>>> elif current_device.type == 'xpu':
>>> is_half_supported = torch.xpu.get_device_properties().has_fp16
>>> elif current_device.type == 'cpu':
>>> is_half_supported = True
>>> # If an accelerator is available, sent the model to it
>>> model = torch.nn.Linear(2, 2)
>>> if (current_device := current_accelerator(check_available=True)) is not None:
>>> model.to(current_device)
"""
return torch._C._accelerator_getAccelerator()
if (acc := torch._C._accelerator_getAccelerator()) is not None:
if (not check_available) or (check_available and is_available()):
return acc
return None
def current_device_index() -> int:

View File

@ -11,7 +11,10 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int:
device = torch.device(device)
device_index: Optional[int] = None
if isinstance(device, torch.device):
if torch.accelerator.current_accelerator().type != device.type:
acc = torch.accelerator.current_accelerator()
if acc is None:
raise RuntimeError("Accelerator expected")
if acc.type != device.type:
raise ValueError(
f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}."
)

View File

@ -6,9 +6,14 @@ namespace torch::accelerator {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_accelerator_getAccelerator", []() {
// If no accelerator is currently available, raise an exception.
return c10::Device(at::getAccelerator(true).value());
m.def("_accelerator_getAccelerator", []() -> std::optional<c10::Device> {
// If no accelerator was available at compile time, return None.
auto acc = at::getAccelerator(false);
if (acc.has_value()) {
return acc.value();
} else {
return std::nullopt;
}
});
m.def("_accelerator_deviceCount", []() {

View File

@ -2355,10 +2355,17 @@ Call this whenever a new thread is created in order to propagate values from
py_module.def(
"_get_accelerator",
[](std::optional<bool> check = std::nullopt) {
return c10::Device(
at::getAccelerator(check.value_or(false))
.value_or(c10::DeviceType::CPU),
-1);
auto acc = at::getAccelerator(check.value_or(false));
if (acc.has_value()) {
bool is_available = at::globalContext()
.getAcceleratorHooksInterface(acc.value())
.isAvailable();
if (!is_available) {
acc = std::nullopt;
}
}
return c10::Device(acc.value_or(c10::DeviceType::CPU), -1);
},
py::arg("check") = nullptr);

View File

@ -1246,9 +1246,13 @@ def _save(
if (
config.save.use_pinned_memory_for_d2h
and torch.accelerator.is_available()
and torch.accelerator.current_accelerator().type
== storage.device.type
and (
acc := torch.accelerator.current_accelerator(
check_available=True
)
)
is not None
and acc.type == storage.device.type
):
new_storage = torch.empty(
num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True

View File

@ -672,7 +672,8 @@ class _BaseDataLoaderIter:
# memory allocation for MPS is fixed.
if (
self._pin_memory
and torch.accelerator.current_accelerator().type == "mps"
and (acc := torch.accelerator.current_accelerator()) is not None
and acc.type == "mps"
):
self._pin_memory = False
warn_msg = (