Make Context to be Device-agnostic Step by Step (3/N) (#137578)

Detailed Descriptions:
- Using unified Device-agnostic API to create new generator for accelerator.
- Add deprecated info for GeneratorForPrivateuseone

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137578
Approved by: https://github.com/cyyever, https://github.com/ezyang
This commit is contained in:
FFFrog
2024-12-18 12:41:13 +08:00
committed by PyTorch MergeBot
parent 80a42399bb
commit f47aac6bc2
13 changed files with 128 additions and 143 deletions

View File

@ -1,6 +1,7 @@
#include <mutex>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <mutex>
namespace at {
static std::mutex _generator_mutex_lock;
@ -12,6 +13,11 @@ std::optional<GeneratorFuncType>& GetGeneratorPrivate() {
_GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
std::lock_guard<std::mutex> lock(_generator_mutex_lock);
TORCH_WARN_DEPRECATION(
"REGISTER_GENERATOR_PRIVATEUSE1 is deprecated. \
Please derive PrivateUse1HooksInterface to implememt getNewGenerator instead.")
TORCH_CHECK(
!GetGeneratorPrivate().has_value(),
"Only can register a generator to the PrivateUse1 dispatch key once!");
@ -21,6 +27,10 @@ _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
}
at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) {
TORCH_WARN_DEPRECATION(
"GetGeneratorForPrivateuse1() is deprecated. Please use \
globalContext().getAcceleratorHooksInterface(device_type).getNewGenerator() instead.")
TORCH_CHECK(
GetGeneratorPrivate().has_value(),
"Please register a generator to the PrivateUse1 dispatch key, \

View File

@ -7,7 +7,7 @@ namespace at {
using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
std::optional<GeneratorFuncType>& GetGeneratorPrivate();
TORCH_API std::optional<GeneratorFuncType>& GetGeneratorPrivate();
class TORCH_API _GeneratorRegister {
public:

View File

@ -106,6 +106,10 @@ const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const
return at::cuda::detail::getDefaultCUDAGenerator(device_index);
}
Generator CUDAHooks::getNewGenerator(DeviceIndex device_index) const {
return make_generator<at::CUDAGeneratorImpl>(device_index);
}
Device CUDAHooks::getDeviceFromPtr(void* data) const {
return at::cuda::getDeviceFromPtr(data);
}

View File

@ -23,6 +23,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool isPinnedPtr(const void* data) const override;
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;
Generator getNewGenerator(
DeviceIndex device_index = -1) const override;
bool hasCUDA() const override;
bool hasMAGMA() const override;
bool hasCuDNN() const override;

View File

@ -74,6 +74,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
CUDA_HELP);
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(
false,
"Cannot get CUDA generator without ATen_cuda library. ",
CUDA_HELP);
}
Device getDeviceFromPtr(void* /*data*/) const override {
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -35,6 +35,10 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
[[maybe_unused]] DeviceIndex device_index = -1) const override {
FAIL_MPSHOOKS_FUNC(__func__);
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index) const override {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual Allocator* getMPSDeviceAllocator() const {
FAIL_MPSHOOKS_FUNC(__func__);
}

View File

@ -1,6 +1,8 @@
#pragma once
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/Storage.h>
@ -11,19 +13,32 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
#define FAIL_PRIVATEUSE1HOOKS_FUNC(func) \
TORCH_CHECK_NOT_IMPLEMENTED( \
false, \
"You should register `PrivateUse1HooksInterface`", \
"by `RegisterPrivateUse1HooksInterface` and implement `", \
func, \
"` at the same time for PrivateUse1.");
~PrivateUse1HooksInterface() override = default;
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
// TODO(FFFrog): Perserved for BC and will be removed in the future.
if (at::GetGeneratorPrivate().has_value())
return at::GetGeneratorForPrivateuse1(device_index);
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
}
at::Device getDeviceFromPtr(void* data) const override {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
}
bool isPinnedPtr(const void* data) const override {
@ -31,25 +46,21 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
}
Allocator* getPinnedMemoryAllocator() const override {
TORCH_CHECK(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
}
void init() const override {}
virtual void resizePrivateUse1Bytes(
const c10::Storage& storage,
size_t newsize) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
}
#undef FAIL_PRIVATEUSE1HOOKS_FUNC
};
struct TORCH_API PrivateUse1HooksArgs {};
@ -66,4 +77,5 @@ TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
} // namespace detail
} // namespace at
C10_DIAGNOSTIC_POP()

View File

@ -21,6 +21,7 @@ struct MPSHooks : public at::MPSHooksInterface {
// MPSGeneratorImpl interface
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;
Generator getNewGenerator(DeviceIndex device_index = -1) const override;
// MPSStream interface
void deviceSynchronize() const override;

View File

@ -69,6 +69,10 @@ const Generator& MPSHooks::getDefaultGenerator([[maybe_unused]] DeviceIndex devi
return at::mps::detail::getDefaultMPSGenerator();
}
Generator MPSHooks::getNewGenerator([[maybe_unused]] DeviceIndex device_index) const {
return make_generator<at::MPSGeneratorImpl>();
}
void MPSHooks::deviceSynchronize() const {
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
}

View File

@ -1,28 +1,28 @@
#include <unordered_map>
#include <c10/core/impl/alloc_cpu.h>
#include <c10/core/Allocator.h>
#include <c10/core/ScalarType.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/alloc_cpu.h>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <torch/extension.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <ATen/EmptyTensor.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Resize.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/ops/abs_native.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/ops/view.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/ops/abs_native.h>
#include <ATen/ops/view.h>
#include <unordered_map>
static uint64_t add_counter = 0;
static uint64_t last_saved_value = 0;
@ -551,8 +551,15 @@ bool custom_add_called() {
return called;
}
void set_custom_device_index(c10::DeviceIndex device_index) {
custom_device_index = device_index;
}
// a global flag used for dummy pin_memory of custom device
bool custom_pinned_flag = false;
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
public:
public:
// Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
@ -561,45 +568,33 @@ public:
~PrivateGeneratorImpl() override = default;
};
// this is used to register generator
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
return at::make_generator<PrivateGeneratorImpl>(device_index);
}
void register_generator_first() {
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
}
void register_generator_second() {
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
}
void set_custom_device_index(c10::DeviceIndex device_index) {
custom_device_index = device_index;
}
// a global flag used for dummy pin_memory of custom device
bool custom_pinned_flag = false;
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
FooHooksInterface(FooHooksArgs) {}
~FooHooksInterface() override = default;
const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override {
static auto device_gen = make_generator_privateuse1(device_index);
return device_gen;
}
// this is a simple implementation, custom_pinned_flag will be set as true
// once tensor.pin_memory() is called. And then tensor.is_pinned()
// always return true no matter what tensor it's called on.
bool isPinnedPtr(const void* data) const override {
return custom_pinned_flag;
}
c10::Allocator* getPinnedMemoryAllocator() const override {
custom_pinned_flag = true;
return c10::GetCPUAllocator();
}
FooHooksInterface(FooHooksArgs) {}
~FooHooksInterface() override = default;
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
static auto device_gen = at::make_generator<PrivateGeneratorImpl>(device_index);
return device_gen;
}
at::Generator getNewGenerator(c10::DeviceIndex device_index) const {
return at::make_generator<PrivateGeneratorImpl>(device_index);
}
// this is a simple implementation, custom_pinned_flag will be set as true
// once tensor.pin_memory() is called. And then tensor.is_pinned()
// always return true no matter what tensor it's called on.
bool isPinnedPtr(const void* data) const override {
return custom_pinned_flag;
}
c10::Allocator* getPinnedMemoryAllocator() const override {
custom_pinned_flag = true;
return c10::GetCPUAllocator();
}
};
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
@ -682,8 +677,6 @@ at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_device", &get_custom_device, "get custom device object");
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
m.def("register_generator_first", &register_generator_first, "register generator for custom device firstly");
m.def("register_generator_second", &register_generator_second, "register generator for custom device secondly");
m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");

View File

@ -1,16 +1,15 @@
#include <c10/core/impl/alloc_cpu.h>
#include <c10/core/Allocator.h>
#include <c10/core/impl/alloc_cpu.h>
#include <torch/csrc/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/Device.h>
#include <torch/extension.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/EmptyTensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Resize.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/native/cpu/Loops.h>
static uint64_t op_counter = 0;
static uint64_t last_saved_value = 0;
@ -179,25 +178,6 @@ bool custom_op_called() {
return called;
}
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
public:
// Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
}
~PrivateGeneratorImpl() override = default;
};
// this is used to register generator
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
return at::make_generator<PrivateGeneratorImpl>(device_index);
}
void register_generator() {
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
}
// Here, we're exposing a custom device object that corresponds to our custom backend.
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
// that's implemented in C++.
@ -205,5 +185,4 @@ void register_generator() {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_device", &get_custom_device, "get custom device object");
m.def("custom_op_called", &custom_op_called, "check if our custom function was called");
m.def("register_generator", &register_generator, "register generator for custom device");
}

View File

@ -173,23 +173,16 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# check generator registered before using
with self.assertRaisesRegex(
RuntimeError,
"Please register a generator to the PrivateUse1 dispatch key",
"Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first",
):
torch.Generator(device=device)
self.module.register_generator_first()
if self.module.is_register_hook() is False:
self.module.register_hook()
gen = torch.Generator(device=device)
self.assertTrue(gen.device == device)
# generator can be registered only once
with self.assertRaisesRegex(
RuntimeError,
"Only can register a generator to the PrivateUse1 dispatch key once",
):
self.module.register_generator_second()
if self.module.is_register_hook() is False:
self.module.register_hook()
default_gen = self.module.default_generator(0)
self.assertTrue(
default_gen.device.type == torch._C._get_privateuse1_backend_name()

View File

@ -1,13 +1,6 @@
#include <torch/csrc/Generator.h>
#include <ATen/ATen.h>
#include <ATen/CPUGeneratorImpl.h>
#include <structmember.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/detail/XPUHooksInterface.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
@ -15,16 +8,13 @@
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/tensor_types.h>
#include <ATen/ATen.h>
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/detail/XPUHooksInterface.h>
#include <structmember.h>
#include <utility>
#ifdef USE_CUDA
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#ifdef USE_MPS
#include <ATen/mps/MPSGeneratorImpl.h>
#endif
using namespace at;
using namespace torch;
@ -60,31 +50,16 @@ static PyObject* THPGenerator_pynew(
auto device = r.deviceWithDefault(0, at::Device(at::kCPU));
THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0));
if (device.type() == at::kCPU) {
c10::DeviceType device_type = device.type();
if (device_type == at::kCPU) {
self->cdata = make_generator<CPUGeneratorImpl>();
}
#ifdef USE_CUDA
else if (device.type() == at::kCUDA) {
self->cdata = make_generator<CUDAGeneratorImpl>(device.index());
}
#elif USE_MPS
else if (device.type() == at::kMPS) {
self->cdata = make_generator<MPSGeneratorImpl>();
}
#endif
else if (device.type() == at::kXPU) {
self->cdata = at::detail::getXPUHooks().getNewGenerator(device.index());
} else if (device.type() == at::kIPU) {
self->cdata = at::detail::getIPUHooks().getNewGenerator(device.index());
} else if (device.type() == at::kPrivateUse1) {
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
} else {
TORCH_CHECK(
false,
"Device type ",
c10::DeviceTypeName(device.type()),
" is not supported for torch.Generator() api.");
self->cdata = globalContext()
.getAcceleratorHooksInterface(device_type)
.getNewGenerator(device.index());
}
return (PyObject*)self.release();
END_HANDLE_TH_ERRORS
}