#include #include #include #include #include #include #include #include namespace torch { namespace jit { TEST(CustomOperatorTest, InferredSchema) { torch::RegisterOperators reg( "foo::bar", [](double a, at::Tensor b) { return a + b; }); auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); ASSERT_EQ(op->schema().name(), "foo::bar"); ASSERT_EQ(op->schema().arguments().size(), 2); ASSERT_EQ(op->schema().arguments()[0].name(), "_0"); ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); ASSERT_EQ(op->schema().arguments()[1].name(), "_1"); ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); Stack stack; push(stack, 2.0f, at::ones(5)); op->getOperation()(stack); at::Tensor output; pop(stack, output); ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); } TEST(CustomOperatorTest, ExplicitSchema) { torch::RegisterOperators reg( "foo::bar_with_schema(float a, Tensor b) -> Tensor", [](double a, at::Tensor b) { return a + b; }); auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); ASSERT_EQ(op->schema().name(), "foo::bar_with_schema"); ASSERT_EQ(op->schema().arguments().size(), 2); ASSERT_EQ(op->schema().arguments()[0].name(), "a"); ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); ASSERT_EQ(op->schema().arguments()[1].name(), "b"); ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); ASSERT_EQ(op->schema().returns().size(), 1); ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); Stack stack; push(stack, 2.0f, at::ones(5)); op->getOperation()(stack); at::Tensor output; pop(stack, output); ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); } TEST(CustomOperatorTest, ListParameters) { // Check that lists work well. torch::RegisterOperators reg( "foo::lists(int[] ints, float[] floats, complex[] complexdoubles, Tensor[] tensors) -> float[]", [](torch::List ints, torch::List floats, torch::List> complexdoubles, torch::List tensors) { return floats; }); auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); ASSERT_EQ(op->schema().name(), "foo::lists"); ASSERT_EQ(op->schema().arguments().size(), 4); ASSERT_EQ(op->schema().arguments()[0].name(), "ints"); ASSERT_TRUE( op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofInts())); ASSERT_EQ(op->schema().arguments()[1].name(), "floats"); ASSERT_TRUE( op->schema().arguments()[1].type()->isSubtypeOf(*ListType::ofFloats())); ASSERT_EQ(op->schema().arguments()[2].name(), "complexdoubles"); ASSERT_TRUE(op->schema().arguments()[2].type()->isSubtypeOf( *ListType::ofComplexDoubles())); ASSERT_EQ(op->schema().arguments()[3].name(), "tensors"); ASSERT_TRUE( op->schema().arguments()[3].type()->isSubtypeOf(*ListType::ofTensors())); ASSERT_EQ(op->schema().returns().size(), 1); ASSERT_TRUE( op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofFloats())); Stack stack; push(stack, c10::List({1, 2})); push(stack, c10::List({1.0, 2.0})); push( stack, c10::List>( {c10::complex(2.4, -5.5), c10::complex(-1.3, 2)})); push(stack, c10::List({at::ones(5)})); op->getOperation()(stack); c10::List output; pop(stack, output); ASSERT_EQ(output.size(), 2); ASSERT_EQ(output.get(0), 1.0); ASSERT_EQ(output.get(1), 2.0); } TEST(CustomOperatorTest, ListParameters2) { torch::RegisterOperators reg( "foo::lists2(Tensor[] tensors) -> Tensor[]", [](torch::List tensors) { return tensors; }); auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); ASSERT_EQ(op->schema().name(), "foo::lists2"); ASSERT_EQ(op->schema().arguments().size(), 1); ASSERT_EQ(op->schema().arguments()[0].name(), "tensors"); ASSERT_TRUE( op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofTensors())); ASSERT_EQ(op->schema().returns().size(), 1); ASSERT_TRUE( op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofTensors())); Stack stack; push(stack, c10::List({at::ones(5)})); op->getOperation()(stack); c10::List output; pop(stack, output); ASSERT_EQ(output.size(), 1); ASSERT_TRUE(output.get(0).allclose(at::ones(5))); } TEST(CustomOperatorTest, Aliasing) { torch::RegisterOperators reg( "foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor { a.add_(b); return a; }); getAllOperatorsFor(Symbol::fromQualString("foo::aliasing")); { auto graph = std::make_shared(); parseIR( R"IR( graph(%x: Tensor, %y: Tensor): %ret : Tensor = foo::aliasing(%x, %y) return (%ret) )IR", graph.get()); auto opNode = *graph->block()->nodes().begin(); AliasDb aliasDb(graph); for (const auto input : opNode->inputs()) { // The custom op writes to all its inputs ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input})); // The output should be a wildcard and thus alias all inputs ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input)); } } { // DCE should not remove a custom op auto graph = std::make_shared(); const auto text = R"IR( graph(%x: Tensor, %y: Tensor): # CHECK: foo::aliasing %ret : Tensor = foo::aliasing(%x, %y) return (%x) )IR"; parseIR(text, graph.get()); EliminateDeadCode(graph); testing::FileCheck().run(text, *graph); } } // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) static constexpr char op_list[] = "foofoo::bar.template;foo::another"; #define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \ torch::detail::SelectiveStr(n) TEST(TestCustomOperator, OperatorGeneratorUndeclared) { // Try to register an op name that does not exist in op_list. // Expected: the op name is not registered. torch::jit::RegisterOperators reg({OperatorGenerator( TORCH_SELECTIVE_NAME_IN_SCHEMA( op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"), [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; at::Tensor b; pop(stack, a, b); push(stack, a + b); }, aliasAnalysisFromSchema())}); auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); ASSERT_EQ(ops.size(), 0); } TEST(TestCustomOperator, OperatorGeneratorBasic) { // The operator should be successfully registered since its name is in the // whitelist. torch::jit::RegisterOperators reg({OperatorGenerator( TORCH_SELECTIVE_NAME_IN_SCHEMA( op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"), [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; at::Tensor b; pop(stack, a, b); push(stack, a + b); }, aliasAnalysisFromSchema())}); auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); ASSERT_EQ(op->schema().name(), "foofoo::bar"); ASSERT_EQ(op->schema().arguments().size(), 2); ASSERT_EQ(op->schema().arguments()[0].name(), "a"); ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); ASSERT_EQ(op->schema().arguments()[1].name(), "b"); ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); ASSERT_EQ(op->schema().returns().size(), 1); ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); Stack stack; push(stack, 2.0f, at::ones(5)); op->getOperation()(stack); at::Tensor output; pop(stack, output); ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); } } // namespace jit } // namespace torch