Add Generator register for the privateuse1 backend (#93920)

Fixes #92202
Add generator regiter for the backend of `privateuseone`

module: backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93920
Approved by: https://github.com/bdhirsh
This commit is contained in:
shibo
2023-03-07 03:43:23 +00:00
committed by PyTorch MergeBot
parent e9ca902894
commit 7038458c5b
5 changed files with 104 additions and 2 deletions

View File

@ -0,0 +1,28 @@
#include <mutex>
#include <ATen/core/GeneratorForPrivateuseone.h>
namespace at {
c10::optional<GeneratorFuncType>& GetGeneratorPrivate() {
static c10::optional<GeneratorFuncType> generator_privateuse1 = c10::nullopt;
return generator_privateuse1;
}
std::mutex _generator_mutex_lock;
_GeneratorRegister::_GeneratorRegister(GeneratorFuncType func) {
_generator_mutex_lock.lock();
TORCH_CHECK(!GetGeneratorPrivate().has_value(),
"Only can register a generator to the PrivateUse1 dispatch key once!");
auto& m_generator = GetGeneratorPrivate();
m_generator = func;
_generator_mutex_lock.unlock();
}
at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) {
TORCH_CHECK(GetGeneratorPrivate().has_value(),
"Please register a generator to the PrivateUse1 dispatch key, \
using the REGISTER_GENERATOR_PRIVATEUSE1 macro.");
return GetGeneratorPrivate().value()(device_index);
}
}

View File

@ -0,0 +1,36 @@
#pragma once
#include <ATen/core/Generator.h>
#include <c10/util/intrusive_ptr.h>
namespace at {
using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
c10::optional<GeneratorFuncType>& GetGeneratorPrivate();
class TORCH_API _GeneratorRegister{
public:
_GeneratorRegister(GeneratorFuncType func);
};
TORCH_API at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index);
/**
* This is used to register Generator to PyTorch for `privateuse1` key.
* Usage: REGISTER_GENERATOR_PRIVATEUSE1(GeneratorForPrivateuse1)
* GeneratorForPrivateuse1 func must return a argument with type of at::Generator.
* class CustomGeneratorImpl : public c10::GeneratorImpl {
* CustomGeneratorImpl(DeviceIndex device_index = -1);
* ~CustomGeneratorImpl() override = default;
* ...
* }
* at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
* return at::make_generator<CustomGeneratorImpl>(id);
* }
* REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
*/
#define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
}

View File

@ -7,7 +7,7 @@
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
static uint64_t add_counter = 0;
static uint64_t last_saved_value = 0;
@ -108,6 +108,25 @@ bool custom_add_called() {
return called;
}
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
public:
// Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
}
~PrivateGeneratorImpl() override = default;
};
// this is used to register generator
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
return at::make_generator<PrivateGeneratorImpl>(device_index);
}
void register_genertor() {
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
}
// Here, we're exposing a custom device object that corresponds to our custom backend.
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
// that's implemented in C++.
@ -115,4 +134,5 @@ bool custom_add_called() {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_device", &get_custom_device, "get custom device object");
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
m.def("register_genertor", &register_genertor, "register generator for custom device");
}

View File

@ -100,5 +100,20 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# None of our CPU operations should call the custom add function.
self.assertFalse(module.custom_add_called())
# check generator registered befor use
with self.assertRaisesRegex(RuntimeError,
"Please register a generator to the PrivateUse1 dispatch key"):
gen_ = torch.Generator(device=device)
module.register_genertor()
gen = torch.Generator(device=device)
self.assertTrue(gen.device == device)
# generator can be registered only once
with self.assertRaisesRegex(RuntimeError,
"Only can register a generator to the PrivateUse1 dispatch key once"):
module.register_genertor()
if __name__ == "__main__":
common.run_tests()

View File

@ -4,6 +4,7 @@
#include <ATen/CPUGeneratorImpl.h>
#include <structmember.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/THP.h>
@ -68,7 +69,9 @@ static PyObject* THPGenerator_pynew(
self->cdata = make_generator<MPSGeneratorImpl>();
}
#endif
else {
else if (device.type() == at::kPrivateUse1) {
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
} else {
AT_ERROR(
"Device type ",
c10::DeviceTypeName(device.type()),