mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make Context to be Device-agnostic Step by Step (3/N) (#137578)
Detailed Descriptions: - Using unified Device-agnostic API to create new generator for accelerator. - Add deprecated info for GeneratorForPrivateuseone Pull Request resolved: https://github.com/pytorch/pytorch/pull/137578 Approved by: https://github.com/cyyever, https://github.com/ezyang
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
#include <mutex>
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace at {
|
||||
|
||||
static std::mutex _generator_mutex_lock;
|
||||
@ -12,6 +13,11 @@ std::optional<GeneratorFuncType>& GetGeneratorPrivate() {
|
||||
|
||||
_GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
|
||||
std::lock_guard<std::mutex> lock(_generator_mutex_lock);
|
||||
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"REGISTER_GENERATOR_PRIVATEUSE1 is deprecated. \
|
||||
Please derive PrivateUse1HooksInterface to implememt getNewGenerator instead.")
|
||||
|
||||
TORCH_CHECK(
|
||||
!GetGeneratorPrivate().has_value(),
|
||||
"Only can register a generator to the PrivateUse1 dispatch key once!");
|
||||
@ -21,6 +27,10 @@ _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
|
||||
}
|
||||
|
||||
at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"GetGeneratorForPrivateuse1() is deprecated. Please use \
|
||||
globalContext().getAcceleratorHooksInterface(device_type).getNewGenerator() instead.")
|
||||
|
||||
TORCH_CHECK(
|
||||
GetGeneratorPrivate().has_value(),
|
||||
"Please register a generator to the PrivateUse1 dispatch key, \
|
||||
|
@ -7,7 +7,7 @@ namespace at {
|
||||
|
||||
using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
|
||||
|
||||
std::optional<GeneratorFuncType>& GetGeneratorPrivate();
|
||||
TORCH_API std::optional<GeneratorFuncType>& GetGeneratorPrivate();
|
||||
|
||||
class TORCH_API _GeneratorRegister {
|
||||
public:
|
||||
|
@ -106,6 +106,10 @@ const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const
|
||||
return at::cuda::detail::getDefaultCUDAGenerator(device_index);
|
||||
}
|
||||
|
||||
Generator CUDAHooks::getNewGenerator(DeviceIndex device_index) const {
|
||||
return make_generator<at::CUDAGeneratorImpl>(device_index);
|
||||
}
|
||||
|
||||
Device CUDAHooks::getDeviceFromPtr(void* data) const {
|
||||
return at::cuda::getDeviceFromPtr(data);
|
||||
}
|
||||
|
@ -23,6 +23,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool isPinnedPtr(const void* data) const override;
|
||||
const Generator& getDefaultGenerator(
|
||||
DeviceIndex device_index = -1) const override;
|
||||
Generator getNewGenerator(
|
||||
DeviceIndex device_index = -1) const override;
|
||||
bool hasCUDA() const override;
|
||||
bool hasMAGMA() const override;
|
||||
bool hasCuDNN() const override;
|
||||
|
@ -74,6 +74,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
CUDA_HELP);
|
||||
}
|
||||
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot get CUDA generator without ATen_cuda library. ",
|
||||
CUDA_HELP);
|
||||
}
|
||||
|
||||
Device getDeviceFromPtr(void* /*data*/) const override {
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
@ -35,6 +35,10 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index) const override {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual Allocator* getMPSDeviceAllocator() const {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Storage.h>
|
||||
@ -11,19 +13,32 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
#define FAIL_PRIVATEUSE1HOOKS_FUNC(func) \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED( \
|
||||
false, \
|
||||
"You should register `PrivateUse1HooksInterface`", \
|
||||
"by `RegisterPrivateUse1HooksInterface` and implement `", \
|
||||
func, \
|
||||
"` at the same time for PrivateUse1.");
|
||||
|
||||
~PrivateUse1HooksInterface() override = default;
|
||||
|
||||
const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) const override {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
|
||||
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
// TODO(FFFrog): Perserved for BC and will be removed in the future.
|
||||
if (at::GetGeneratorPrivate().has_value())
|
||||
return at::GetGeneratorForPrivateuse1(device_index);
|
||||
|
||||
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
at::Device getDeviceFromPtr(void* data) const override {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
|
||||
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
bool isPinnedPtr(const void* data) const override {
|
||||
@ -31,25 +46,21 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
}
|
||||
|
||||
Allocator* getPinnedMemoryAllocator() const override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
|
||||
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
|
||||
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
void init() const override {}
|
||||
virtual void resizePrivateUse1Bytes(
|
||||
const c10::Storage& storage,
|
||||
size_t newsize) const {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
|
||||
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
#undef FAIL_PRIVATEUSE1HOOKS_FUNC
|
||||
};
|
||||
|
||||
struct TORCH_API PrivateUse1HooksArgs {};
|
||||
@ -66,4 +77,5 @@ TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
|
||||
} // namespace detail
|
||||
|
||||
} // namespace at
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
@ -21,6 +21,7 @@ struct MPSHooks : public at::MPSHooksInterface {
|
||||
// MPSGeneratorImpl interface
|
||||
const Generator& getDefaultGenerator(
|
||||
DeviceIndex device_index = -1) const override;
|
||||
Generator getNewGenerator(DeviceIndex device_index = -1) const override;
|
||||
|
||||
// MPSStream interface
|
||||
void deviceSynchronize() const override;
|
||||
|
@ -69,6 +69,10 @@ const Generator& MPSHooks::getDefaultGenerator([[maybe_unused]] DeviceIndex devi
|
||||
return at::mps::detail::getDefaultMPSGenerator();
|
||||
}
|
||||
|
||||
Generator MPSHooks::getNewGenerator([[maybe_unused]] DeviceIndex device_index) const {
|
||||
return make_generator<at::MPSGeneratorImpl>();
|
||||
}
|
||||
|
||||
void MPSHooks::deviceSynchronize() const {
|
||||
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
}
|
||||
|
@ -1,28 +1,28 @@
|
||||
#include <unordered_map>
|
||||
#include <c10/core/impl/alloc_cpu.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/core/impl/alloc_cpu.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
#include <ATen/ops/abs_native.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
#include <ATen/ops/view.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <ATen/ops/abs_native.h>
|
||||
#include <ATen/ops/view.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
static uint64_t add_counter = 0;
|
||||
static uint64_t last_saved_value = 0;
|
||||
@ -551,8 +551,15 @@ bool custom_add_called() {
|
||||
return called;
|
||||
}
|
||||
|
||||
void set_custom_device_index(c10::DeviceIndex device_index) {
|
||||
custom_device_index = device_index;
|
||||
}
|
||||
|
||||
// a global flag used for dummy pin_memory of custom device
|
||||
bool custom_pinned_flag = false;
|
||||
|
||||
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
|
||||
public:
|
||||
public:
|
||||
// Constructors
|
||||
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
|
||||
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
|
||||
@ -561,45 +568,33 @@ public:
|
||||
~PrivateGeneratorImpl() override = default;
|
||||
};
|
||||
|
||||
// this is used to register generator
|
||||
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
|
||||
return at::make_generator<PrivateGeneratorImpl>(device_index);
|
||||
}
|
||||
|
||||
void register_generator_first() {
|
||||
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
|
||||
}
|
||||
|
||||
void register_generator_second() {
|
||||
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
|
||||
}
|
||||
|
||||
void set_custom_device_index(c10::DeviceIndex device_index) {
|
||||
custom_device_index = device_index;
|
||||
}
|
||||
|
||||
// a global flag used for dummy pin_memory of custom device
|
||||
bool custom_pinned_flag = false;
|
||||
|
||||
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
|
||||
|
||||
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
FooHooksInterface(FooHooksArgs) {}
|
||||
~FooHooksInterface() override = default;
|
||||
const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override {
|
||||
static auto device_gen = make_generator_privateuse1(device_index);
|
||||
return device_gen;
|
||||
}
|
||||
// this is a simple implementation, custom_pinned_flag will be set as true
|
||||
// once tensor.pin_memory() is called. And then tensor.is_pinned()
|
||||
// always return true no matter what tensor it's called on.
|
||||
bool isPinnedPtr(const void* data) const override {
|
||||
return custom_pinned_flag;
|
||||
}
|
||||
c10::Allocator* getPinnedMemoryAllocator() const override {
|
||||
custom_pinned_flag = true;
|
||||
return c10::GetCPUAllocator();
|
||||
}
|
||||
FooHooksInterface(FooHooksArgs) {}
|
||||
~FooHooksInterface() override = default;
|
||||
|
||||
const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) const override {
|
||||
static auto device_gen = at::make_generator<PrivateGeneratorImpl>(device_index);
|
||||
return device_gen;
|
||||
}
|
||||
|
||||
at::Generator getNewGenerator(c10::DeviceIndex device_index) const {
|
||||
return at::make_generator<PrivateGeneratorImpl>(device_index);
|
||||
}
|
||||
|
||||
// this is a simple implementation, custom_pinned_flag will be set as true
|
||||
// once tensor.pin_memory() is called. And then tensor.is_pinned()
|
||||
// always return true no matter what tensor it's called on.
|
||||
bool isPinnedPtr(const void* data) const override {
|
||||
return custom_pinned_flag;
|
||||
}
|
||||
|
||||
c10::Allocator* getPinnedMemoryAllocator() const override {
|
||||
custom_pinned_flag = true;
|
||||
return c10::GetCPUAllocator();
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
|
||||
@ -682,8 +677,6 @@ at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("custom_device", &get_custom_device, "get custom device object");
|
||||
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
|
||||
m.def("register_generator_first", ®ister_generator_first, "register generator for custom device firstly");
|
||||
m.def("register_generator_second", ®ister_generator_second, "register generator for custom device secondly");
|
||||
m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
|
||||
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
|
||||
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
|
||||
|
@ -1,16 +1,15 @@
|
||||
#include <c10/core/impl/alloc_cpu.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/impl/alloc_cpu.h>
|
||||
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
||||
static uint64_t op_counter = 0;
|
||||
static uint64_t last_saved_value = 0;
|
||||
@ -179,25 +178,6 @@ bool custom_op_called() {
|
||||
return called;
|
||||
}
|
||||
|
||||
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
|
||||
public:
|
||||
// Constructors
|
||||
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
|
||||
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
|
||||
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
|
||||
}
|
||||
~PrivateGeneratorImpl() override = default;
|
||||
};
|
||||
|
||||
// this is used to register generator
|
||||
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
|
||||
return at::make_generator<PrivateGeneratorImpl>(device_index);
|
||||
}
|
||||
|
||||
void register_generator() {
|
||||
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
|
||||
}
|
||||
|
||||
// Here, we're exposing a custom device object that corresponds to our custom backend.
|
||||
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
|
||||
// that's implemented in C++.
|
||||
@ -205,5 +185,4 @@ void register_generator() {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("custom_device", &get_custom_device, "get custom device object");
|
||||
m.def("custom_op_called", &custom_op_called, "check if our custom function was called");
|
||||
m.def("register_generator", ®ister_generator, "register generator for custom device");
|
||||
}
|
||||
|
@ -173,23 +173,16 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
# check generator registered before using
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Please register a generator to the PrivateUse1 dispatch key",
|
||||
"Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first",
|
||||
):
|
||||
torch.Generator(device=device)
|
||||
|
||||
self.module.register_generator_first()
|
||||
if self.module.is_register_hook() is False:
|
||||
self.module.register_hook()
|
||||
|
||||
gen = torch.Generator(device=device)
|
||||
self.assertTrue(gen.device == device)
|
||||
|
||||
# generator can be registered only once
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Only can register a generator to the PrivateUse1 dispatch key once",
|
||||
):
|
||||
self.module.register_generator_second()
|
||||
|
||||
if self.module.is_register_hook() is False:
|
||||
self.module.register_hook()
|
||||
default_gen = self.module.default_generator(0)
|
||||
self.assertTrue(
|
||||
default_gen.device.type == torch._C._get_privateuse1_backend_name()
|
||||
|
@ -1,13 +1,6 @@
|
||||
#include <torch/csrc/Generator.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/CPUGeneratorImpl.h>
|
||||
#include <structmember.h>
|
||||
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
#include <ATen/detail/XPUHooksInterface.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/autograd/generated/VariableType.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
@ -15,16 +8,13 @@
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/tensor_types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/CPUGeneratorImpl.h>
|
||||
#include <ATen/detail/XPUHooksInterface.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <utility>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_MPS
|
||||
#include <ATen/mps/MPSGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
using namespace at;
|
||||
using namespace torch;
|
||||
|
||||
@ -60,31 +50,16 @@ static PyObject* THPGenerator_pynew(
|
||||
auto device = r.deviceWithDefault(0, at::Device(at::kCPU));
|
||||
|
||||
THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0));
|
||||
if (device.type() == at::kCPU) {
|
||||
|
||||
c10::DeviceType device_type = device.type();
|
||||
if (device_type == at::kCPU) {
|
||||
self->cdata = make_generator<CPUGeneratorImpl>();
|
||||
}
|
||||
#ifdef USE_CUDA
|
||||
else if (device.type() == at::kCUDA) {
|
||||
self->cdata = make_generator<CUDAGeneratorImpl>(device.index());
|
||||
}
|
||||
#elif USE_MPS
|
||||
else if (device.type() == at::kMPS) {
|
||||
self->cdata = make_generator<MPSGeneratorImpl>();
|
||||
}
|
||||
#endif
|
||||
else if (device.type() == at::kXPU) {
|
||||
self->cdata = at::detail::getXPUHooks().getNewGenerator(device.index());
|
||||
} else if (device.type() == at::kIPU) {
|
||||
self->cdata = at::detail::getIPUHooks().getNewGenerator(device.index());
|
||||
} else if (device.type() == at::kPrivateUse1) {
|
||||
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Device type ",
|
||||
c10::DeviceTypeName(device.type()),
|
||||
" is not supported for torch.Generator() api.");
|
||||
self->cdata = globalContext()
|
||||
.getAcceleratorHooksInterface(device_type)
|
||||
.getNewGenerator(device.index());
|
||||
}
|
||||
|
||||
return (PyObject*)self.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
Reference in New Issue
Block a user