mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #156052 and #156444. This PR setup the privateuseone key in Python to be used as a python backend for pytorch. Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor. Changes done in this PR: 1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine. 2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops. 3. python function that invokes the other incantations to setup the privateusekey backend. This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859 Approved by: https://github.com/albanD
47 lines
1.4 KiB
C++
47 lines
1.4 KiB
C++
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
#include <c10/core/impl/FakeGuardImpl.h>
|
|
#include <array>
|
|
|
|
namespace c10::impl {
|
|
|
|
std::array<
|
|
std::atomic<const DeviceGuardImplInterface*>,
|
|
static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
|
|
device_guard_impl_registry;
|
|
|
|
void registerDeviceGuard(
|
|
DeviceType type,
|
|
const DeviceGuardImplInterface* impl) {
|
|
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
|
|
}
|
|
|
|
DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
|
|
DeviceType type,
|
|
const DeviceGuardImplInterface* impl) {
|
|
registerDeviceGuard(type, impl);
|
|
}
|
|
|
|
namespace {
|
|
thread_local std::unique_ptr<DeviceGuardImplInterface> tls_fake_device_guard =
|
|
nullptr;
|
|
} // namespace
|
|
|
|
void ensureCUDADeviceGuardSet() {
|
|
constexpr auto cuda_idx = static_cast<std::size_t>(DeviceType::CUDA);
|
|
|
|
const DeviceGuardImplInterface* p =
|
|
device_guard_impl_registry[cuda_idx].load();
|
|
|
|
// A non-null `ptr` indicates that the CUDA guard is already set up,
|
|
// implying this is using cuda build
|
|
if (p && p->deviceCount() == 0) {
|
|
// In following cases, we override CUDA guard interface with a no-op
|
|
// device guard. When p->deviceCount() == 0, cuda build is enabled, but no
|
|
// cuda devices available.
|
|
tls_fake_device_guard = std::make_unique<FakeGuardImpl<DeviceType::CUDA>>();
|
|
device_guard_impl_registry[cuda_idx].store(tls_fake_device_guard.get());
|
|
}
|
|
}
|
|
|
|
} // namespace c10::impl
|