Port /test/cpp_extensions/rng_extension.cpp to new operator registration API (#39459)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39459

Update to this PR: this code isn't going to fully solve https://github.com/pytorch/pytorch/issues/37010. The changes required for 37010 is more than this PR initially planned. Instead, this PR switches op registration of rng related tests to use the new API (similar to what was done in #36925)

Test Plan:
1) unit tests

Imported from OSS

Reviewed By: ezyang

Differential Revision: D22264889

fbshipit-source-id: 82488ac6e3b762a756818434e22c2a0f9cb9dd47
This commit is contained in:
Changji Shi
2020-06-26 16:10:25 -07:00
committed by Facebook GitHub Bot
parent 24a8614cac
commit 47c72be3d7
3 changed files with 10 additions and 24 deletions

View File

@ -29,12 +29,12 @@ Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
return a;
}
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
m.impl_UNBOXED("aten::empty.memory_format", empty_override);
m.impl_UNBOXED("aten::add.Tensor", add_override);
}
TEST(BackendExtensionTest, TestRegisterOp) {
EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU));
auto registry1 = torch::RegisterOperators()
.op(torch::RegisterOperators::options()
.schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
.impl_unboxedOnlyKernel<decltype(empty_override), &empty_override>(DispatchKey::MSNPU));
Tensor a = empty({5, 5}, at::kMSNPU);
ASSERT_EQ(a.device().type(), at::kMSNPU);
ASSERT_EQ(a.device().index(), 1);
@ -46,11 +46,6 @@ TEST(BackendExtensionTest, TestRegisterOp) {
ASSERT_EQ(b.device().index(), 1);
ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
EXPECT_ANY_THROW(add(a, b));
auto registry2 = torch::RegisterOperators()
.op(torch::RegisterOperators::options()
.schema("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
.impl_unboxedOnlyKernel<decltype(add_override), &add_override>(DispatchKey::MSNPU));
add(a, b);
ASSERT_EQ(test_int, 2);

View File

@ -1,9 +1,9 @@
#include <torch/extension.h>
#include <torch/library.h>
#include <ATen/Generator.h>
#include <ATen/Tensor.h>
#include <ATen/native/DistributionTemplates.h>
#include <ATen/native/cpu/DistributionTemplates.h>
#include <ATen/core/op_registration/op_registration.h>
#include <memory>
using namespace at;
@ -53,21 +53,13 @@ size_t getInstanceCount() {
return instance_count;
}
void registerOps() {
static auto registry = torch::RegisterOperators()
.op(torch::RegisterOperators::options()
.schema("aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)")
.impl_unboxedOnlyKernel<decltype(random_from_to), &random_from_to>(DispatchKey::CustomRNGKeyId))
.op(torch::RegisterOperators::options()
.schema("aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)")
.impl_unboxedOnlyKernel<decltype(random_to), &random_to>(DispatchKey::CustomRNGKeyId))
.op(torch::RegisterOperators::options()
.schema("aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)")
.impl_unboxedOnlyKernel<decltype(random_), &random_>(DispatchKey::CustomRNGKeyId));
TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
m.impl_UNBOXED("aten::random_.from", random_from_to);
m.impl_UNBOXED("aten::random_.to", random_to);
m.impl_UNBOXED("aten::random_", random_);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("registerOps", &registerOps);
m.def("createTestCPUGenerator", &createTestCPUGenerator);
m.def("getInstanceCount", &getInstanceCount);
m.def("identity", &identity);

View File

@ -141,7 +141,6 @@ class TestRNGExtension(common.TestCase):
def setUp(self):
super(TestRNGExtension, self).setUp()
rng_extension.registerOps()
def test_rng(self):
fourty_two = torch.full((10,), 42, dtype=torch.int64)