From 8226330af3bdf6e8a1cb9dcb0887ad189f9630e7 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Wed, 15 May 2019 14:47:13 -0700 Subject: [PATCH] Extend testAvailableArgTypes (#20374) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20374 This test case now also tests that the argument type works correctly in kernels that - don't return outputs - return multiple outputs Reviewed By: li-roy Differential Revision: D15298233 fbshipit-source-id: 82ab9d81b55b4f9fb34d66a155cc426af8592e25 --- .../op_registration/op_registration_test.cpp | 114 +++++++++++------- 1 file changed, 71 insertions(+), 43 deletions(-) diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 04bc044dd30c..b44838a78eb3 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -9,6 +9,7 @@ #include #include +#include using c10::RegisterOperators; using c10::OperatorKernel; @@ -414,151 +415,178 @@ struct ArgTypeTestKernel final : OperatorKernel { return output_; } - static void test(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema) { - // test with explicitly specified schema - test_(input, inputExpectation, output, outputExpectation, schema); - - // test with inferred schema - test_(input, inputExpectation, output, outputExpectation, ""); - } - -private: - static void test_(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema = "") { + static void test(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema) { auto registry = c10::RegisterOperators().op("_test::my_op" + schema, kernel(input, std::move(inputExpectation), std::move(output))); auto op = Dispatcher::singleton().findSchema("_test::my_op", ""); ASSERT_TRUE(op.has_value()); // assert schema is registered auto actualOutput = callOp(*op, std::move(input)); - EXPECT_EQ(1, actualOutput.size()); - outputExpectation(actualOutput[0]); + outputExpectation(actualOutput); } +private: + InputType input_; std::function inputExpectation_; OutputType output_; std::string schema_; }; +template +struct testArgTypes final { + static void test(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema) { + // Test with explicitly specified schema + ArgTypeTestKernel::test( + input, inputExpectation, output, [&] (const c10::Stack& output) { + EXPECT_EQ(1, output.size()); + outputExpectation(output[0]); + }, schema + ); + + // Test with inferred schema + ArgTypeTestKernel::test( + input, inputExpectation, output, [&] (const c10::Stack& output) { + EXPECT_EQ(1, output.size()); + outputExpectation(output[0]); + }, "" + ); + + // Test taking argument and returning nothing + ArgTypeTestKernel>::test( + input, inputExpectation, {}, [] (const c10::Stack&) {}, "" + ); + + // Test taking argument and returning multiple outputs + ArgTypeTestKernel>::test( + input, inputExpectation, std::tuple{3, output}, [&] (const c10::Stack& output) { + EXPECT_EQ(2, output.size()); + EXPECT_EQ(3, output[0].toInt()); + outputExpectation(output[1]); + }, "" + ); + } +}; + TEST(OperatorRegistrationTest, testAvailableArgTypes) { // TODO Test Scalar // primitive types - ArgTypeTestKernel::test( + testArgTypes::test( 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);}, 2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}, "(float a) -> float"); - ArgTypeTestKernel::test( + testArgTypes::test( 1, [] (const int64_t& v) {EXPECT_EQ(1, v);}, 2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}, "(int a) -> int"); - ArgTypeTestKernel::test( + testArgTypes::test( true, [] (const bool& v) {EXPECT_EQ(true, v);}, false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}, "(bool a) -> bool"); - ArgTypeTestKernel::test( + testArgTypes::test( false, [] (const bool& v) {EXPECT_EQ(false, v);}, true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}, "(bool a) -> bool"); - ArgTypeTestKernel::test( + testArgTypes::test( "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);}, "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str a) -> str"); - ArgTypeTestKernel::test( + testArgTypes::test( dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());}, dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}, "(Tensor a) -> Tensor"); // optional types (with has_value() == true) - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(1.5), [] (const c10::optional& v) {EXPECT_EQ(1.5, v.value());}, c10::optional(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}, "(float? a) -> float?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(1), [] (const c10::optional& v) {EXPECT_EQ(1, v.value());}, c10::optional(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}, "(int? a) -> int?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(true), [] (const c10::optional& v) {EXPECT_EQ(true, v.value());}, c10::optional(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}, "(bool? a) -> bool?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(false), [] (const c10::optional& v) {EXPECT_EQ(false, v.value());}, c10::optional(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}, "(bool? a) -> bool?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional("string1"), [] (const c10::optional& v) {EXPECT_EQ("string1", v.value());}, c10::optional("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str? a) -> str?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(dummyTensor(TensorType1())), [] (const c10::optional& v) {EXPECT_EQ(TensorType1(), v.value().type_id());}, c10::optional(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}, "(Tensor? a) -> Tensor?"); // optional types (with has_value() == false) - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(float? a) -> float?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(int? a) -> int?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(bool? a) -> bool?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(bool? a) -> bool?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(str? a) -> str?"); - ArgTypeTestKernel>::test( + testArgTypes>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(Tensor? a) -> Tensor?"); // list types (with empty list) - ArgTypeTestKernel, std::vector>::test( + testArgTypes, std::vector>::test( c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());}, "(float[] a) -> float[]"); - ArgTypeTestKernel, std::vector>::test( + testArgTypes, std::vector>::test( c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}, "(int[] a) -> int[]"); // TODO Converting std::vector to ArrayRef doesn't work, so we // need to find an alternative - // ArgTypeTestKernel, std::vector>::test( + // testArgTypes, std::vector>::test( // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}, // "(bool[] a) -> bool[]"); - // ArgTypeTestKernel, std::vector>::test( + // testArgTypes, std::vector>::test( // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}, // "(bool[] a) -> bool[]"); // TODO We currently don't support str[] (i.e. string list) as type. Do we want to? - // ArgTypeTestKernel, std::vector>::test( + // testArgTypes, std::vector>::test( // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());}, // "(str[] a) -> str[]"); // list types (with non-empty list) - ArgTypeTestKernel, std::vector>::test( + testArgTypes, std::vector>::test( c10::ArrayRef({1.5, 2.5}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1.5, 2.5}), v);}, std::vector({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector({3.5, 4.5}), v.toDoubleListRef());}, "(float[] a) -> float[]"); - ArgTypeTestKernel, std::vector>::test( + testArgTypes, std::vector>::test( c10::ArrayRef({1, 2}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v);}, std::vector({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}, "(int[] a) -> int[]"); // TODO When fixing bool[] and str[] (see above), also add them here - ArgTypeTestKernel, std::vector>::test( + testArgTypes, std::vector>::test( c10::ArrayRef({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef v) { EXPECT_EQ(2, v.size()); EXPECT_EQ(TensorType1(), v[0].type_id()); @@ -572,19 +600,19 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { "(Tensor[] a) -> Tensor[]"); // Test optional of list (with nullopt) - ArgTypeTestKernel>, c10::optional>>::test( + testArgTypes>, c10::optional>>::test( c10::optional>(c10::nullopt), [] (c10::optional> v) {EXPECT_FALSE(v.has_value());}, c10::optional>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(int[]? a) -> int[]?"); // Test optional of list (with empty list) - ArgTypeTestKernel>, c10::optional>>::test( + testArgTypes>, c10::optional>>::test( c10::optional>(c10::ArrayRef{}), [] (c10::optional> v) {EXPECT_EQ(0, v.value().size());}, c10::optional>(std::vector{}), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}, "(int[]? a) -> int[]?"); // Test optional of list (with values) - ArgTypeTestKernel>, c10::optional>>::test( + testArgTypes>, c10::optional>>::test( c10::optional>({1, 2}), [] (c10::optional> v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v.value());}, c10::optional>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}, "(int[]? a) -> int[]?"); @@ -595,7 +623,7 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { c10::Dict str_dict; str_dict.insert("key1", "value1"); str_dict.insert("key2", "value2"); - ArgTypeTestKernel>::test( + testArgTypes>::test( str_dict, [] (c10::Dict v) { EXPECT_EQ(2, v.size()); EXPECT_EQ("value1", v.at("key1")); @@ -611,7 +639,7 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { c10::Dict tensor_dict; tensor_dict.insert(1, dummyTensor(TensorType1())); tensor_dict.insert(2, dummyTensor(TensorType2())); - ArgTypeTestKernel>::test( + testArgTypes>::test( tensor_dict, [] (c10::Dict v) { EXPECT_EQ(2, v.size()); EXPECT_EQ(TensorType1(), v.at(1).type_id());