mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Extend testAvailableArgTypes (#20374)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20374 This test case now also tests that the argument type works correctly in kernels that - don't return outputs - return multiple outputs Reviewed By: li-roy Differential Revision: D15298233 fbshipit-source-id: 82ab9d81b55b4f9fb34d66a155cc426af8592e25
This commit is contained in:
committed by
Facebook Github Bot
parent
f89ab7b623
commit
8226330af3
@ -9,6 +9,7 @@
|
||||
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <functional>
|
||||
|
||||
using c10::RegisterOperators;
|
||||
using c10::OperatorKernel;
|
||||
@ -414,151 +415,178 @@ struct ArgTypeTestKernel final : OperatorKernel {
|
||||
return output_;
|
||||
}
|
||||
|
||||
static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema) {
|
||||
// test with explicitly specified schema
|
||||
test_(input, inputExpectation, output, outputExpectation, schema);
|
||||
|
||||
// test with inferred schema
|
||||
test_(input, inputExpectation, output, outputExpectation, "");
|
||||
}
|
||||
|
||||
private:
|
||||
static void test_(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema = "") {
|
||||
static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
|
||||
auto registry = c10::RegisterOperators().op("_test::my_op" + schema, kernel<ArgTypeTestKernel>(input, std::move(inputExpectation), std::move(output)));
|
||||
auto op = Dispatcher::singleton().findSchema("_test::my_op", "");
|
||||
ASSERT_TRUE(op.has_value()); // assert schema is registered
|
||||
auto actualOutput = callOp(*op, std::move(input));
|
||||
EXPECT_EQ(1, actualOutput.size());
|
||||
outputExpectation(actualOutput[0]);
|
||||
outputExpectation(actualOutput);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
InputType input_;
|
||||
std::function<void(const InputType&)> inputExpectation_;
|
||||
OutputType output_;
|
||||
std::string schema_;
|
||||
};
|
||||
|
||||
template<class InputType, class OutputType = InputType>
|
||||
struct testArgTypes final {
|
||||
static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema) {
|
||||
// Test with explicitly specified schema
|
||||
ArgTypeTestKernel<InputType, OutputType>::test(
|
||||
input, inputExpectation, output, [&] (const c10::Stack& output) {
|
||||
EXPECT_EQ(1, output.size());
|
||||
outputExpectation(output[0]);
|
||||
}, schema
|
||||
);
|
||||
|
||||
// Test with inferred schema
|
||||
ArgTypeTestKernel<InputType, OutputType>::test(
|
||||
input, inputExpectation, output, [&] (const c10::Stack& output) {
|
||||
EXPECT_EQ(1, output.size());
|
||||
outputExpectation(output[0]);
|
||||
}, ""
|
||||
);
|
||||
|
||||
// Test taking argument and returning nothing
|
||||
ArgTypeTestKernel<InputType, std::tuple<>>::test(
|
||||
input, inputExpectation, {}, [] (const c10::Stack&) {}, ""
|
||||
);
|
||||
|
||||
// Test taking argument and returning multiple outputs
|
||||
ArgTypeTestKernel<InputType, std::tuple<int64_t, OutputType>>::test(
|
||||
input, inputExpectation, std::tuple<int64_t, OutputType>{3, output}, [&] (const c10::Stack& output) {
|
||||
EXPECT_EQ(2, output.size());
|
||||
EXPECT_EQ(3, output[0].toInt());
|
||||
outputExpectation(output[1]);
|
||||
}, ""
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(OperatorRegistrationTest, testAvailableArgTypes) {
|
||||
// TODO Test Scalar
|
||||
|
||||
// primitive types
|
||||
ArgTypeTestKernel<double>::test(
|
||||
testArgTypes<double>::test(
|
||||
1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
|
||||
2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
|
||||
"(float a) -> float");
|
||||
ArgTypeTestKernel<int64_t>::test(
|
||||
testArgTypes<int64_t>::test(
|
||||
1, [] (const int64_t& v) {EXPECT_EQ(1, v);},
|
||||
2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
|
||||
"(int a) -> int");
|
||||
ArgTypeTestKernel<bool>::test(
|
||||
testArgTypes<bool>::test(
|
||||
true, [] (const bool& v) {EXPECT_EQ(true, v);},
|
||||
false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
|
||||
"(bool a) -> bool");
|
||||
ArgTypeTestKernel<bool>::test(
|
||||
testArgTypes<bool>::test(
|
||||
false, [] (const bool& v) {EXPECT_EQ(false, v);},
|
||||
true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
|
||||
"(bool a) -> bool");
|
||||
ArgTypeTestKernel<std::string>::test(
|
||||
testArgTypes<std::string>::test(
|
||||
"string1", [] (const std::string& v) {EXPECT_EQ("string1", v);},
|
||||
"string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());},
|
||||
"(str a) -> str");
|
||||
ArgTypeTestKernel<Tensor>::test(
|
||||
testArgTypes<Tensor>::test(
|
||||
dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());},
|
||||
dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());},
|
||||
"(Tensor a) -> Tensor");
|
||||
|
||||
|
||||
// optional types (with has_value() == true)
|
||||
ArgTypeTestKernel<c10::optional<double>>::test(
|
||||
testArgTypes<c10::optional<double>>::test(
|
||||
c10::optional<double>(1.5), [] (const c10::optional<double>& v) {EXPECT_EQ(1.5, v.value());},
|
||||
c10::optional<double>(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
|
||||
"(float? a) -> float?");
|
||||
ArgTypeTestKernel<c10::optional<int64_t>>::test(
|
||||
testArgTypes<c10::optional<int64_t>>::test(
|
||||
c10::optional<int64_t>(1), [] (const c10::optional<int64_t>& v) {EXPECT_EQ(1, v.value());},
|
||||
c10::optional<int64_t>(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
|
||||
"(int? a) -> int?");
|
||||
ArgTypeTestKernel<c10::optional<bool>>::test(
|
||||
testArgTypes<c10::optional<bool>>::test(
|
||||
c10::optional<bool>(true), [] (const c10::optional<bool>& v) {EXPECT_EQ(true, v.value());},
|
||||
c10::optional<bool>(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
|
||||
"(bool? a) -> bool?");
|
||||
ArgTypeTestKernel<c10::optional<bool>>::test(
|
||||
testArgTypes<c10::optional<bool>>::test(
|
||||
c10::optional<bool>(false), [] (const c10::optional<bool>& v) {EXPECT_EQ(false, v.value());},
|
||||
c10::optional<bool>(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
|
||||
"(bool? a) -> bool?");
|
||||
ArgTypeTestKernel<c10::optional<std::string>>::test(
|
||||
testArgTypes<c10::optional<std::string>>::test(
|
||||
c10::optional<std::string>("string1"), [] (const c10::optional<std::string>& v) {EXPECT_EQ("string1", v.value());},
|
||||
c10::optional<std::string>("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());},
|
||||
"(str? a) -> str?");
|
||||
ArgTypeTestKernel<c10::optional<Tensor>>::test(
|
||||
testArgTypes<c10::optional<Tensor>>::test(
|
||||
c10::optional<Tensor>(dummyTensor(TensorType1())), [] (const c10::optional<Tensor>& v) {EXPECT_EQ(TensorType1(), v.value().type_id());},
|
||||
c10::optional<Tensor>(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());},
|
||||
"(Tensor? a) -> Tensor?");
|
||||
|
||||
|
||||
// optional types (with has_value() == false)
|
||||
ArgTypeTestKernel<c10::optional<double>>::test(
|
||||
testArgTypes<c10::optional<double>>::test(
|
||||
c10::optional<double>(), [] (const c10::optional<double>& v) {EXPECT_FALSE(v.has_value());},
|
||||
c10::optional<double>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
|
||||
"(float? a) -> float?");
|
||||
ArgTypeTestKernel<c10::optional<int64_t>>::test(
|
||||
testArgTypes<c10::optional<int64_t>>::test(
|
||||
c10::optional<int64_t>(), [] (const c10::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
|
||||
c10::optional<int64_t>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
|
||||
"(int? a) -> int?");
|
||||
ArgTypeTestKernel<c10::optional<bool>>::test(
|
||||
testArgTypes<c10::optional<bool>>::test(
|
||||
c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
|
||||
c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
|
||||
"(bool? a) -> bool?");
|
||||
ArgTypeTestKernel<c10::optional<bool>>::test(
|
||||
testArgTypes<c10::optional<bool>>::test(
|
||||
c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
|
||||
c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
|
||||
"(bool? a) -> bool?");
|
||||
ArgTypeTestKernel<c10::optional<std::string>>::test(
|
||||
testArgTypes<c10::optional<std::string>>::test(
|
||||
c10::optional<std::string>(), [] (const c10::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
|
||||
c10::optional<std::string>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
|
||||
"(str? a) -> str?");
|
||||
ArgTypeTestKernel<c10::optional<Tensor>>::test(
|
||||
testArgTypes<c10::optional<Tensor>>::test(
|
||||
c10::optional<Tensor>(), [] (const c10::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
|
||||
c10::optional<Tensor>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
|
||||
"(Tensor? a) -> Tensor?");
|
||||
|
||||
|
||||
// list types (with empty list)
|
||||
ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
|
||||
testArgTypes<c10::ArrayRef<double>, std::vector<double>>::test(
|
||||
c10::ArrayRef<double>(), [] (c10::ArrayRef<double> v) {EXPECT_EQ(0, v.size());},
|
||||
std::vector<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());},
|
||||
"(float[] a) -> float[]");
|
||||
ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
|
||||
testArgTypes<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
|
||||
c10::ArrayRef<int64_t>(), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(0, v.size());},
|
||||
std::vector<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());},
|
||||
"(int[] a) -> int[]");
|
||||
// TODO Converting std::vector<bool> to ArrayRef<bool> doesn't work, so we
|
||||
// need to find an alternative
|
||||
// ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
|
||||
// testArgTypes<c10::ArrayRef<bool>, std::vector<bool>>::test(
|
||||
// c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
|
||||
// std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());},
|
||||
// "(bool[] a) -> bool[]");
|
||||
// ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
|
||||
// testArgTypes<c10::ArrayRef<bool>, std::vector<bool>>::test(
|
||||
// c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
|
||||
// std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());},
|
||||
// "(bool[] a) -> bool[]");
|
||||
// TODO We currently don't support str[] (i.e. string list) as type. Do we want to?
|
||||
// ArgTypeTestKernel<c10::ArrayRef<std::string>, std::vector<std::string>>::test(
|
||||
// testArgTypes<c10::ArrayRef<std::string>, std::vector<std::string>>::test(
|
||||
// c10::ArrayRef<std::string>(), [] (c10::ArrayRef<std::string> v) {EXPECT_EQ(0, v.size());},
|
||||
// std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());},
|
||||
// "(str[] a) -> str[]");
|
||||
|
||||
|
||||
// list types (with non-empty list)
|
||||
ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
|
||||
testArgTypes<c10::ArrayRef<double>, std::vector<double>>::test(
|
||||
c10::ArrayRef<double>({1.5, 2.5}), [] (c10::ArrayRef<double> v) {EXPECT_EQ(c10::ArrayRef<double>({1.5, 2.5}), v);},
|
||||
std::vector<double>({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector<double>({3.5, 4.5}), v.toDoubleListRef());},
|
||||
"(float[] a) -> float[]");
|
||||
ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
|
||||
testArgTypes<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
|
||||
c10::ArrayRef<int64_t>({1, 2}), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v);},
|
||||
std::vector<int64_t>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());},
|
||||
"(int[] a) -> int[]");
|
||||
// TODO When fixing bool[] and str[] (see above), also add them here
|
||||
ArgTypeTestKernel<c10::ArrayRef<Tensor>, std::vector<Tensor>>::test(
|
||||
testArgTypes<c10::ArrayRef<Tensor>, std::vector<Tensor>>::test(
|
||||
c10::ArrayRef<Tensor>({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef<Tensor> v) {
|
||||
EXPECT_EQ(2, v.size());
|
||||
EXPECT_EQ(TensorType1(), v[0].type_id());
|
||||
@ -572,19 +600,19 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
|
||||
"(Tensor[] a) -> Tensor[]");
|
||||
|
||||
// Test optional of list (with nullopt)
|
||||
ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
|
||||
testArgTypes<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
|
||||
c10::optional<c10::ArrayRef<int64_t>>(c10::nullopt), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_FALSE(v.has_value());},
|
||||
c10::optional<std::vector<int64_t>>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
|
||||
"(int[]? a) -> int[]?");
|
||||
|
||||
// Test optional of list (with empty list)
|
||||
ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
|
||||
testArgTypes<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
|
||||
c10::optional<c10::ArrayRef<int64_t>>(c10::ArrayRef<int64_t>{}), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_EQ(0, v.value().size());},
|
||||
c10::optional<std::vector<int64_t>>(std::vector<int64_t>{}), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());},
|
||||
"(int[]? a) -> int[]?");
|
||||
|
||||
// Test optional of list (with values)
|
||||
ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
|
||||
testArgTypes<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
|
||||
c10::optional<c10::ArrayRef<int64_t>>({1, 2}), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v.value());},
|
||||
c10::optional<std::vector<int64_t>>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());},
|
||||
"(int[]? a) -> int[]?");
|
||||
@ -595,7 +623,7 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
|
||||
c10::Dict<std::string, std::string> str_dict;
|
||||
str_dict.insert("key1", "value1");
|
||||
str_dict.insert("key2", "value2");
|
||||
ArgTypeTestKernel<c10::Dict<std::string, std::string>>::test(
|
||||
testArgTypes<c10::Dict<std::string, std::string>>::test(
|
||||
str_dict, [] (c10::Dict<std::string, std::string> v) {
|
||||
EXPECT_EQ(2, v.size());
|
||||
EXPECT_EQ("value1", v.at("key1"));
|
||||
@ -611,7 +639,7 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
|
||||
c10::Dict<int64_t, Tensor> tensor_dict;
|
||||
tensor_dict.insert(1, dummyTensor(TensorType1()));
|
||||
tensor_dict.insert(2, dummyTensor(TensorType2()));
|
||||
ArgTypeTestKernel<c10::Dict<int64_t, Tensor>>::test(
|
||||
testArgTypes<c10::Dict<int64_t, Tensor>>::test(
|
||||
tensor_dict, [] (c10::Dict<int64_t, Tensor> v) {
|
||||
EXPECT_EQ(2, v.size());
|
||||
EXPECT_EQ(TensorType1(), v.at(1).type_id());
|
||||
|
Reference in New Issue
Block a user