diff --git a/aten/src/ATen/core/GeneratorForPrivateuseone.cpp b/aten/src/ATen/core/GeneratorForPrivateuseone.cpp index 34d84085ca03..030e9f70851a 100644 --- a/aten/src/ATen/core/GeneratorForPrivateuseone.cpp +++ b/aten/src/ATen/core/GeneratorForPrivateuseone.cpp @@ -1,6 +1,7 @@ -#include #include +#include + namespace at { static std::mutex _generator_mutex_lock; @@ -12,6 +13,11 @@ std::optional& GetGeneratorPrivate() { _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) { std::lock_guard lock(_generator_mutex_lock); + + TORCH_WARN_DEPRECATION( + "REGISTER_GENERATOR_PRIVATEUSE1 is deprecated. \ + Please derive PrivateUse1HooksInterface to implememt getNewGenerator instead.") + TORCH_CHECK( !GetGeneratorPrivate().has_value(), "Only can register a generator to the PrivateUse1 dispatch key once!"); @@ -21,6 +27,10 @@ _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) { } at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) { + TORCH_WARN_DEPRECATION( + "GetGeneratorForPrivateuse1() is deprecated. Please use \ + globalContext().getAcceleratorHooksInterface(device_type).getNewGenerator() instead.") + TORCH_CHECK( GetGeneratorPrivate().has_value(), "Please register a generator to the PrivateUse1 dispatch key, \ diff --git a/aten/src/ATen/core/GeneratorForPrivateuseone.h b/aten/src/ATen/core/GeneratorForPrivateuseone.h index 747c77897ff9..a4879a1f5f5c 100644 --- a/aten/src/ATen/core/GeneratorForPrivateuseone.h +++ b/aten/src/ATen/core/GeneratorForPrivateuseone.h @@ -7,7 +7,7 @@ namespace at { using GeneratorFuncType = std::function; -std::optional& GetGeneratorPrivate(); +TORCH_API std::optional& GetGeneratorPrivate(); class TORCH_API _GeneratorRegister { public: diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index d14ea9d505a2..69388071a628 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -106,6 +106,10 @@ const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const return at::cuda::detail::getDefaultCUDAGenerator(device_index); } +Generator CUDAHooks::getNewGenerator(DeviceIndex device_index) const { + return make_generator(device_index); +} + Device CUDAHooks::getDeviceFromPtr(void* data) const { return at::cuda::getDeviceFromPtr(data); } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index ea190c9e1a50..60f99b5e41ae 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -23,6 +23,8 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool isPinnedPtr(const void* data) const override; const Generator& getDefaultGenerator( DeviceIndex device_index = -1) const override; + Generator getNewGenerator( + DeviceIndex device_index = -1) const override; bool hasCUDA() const override; bool hasMAGMA() const override; bool hasCuDNN() const override; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index dc7bf51ad72d..9b54a84dd68d 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -74,6 +74,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { CUDA_HELP); } + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK( + false, + "Cannot get CUDA generator without ATen_cuda library. ", + CUDA_HELP); + } + Device getDeviceFromPtr(void* /*data*/) const override { TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP); } diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 50e42fbe798c..01d6281e8afe 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -35,6 +35,10 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { [[maybe_unused]] DeviceIndex device_index = -1) const override { FAIL_MPSHOOKS_FUNC(__func__); } + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index) const override { + FAIL_MPSHOOKS_FUNC(__func__); + } virtual Allocator* getMPSDeviceAllocator() const { FAIL_MPSHOOKS_FUNC(__func__); } diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.h b/aten/src/ATen/detail/PrivateUse1HooksInterface.h index 17927046d2e4..69819c764260 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.h +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.h @@ -1,6 +1,8 @@ #pragma once +#include #include + #include #include #include @@ -11,19 +13,32 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") namespace at { struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { +#define FAIL_PRIVATEUSE1HOOKS_FUNC(func) \ + TORCH_CHECK_NOT_IMPLEMENTED( \ + false, \ + "You should register `PrivateUse1HooksInterface`", \ + "by `RegisterPrivateUse1HooksInterface` and implement `", \ + func, \ + "` at the same time for PrivateUse1."); + ~PrivateUse1HooksInterface() override = default; const at::Generator& getDefaultGenerator( c10::DeviceIndex device_index) const override { - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`."); + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + // TODO(FFFrog): Perserved for BC and will be removed in the future. + if (at::GetGeneratorPrivate().has_value()) + return at::GetGeneratorForPrivateuse1(device_index); + + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); } at::Device getDeviceFromPtr(void* data) const override { - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`."); + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); } bool isPinnedPtr(const void* data) const override { @@ -31,25 +46,21 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { } Allocator* getPinnedMemoryAllocator() const override { - TORCH_CHECK( - false, - "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`."); + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); } bool hasPrimaryContext(DeviceIndex device_index) const override { - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`."); + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); } void init() const override {} virtual void resizePrivateUse1Bytes( const c10::Storage& storage, size_t newsize) const { - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`."); + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); } + +#undef FAIL_PRIVATEUSE1HOOKS_FUNC }; struct TORCH_API PrivateUse1HooksArgs {}; @@ -66,4 +77,5 @@ TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks(); } // namespace detail } // namespace at + C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 58c0614239de..13f832051a76 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -21,6 +21,7 @@ struct MPSHooks : public at::MPSHooksInterface { // MPSGeneratorImpl interface const Generator& getDefaultGenerator( DeviceIndex device_index = -1) const override; + Generator getNewGenerator(DeviceIndex device_index = -1) const override; // MPSStream interface void deviceSynchronize() const override; diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index 9eef2267797c..03c39c957368 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -69,6 +69,10 @@ const Generator& MPSHooks::getDefaultGenerator([[maybe_unused]] DeviceIndex devi return at::mps::detail::getDefaultMPSGenerator(); } +Generator MPSHooks::getNewGenerator([[maybe_unused]] DeviceIndex device_index) const { + return make_generator(); +} + void MPSHooks::deviceSynchronize() const { at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT); } diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index f857aecc657c..d9eadbbab084 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -1,28 +1,28 @@ -#include -#include #include #include +#include +#include +#include #include #include #include -#include -#include #include -#include -#include +#include +#include +#include #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include #include +#include +#include +#include + +#include static uint64_t add_counter = 0; static uint64_t last_saved_value = 0; @@ -551,8 +551,15 @@ bool custom_add_called() { return called; } +void set_custom_device_index(c10::DeviceIndex device_index) { + custom_device_index = device_index; +} + +// a global flag used for dummy pin_memory of custom device +bool custom_pinned_flag = false; + class PrivateGeneratorImpl : public at::CPUGeneratorImpl { -public: + public: // Constructors PrivateGeneratorImpl(c10::DeviceIndex device_index) { device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); @@ -561,45 +568,33 @@ public: ~PrivateGeneratorImpl() override = default; }; -// this is used to register generator -at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) { - return at::make_generator(device_index); -} - -void register_generator_first() { - REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1) -} - -void register_generator_second() { - REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1) -} - -void set_custom_device_index(c10::DeviceIndex device_index) { - custom_device_index = device_index; -} - -// a global flag used for dummy pin_memory of custom device -bool custom_pinned_flag = false; - struct FooHooksArgs : public at::PrivateUse1HooksArgs {}; struct FooHooksInterface : public at::PrivateUse1HooksInterface { - FooHooksInterface(FooHooksArgs) {} - ~FooHooksInterface() override = default; - const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override { - static auto device_gen = make_generator_privateuse1(device_index); - return device_gen; - } - // this is a simple implementation, custom_pinned_flag will be set as true - // once tensor.pin_memory() is called. And then tensor.is_pinned() - // always return true no matter what tensor it's called on. - bool isPinnedPtr(const void* data) const override { - return custom_pinned_flag; - } - c10::Allocator* getPinnedMemoryAllocator() const override { - custom_pinned_flag = true; - return c10::GetCPUAllocator(); - } + FooHooksInterface(FooHooksArgs) {} + ~FooHooksInterface() override = default; + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + static auto device_gen = at::make_generator(device_index); + return device_gen; + } + + at::Generator getNewGenerator(c10::DeviceIndex device_index) const { + return at::make_generator(device_index); + } + + // this is a simple implementation, custom_pinned_flag will be set as true + // once tensor.pin_memory() is called. And then tensor.is_pinned() + // always return true no matter what tensor it's called on. + bool isPinnedPtr(const void* data) const override { + return custom_pinned_flag; + } + + c10::Allocator* getPinnedMemoryAllocator() const override { + custom_pinned_flag = true; + return c10::GetCPUAllocator(); + } }; TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs); @@ -682,8 +677,6 @@ at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_device", &get_custom_device, "get custom device object"); m.def("custom_add_called", &custom_add_called, "check if our custom add function was called"); - m.def("register_generator_first", ®ister_generator_first, "register generator for custom device firstly"); - m.def("register_generator_second", ®ister_generator_second, "register generator for custom device secondly"); m.def("set_custom_device_index", &set_custom_device_index, "set custom device index"); m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method"); m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called"); diff --git a/test/inductor/extension_backends/cpp/extension_device.cpp b/test/inductor/extension_backends/cpp/extension_device.cpp index 243cdbb156c1..249ab3865668 100644 --- a/test/inductor/extension_backends/cpp/extension_device.cpp +++ b/test/inductor/extension_backends/cpp/extension_device.cpp @@ -1,16 +1,15 @@ -#include #include +#include -#include #include #include +#include #include -#include +#include #include #include -#include -#include +#include static uint64_t op_counter = 0; static uint64_t last_saved_value = 0; @@ -179,25 +178,6 @@ bool custom_op_called() { return called; } -class PrivateGeneratorImpl : public at::CPUGeneratorImpl { -public: - // Constructors - PrivateGeneratorImpl(c10::DeviceIndex device_index) { - device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); - key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); - } - ~PrivateGeneratorImpl() override = default; -}; - -// this is used to register generator -at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) { - return at::make_generator(device_index); -} - -void register_generator() { - REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1) -} - // Here, we're exposing a custom device object that corresponds to our custom backend. // We do this using pybind: exposing an "extension_name.custom_device()" function in python, // that's implemented in C++. @@ -205,5 +185,4 @@ void register_generator() { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_device", &get_custom_device, "get custom device object"); m.def("custom_op_called", &custom_op_called, "check if our custom function was called"); - m.def("register_generator", ®ister_generator, "register generator for custom device"); } diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 2c0f8b8b2a09..d427bc2653bf 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -173,23 +173,16 @@ class TestCppExtensionOpenRgistration(common.TestCase): # check generator registered before using with self.assertRaisesRegex( RuntimeError, - "Please register a generator to the PrivateUse1 dispatch key", + "Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first", ): torch.Generator(device=device) - self.module.register_generator_first() + if self.module.is_register_hook() is False: + self.module.register_hook() + gen = torch.Generator(device=device) self.assertTrue(gen.device == device) - # generator can be registered only once - with self.assertRaisesRegex( - RuntimeError, - "Only can register a generator to the PrivateUse1 dispatch key once", - ): - self.module.register_generator_second() - - if self.module.is_register_hook() is False: - self.module.register_hook() default_gen = self.module.default_generator(0) self.assertTrue( default_gen.device.type == torch._C._get_privateuse1_backend_name() diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index e3181b2ce13c..ddc6e26966de 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -1,13 +1,6 @@ -#include - -#include -#include -#include - -#include -#include #include #include +#include #include #include #include @@ -15,16 +8,13 @@ #include #include +#include +#include +#include + +#include #include -#ifdef USE_CUDA -#include -#endif - -#ifdef USE_MPS -#include -#endif - using namespace at; using namespace torch; @@ -60,31 +50,16 @@ static PyObject* THPGenerator_pynew( auto device = r.deviceWithDefault(0, at::Device(at::kCPU)); THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0)); - if (device.type() == at::kCPU) { + + c10::DeviceType device_type = device.type(); + if (device_type == at::kCPU) { self->cdata = make_generator(); - } -#ifdef USE_CUDA - else if (device.type() == at::kCUDA) { - self->cdata = make_generator(device.index()); - } -#elif USE_MPS - else if (device.type() == at::kMPS) { - self->cdata = make_generator(); - } -#endif - else if (device.type() == at::kXPU) { - self->cdata = at::detail::getXPUHooks().getNewGenerator(device.index()); - } else if (device.type() == at::kIPU) { - self->cdata = at::detail::getIPUHooks().getNewGenerator(device.index()); - } else if (device.type() == at::kPrivateUse1) { - self->cdata = at::GetGeneratorForPrivateuse1(device.index()); } else { - TORCH_CHECK( - false, - "Device type ", - c10::DeviceTypeName(device.type()), - " is not supported for torch.Generator() api."); + self->cdata = globalContext() + .getAcceleratorHooksInterface(device_type) + .getNewGenerator(device.index()); } + return (PyObject*)self.release(); END_HANDLE_TH_ERRORS }