mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Generator register for the privateuse1 backend (#93920)
Fixes #92202 Add generator regiter for the backend of `privateuseone` module: backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/93920 Approved by: https://github.com/bdhirsh
This commit is contained in:
28
aten/src/ATen/core/GeneratorForPrivateuseone.cpp
Normal file
28
aten/src/ATen/core/GeneratorForPrivateuseone.cpp
Normal file
@ -0,0 +1,28 @@
|
||||
#include <mutex>
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
c10::optional<GeneratorFuncType>& GetGeneratorPrivate() {
|
||||
static c10::optional<GeneratorFuncType> generator_privateuse1 = c10::nullopt;
|
||||
return generator_privateuse1;
|
||||
}
|
||||
|
||||
std::mutex _generator_mutex_lock;
|
||||
_GeneratorRegister::_GeneratorRegister(GeneratorFuncType func) {
|
||||
_generator_mutex_lock.lock();
|
||||
TORCH_CHECK(!GetGeneratorPrivate().has_value(),
|
||||
"Only can register a generator to the PrivateUse1 dispatch key once!");
|
||||
auto& m_generator = GetGeneratorPrivate();
|
||||
m_generator = func;
|
||||
_generator_mutex_lock.unlock();
|
||||
}
|
||||
|
||||
at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) {
|
||||
TORCH_CHECK(GetGeneratorPrivate().has_value(),
|
||||
"Please register a generator to the PrivateUse1 dispatch key, \
|
||||
using the REGISTER_GENERATOR_PRIVATEUSE1 macro.");
|
||||
return GetGeneratorPrivate().value()(device_index);
|
||||
}
|
||||
|
||||
}
|
36
aten/src/ATen/core/GeneratorForPrivateuseone.h
Normal file
36
aten/src/ATen/core/GeneratorForPrivateuseone.h
Normal file
@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
|
||||
|
||||
c10::optional<GeneratorFuncType>& GetGeneratorPrivate();
|
||||
|
||||
class TORCH_API _GeneratorRegister{
|
||||
public:
|
||||
_GeneratorRegister(GeneratorFuncType func);
|
||||
};
|
||||
|
||||
TORCH_API at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index);
|
||||
|
||||
/**
|
||||
* This is used to register Generator to PyTorch for `privateuse1` key.
|
||||
* Usage: REGISTER_GENERATOR_PRIVATEUSE1(GeneratorForPrivateuse1)
|
||||
* GeneratorForPrivateuse1 func must return a argument with type of at::Generator.
|
||||
* class CustomGeneratorImpl : public c10::GeneratorImpl {
|
||||
* CustomGeneratorImpl(DeviceIndex device_index = -1);
|
||||
* ~CustomGeneratorImpl() override = default;
|
||||
* ...
|
||||
* }
|
||||
* at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
|
||||
* return at::make_generator<CustomGeneratorImpl>(id);
|
||||
* }
|
||||
* REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
|
||||
*/
|
||||
#define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
|
||||
auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
|
||||
|
||||
}
|
@ -7,7 +7,7 @@
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
|
||||
static uint64_t add_counter = 0;
|
||||
static uint64_t last_saved_value = 0;
|
||||
@ -108,6 +108,25 @@ bool custom_add_called() {
|
||||
return called;
|
||||
}
|
||||
|
||||
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
|
||||
public:
|
||||
// Constructors
|
||||
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
|
||||
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
|
||||
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
|
||||
}
|
||||
~PrivateGeneratorImpl() override = default;
|
||||
};
|
||||
|
||||
// this is used to register generator
|
||||
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
|
||||
return at::make_generator<PrivateGeneratorImpl>(device_index);
|
||||
}
|
||||
|
||||
void register_genertor() {
|
||||
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
|
||||
}
|
||||
|
||||
// Here, we're exposing a custom device object that corresponds to our custom backend.
|
||||
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
|
||||
// that's implemented in C++.
|
||||
@ -115,4 +134,5 @@ bool custom_add_called() {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("custom_device", &get_custom_device, "get custom device object");
|
||||
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
|
||||
m.def("register_genertor", ®ister_genertor, "register generator for custom device");
|
||||
}
|
||||
|
@ -100,5 +100,20 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
# None of our CPU operations should call the custom add function.
|
||||
self.assertFalse(module.custom_add_called())
|
||||
|
||||
# check generator registered befor use
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Please register a generator to the PrivateUse1 dispatch key"):
|
||||
gen_ = torch.Generator(device=device)
|
||||
|
||||
module.register_genertor()
|
||||
|
||||
gen = torch.Generator(device=device)
|
||||
self.assertTrue(gen.device == device)
|
||||
|
||||
# generator can be registered only once
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Only can register a generator to the PrivateUse1 dispatch key once"):
|
||||
module.register_genertor()
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/CPUGeneratorImpl.h>
|
||||
#include <structmember.h>
|
||||
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
@ -68,7 +69,9 @@ static PyObject* THPGenerator_pynew(
|
||||
self->cdata = make_generator<MPSGeneratorImpl>();
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
else if (device.type() == at::kPrivateUse1) {
|
||||
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Device type ",
|
||||
c10::DeviceTypeName(device.type()),
|
||||
|
Reference in New Issue
Block a user