mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port /test/cpp_extensions/rng_extension.cpp to new operator registration API (#39459)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39459 Update to this PR: this code isn't going to fully solve https://github.com/pytorch/pytorch/issues/37010. The changes required for 37010 is more than this PR initially planned. Instead, this PR switches op registration of rng related tests to use the new API (similar to what was done in #36925) Test Plan: 1) unit tests Imported from OSS Reviewed By: ezyang Differential Revision: D22264889 fbshipit-source-id: 82488ac6e3b762a756818434e22c2a0f9cb9dd47
This commit is contained in:
committed by
Facebook GitHub Bot
parent
24a8614cac
commit
47c72be3d7
@ -29,12 +29,12 @@ Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
||||
return a;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
|
||||
m.impl_UNBOXED("aten::empty.memory_format", empty_override);
|
||||
m.impl_UNBOXED("aten::add.Tensor", add_override);
|
||||
}
|
||||
|
||||
TEST(BackendExtensionTest, TestRegisterOp) {
|
||||
EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU));
|
||||
auto registry1 = torch::RegisterOperators()
|
||||
.op(torch::RegisterOperators::options()
|
||||
.schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
|
||||
.impl_unboxedOnlyKernel<decltype(empty_override), &empty_override>(DispatchKey::MSNPU));
|
||||
Tensor a = empty({5, 5}, at::kMSNPU);
|
||||
ASSERT_EQ(a.device().type(), at::kMSNPU);
|
||||
ASSERT_EQ(a.device().index(), 1);
|
||||
@ -46,11 +46,6 @@ TEST(BackendExtensionTest, TestRegisterOp) {
|
||||
ASSERT_EQ(b.device().index(), 1);
|
||||
ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
|
||||
|
||||
EXPECT_ANY_THROW(add(a, b));
|
||||
auto registry2 = torch::RegisterOperators()
|
||||
.op(torch::RegisterOperators::options()
|
||||
.schema("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
|
||||
.impl_unboxedOnlyKernel<decltype(add_override), &add_override>(DispatchKey::MSNPU));
|
||||
add(a, b);
|
||||
ASSERT_EQ(test_int, 2);
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/Generator.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/native/DistributionTemplates.h>
|
||||
#include <ATen/native/cpu/DistributionTemplates.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <memory>
|
||||
|
||||
using namespace at;
|
||||
@ -53,21 +53,13 @@ size_t getInstanceCount() {
|
||||
return instance_count;
|
||||
}
|
||||
|
||||
void registerOps() {
|
||||
static auto registry = torch::RegisterOperators()
|
||||
.op(torch::RegisterOperators::options()
|
||||
.schema("aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)")
|
||||
.impl_unboxedOnlyKernel<decltype(random_from_to), &random_from_to>(DispatchKey::CustomRNGKeyId))
|
||||
.op(torch::RegisterOperators::options()
|
||||
.schema("aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)")
|
||||
.impl_unboxedOnlyKernel<decltype(random_to), &random_to>(DispatchKey::CustomRNGKeyId))
|
||||
.op(torch::RegisterOperators::options()
|
||||
.schema("aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)")
|
||||
.impl_unboxedOnlyKernel<decltype(random_), &random_>(DispatchKey::CustomRNGKeyId));
|
||||
TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
|
||||
m.impl_UNBOXED("aten::random_.from", random_from_to);
|
||||
m.impl_UNBOXED("aten::random_.to", random_to);
|
||||
m.impl_UNBOXED("aten::random_", random_);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("registerOps", ®isterOps);
|
||||
m.def("createTestCPUGenerator", &createTestCPUGenerator);
|
||||
m.def("getInstanceCount", &getInstanceCount);
|
||||
m.def("identity", &identity);
|
||||
|
@ -141,7 +141,6 @@ class TestRNGExtension(common.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestRNGExtension, self).setUp()
|
||||
rng_extension.registerOps()
|
||||
|
||||
def test_rng(self):
|
||||
fourty_two = torch.full((10,), 42, dtype=torch.int64)
|
||||
|
Reference in New Issue
Block a user