mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduce lambda-based kernel API (#18541)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18541 Allow registering lambdas as c10 kernels. Reviewed By: dzhulgakov Differential Revision: D14653005 fbshipit-source-id: f867cc776b1339e83b7a2e1935f5cf924cfba44a
This commit is contained in:
committed by
Facebook Github Bot
parent
24752eb7b8
commit
f4e87e193a
@ -138,6 +138,18 @@ namespace detail {
|
||||
return guts::make_unique<FunctionSchema>(inferFunctionSchema<KernelFunctor>("", ""));
|
||||
}
|
||||
};
|
||||
|
||||
template<class KernelFunctor, class... ConstructorParameters>
|
||||
// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
|
||||
inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
|
||||
detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
|
||||
kernelFunctor(ConstructorParameters&&... constructorParameters) {
|
||||
return {
|
||||
&detail::wrap_kernel_functor<KernelFunctor>::call,
|
||||
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
|
||||
detail::FunctionSchemaInferer<KernelFunctor>()
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -181,11 +193,7 @@ template<class KernelFunctor, class... ConstructorParameters>
|
||||
inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
|
||||
detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
|
||||
kernel(ConstructorParameters&&... constructorParameters) {
|
||||
return {
|
||||
&detail::wrap_kernel_functor<KernelFunctor>::call,
|
||||
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
|
||||
detail::FunctionSchemaInferer<KernelFunctor>()
|
||||
};
|
||||
return detail::kernelFunctor<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...);
|
||||
}
|
||||
|
||||
}
|
||||
|
60
aten/src/ATen/core/op_registration/kernel_lambda.h
Normal file
60
aten/src/ATen/core/op_registration/kernel_lambda.h
Normal file
@ -0,0 +1,60 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/op_registration/kernel_functor.h>
|
||||
#include <c10/util/TypeTraits.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
// WrapRuntimeKernelFunctor: Wraps any runtime functor into a functor that
|
||||
// inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
|
||||
// This can, for example, be used for lamdas, functors or even function pointers.
|
||||
// In the case of function pointers, since it is a runtime function pointer,
|
||||
// there is an overhead for calling it whenever the kernel is invoked.
|
||||
template<class FuncType, class ReturnType, class ParameterList> class WrapRuntimeKernelFunctor_ {};
|
||||
template<class FuncType, class ReturnType, class... Parameters>
|
||||
class WrapRuntimeKernelFunctor_<FuncType, ReturnType, guts::typelist::typelist<Parameters...>> final : public c10::OperatorKernel {
|
||||
public:
|
||||
template<class FuncType_>
|
||||
explicit WrapRuntimeKernelFunctor_(FuncType_&& kernel_func)
|
||||
: kernel_func_(std::forward<FuncType_>(kernel_func)) {}
|
||||
|
||||
auto operator()(Parameters&&... args) -> decltype(std::declval<FuncType>()(std::forward<Parameters>(args)...)) {
|
||||
return kernel_func_(std::forward<Parameters>(args)...);
|
||||
}
|
||||
|
||||
private:
|
||||
FuncType kernel_func_;
|
||||
};
|
||||
template<class FuncType>
|
||||
using WrapRuntimeKernelFunctor = WrapRuntimeKernelFunctor_<
|
||||
FuncType,
|
||||
typename guts::infer_function_traits_t<FuncType>::return_type,
|
||||
typename guts::infer_function_traits_t<FuncType>::parameter_types
|
||||
>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use this to register an operator whose kernel is implemented as a stateless lambda.
|
||||
*
|
||||
* Example:
|
||||
*
|
||||
* > static auto registry = c10::RegisterOperators()
|
||||
* > .op("my_op",
|
||||
* > c10::kernel([] (Tensor a) -> Tensor{...}),
|
||||
* > c10::dispatchKey(CPUTensorId()));
|
||||
*/
|
||||
template<class Lambda>
|
||||
inline constexpr auto kernel(Lambda&& functor) ->
|
||||
guts::enable_if_t<guts::is_stateless_lambda<guts::decay_t<Lambda>>::value,
|
||||
decltype(detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor)))> {
|
||||
// We don't support stateful lambdas (i.e. lambdas with a capture), because their
|
||||
// behavior would be nonobvious. A functor kernel with cache gets a new instance of
|
||||
// its cache each time the kernel is looked up from the dispatch table.
|
||||
// A lambda with a capture would be global and share its capture between all kernel lookups.
|
||||
// So, instead of making users having to think about it (including the thread-safety
|
||||
// issues this causes), let's just forbid stateful lambdas alltogether.
|
||||
return detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor));
|
||||
}
|
||||
|
||||
}
|
803
aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
Normal file
803
aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
Normal file
@ -0,0 +1,803 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <ATen/core/op_registration/test_helpers.h>
|
||||
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
using c10::RegisterOperators;
|
||||
using c10::FunctionSchema;
|
||||
using c10::Argument;
|
||||
using c10::IntType;
|
||||
using c10::FloatType;
|
||||
using c10::ListType;
|
||||
using c10::kernel;
|
||||
using c10::dispatchKey;
|
||||
using c10::TensorTypeId;
|
||||
using c10::KernelCache;
|
||||
using c10::Stack;
|
||||
using c10::guts::make_unique;
|
||||
using c10::ivalue::TensorList;
|
||||
using c10::ivalue::IntList;
|
||||
using c10::intrusive_ptr;
|
||||
using c10::ArrayRef;
|
||||
using std::unique_ptr;
|
||||
using at::Tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
C10_DECLARE_TENSOR_TYPE(TensorType1);
|
||||
C10_DEFINE_TENSOR_TYPE(TensorType1);
|
||||
C10_DECLARE_TENSOR_TYPE(TensorType2);
|
||||
C10_DEFINE_TENSOR_TYPE(TensorType2);
|
||||
|
||||
FunctionSchema errorOpSchema(
|
||||
"_test::error",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("output", IntType::get())}));
|
||||
|
||||
FunctionSchema opSchema(
|
||||
"_test::my_op",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("output", IntType::get())}));
|
||||
|
||||
void expectCallsIncrement(TensorTypeId type_id) {
|
||||
// assert that schema and cpu kernel are present
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
auto result = callOp(*op, dummyTensor(type_id), 5);
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(6, result[0].toInt());
|
||||
}
|
||||
|
||||
void expectCallsDecrement(TensorTypeId type_id) {
|
||||
// assert that schema and cpu kernel are present
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
auto result = callOp(*op, dummyTensor(type_id), 5);
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(4, result[0].toInt());
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
|
||||
expectCallsIncrement(TensorType1());
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()))
|
||||
.op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()))
|
||||
.op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()))
|
||||
.op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
|
||||
expectCallsIncrement(TensorType1());
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
|
||||
auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
|
||||
auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
|
||||
auto registrar3 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()));
|
||||
auto registrar4 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
|
||||
expectCallsIncrement(TensorType1());
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
|
||||
{
|
||||
auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
|
||||
{
|
||||
auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i-1;}), dispatchKey(TensorType2()));
|
||||
|
||||
// assert that schema and cpu kernel are present
|
||||
expectCallsIncrement(TensorType1());
|
||||
expectCallsDecrement(TensorType2());
|
||||
}
|
||||
|
||||
// now registrar2 is destructed. Assert that schema is still present but cpu kernel is not
|
||||
expectCallsIncrement(TensorType1());
|
||||
expectDoesntFindKernel("_test::my_op", TensorType2());
|
||||
}
|
||||
|
||||
// now both registrars are destructed. Assert that the whole schema is gone
|
||||
expectDoesntFindOperator("_test::my_op");
|
||||
}
|
||||
|
||||
bool was_called = false;
|
||||
|
||||
FunctionSchema opWithoutOutputSchema(
|
||||
"_test::no_return",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy")}),
|
||||
(std::vector<Argument>{}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators().op(opWithoutOutputSchema,
|
||||
kernel([] (const Tensor&) -> void {was_called = true;}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
was_called = false;
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_TRUE(was_called);
|
||||
EXPECT_EQ(0, result.size());
|
||||
}
|
||||
|
||||
FunctionSchema opWithZeroOutputsSchema(
|
||||
"_test::zero_outputs",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy")}),
|
||||
(std::vector<Argument>{}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators().op(opWithZeroOutputsSchema,
|
||||
kernel([] (const Tensor&) -> std::tuple<> {was_called = true; return {};}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
was_called = false;
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_TRUE(was_called);
|
||||
EXPECT_EQ(0, result.size());
|
||||
}
|
||||
|
||||
FunctionSchema opWithIntOutputSchema(
|
||||
"_test::int_output",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("a", IntType::get()),
|
||||
Argument("b", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("sum", IntType::get())}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithIntOutputSchema,
|
||||
kernel([] (Tensor, int64_t a, int64_t b) {return a+b;}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()), 3, 6);
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(9, result[0].toInt());
|
||||
}
|
||||
|
||||
FunctionSchema opWithTensorOutput(
|
||||
"_test::returning_tensor",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("input")}),
|
||||
(std::vector<Argument>{Argument("output")}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorOutput,
|
||||
kernel([] (const Tensor& a) {return a;}),
|
||||
dispatchKey(TensorType1()))
|
||||
.op(opWithTensorOutput,
|
||||
kernel([] (const Tensor& a) {return a;}),
|
||||
dispatchKey(TensorType2()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
|
||||
|
||||
result = callOp(*op, dummyTensor(TensorType2()));
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
|
||||
}
|
||||
|
||||
FunctionSchema opWithTensorListOutputSchema(
|
||||
"_test::list_output",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("input1"),
|
||||
Argument("input2"),
|
||||
Argument("input3")}),
|
||||
(std::vector<Argument>{Argument("output", ListType::ofTensors())}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorListOutputSchema,
|
||||
kernel([] (const Tensor& a, const Tensor& b, const Tensor& c) -> std::vector<Tensor> {return {a, b, c};}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), dummyTensor(TensorType1()));
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(3, result[0].toTensorListRef().size());
|
||||
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[0].type_id());
|
||||
EXPECT_EQ(TensorType2(), result[0].toTensorListRef()[1].type_id());
|
||||
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
|
||||
}
|
||||
|
||||
FunctionSchema opWithIntListOutputSchema(
|
||||
"_test::list_output",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input1", IntType::get()),
|
||||
Argument("input2", IntType::get()),
|
||||
Argument("input3", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("output", ListType::ofInts())}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithIntListOutputSchema,
|
||||
kernel([] (const Tensor&, int64_t a, int64_t b, int64_t c) -> std::vector<int64_t> {return {a,b,c};}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()), 2, 4, 6);
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(3, result[0].toIntListRef().size());
|
||||
EXPECT_EQ(2, result[0].toIntListRef()[0]);
|
||||
EXPECT_EQ(4, result[0].toIntListRef()[1]);
|
||||
EXPECT_EQ(6, result[0].toIntListRef()[2]);
|
||||
}
|
||||
|
||||
FunctionSchema opWithMultipleOutputsSchema(
|
||||
"_test::multiple_outputs",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy")}),
|
||||
(std::vector<Argument>{Argument("output1"),
|
||||
Argument("output2", IntType::get()),
|
||||
Argument("output3", ListType::ofTensors())}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithMultipleOutputsSchema,
|
||||
kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>> {
|
||||
return std::tuple<Tensor, int64_t, std::vector<Tensor>>(
|
||||
dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())}
|
||||
);
|
||||
}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_EQ(3, result.size());
|
||||
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
|
||||
EXPECT_EQ(5, result[1].toInt());
|
||||
EXPECT_EQ(2, result[2].toTensorListRef().size());
|
||||
EXPECT_EQ(TensorType1(), result[2].toTensorListRef()[0].type_id());
|
||||
EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id());
|
||||
}
|
||||
|
||||
FunctionSchema opWithTensorInputWithOutput(
|
||||
"_test::tensor_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("input")}),
|
||||
(std::vector<Argument>{Argument("output")}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorInputWithOutput,
|
||||
kernel([] (const Tensor& a) {return a;}),
|
||||
dispatchKey(TensorType1()))
|
||||
.op(opWithTensorInputWithOutput,
|
||||
kernel([] (const Tensor& a) {return a;}),
|
||||
dispatchKey(TensorType2()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
|
||||
|
||||
result = callOp(*op, dummyTensor(TensorType2()));
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorInputWithOutput,
|
||||
kernel([] (Tensor a) {return a;}),
|
||||
dispatchKey(TensorType1()))
|
||||
.op(opWithTensorInputWithOutput,
|
||||
kernel([] (Tensor a) {return a;}),
|
||||
dispatchKey(TensorType2()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto result = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
|
||||
|
||||
result = callOp(*op, dummyTensor(TensorType2()));
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
|
||||
}
|
||||
|
||||
Tensor captured_input;
|
||||
|
||||
FunctionSchema opWithTensorInputWithoutOutput(
|
||||
"_test::tensor_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("input")}),
|
||||
(std::vector<Argument>{}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorInputWithoutOutput,
|
||||
kernel([] (const Tensor& a) -> void {captured_input = a;}),
|
||||
dispatchKey(TensorType1()))
|
||||
.op(opWithTensorInputWithoutOutput,
|
||||
kernel([] (const Tensor& a) -> void {captured_input = a;}),
|
||||
dispatchKey(TensorType2()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto outputs = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
EXPECT_EQ(TensorType1(), captured_input.type_id());
|
||||
|
||||
outputs = callOp(*op, dummyTensor(TensorType2()));
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
EXPECT_EQ(TensorType2(), captured_input.type_id());
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorInputWithoutOutput,
|
||||
kernel([] (Tensor a) -> void {captured_input = a;}),
|
||||
dispatchKey(TensorType1()))
|
||||
.op(opWithTensorInputWithoutOutput,
|
||||
kernel([] (Tensor a) -> void {captured_input = a;}),
|
||||
dispatchKey(TensorType2()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto outputs = callOp(*op, dummyTensor(TensorType1()));
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
EXPECT_EQ(TensorType1(), captured_input.type_id());
|
||||
|
||||
outputs = callOp(*op, dummyTensor(TensorType2()));
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
EXPECT_EQ(TensorType2(), captured_input.type_id());
|
||||
}
|
||||
|
||||
int64_t captured_int_input = 0;
|
||||
|
||||
FunctionSchema opWithIntInputWithoutOutput(
|
||||
"_test::int_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", IntType::get())}),
|
||||
(std::vector<Argument>{}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithIntInputWithoutOutput,
|
||||
kernel([] (Tensor, int64_t a) -> void {captured_int_input = a;}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_int_input = 0;
|
||||
auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
EXPECT_EQ(3, captured_int_input);
|
||||
}
|
||||
|
||||
FunctionSchema opWithIntInputWithOutput(
|
||||
"_test::int_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("output", IntType::get())}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithIntInputWithOutput,
|
||||
kernel([] (Tensor, int64_t a) {return a + 1;}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(4, outputs[0].toInt());
|
||||
}
|
||||
|
||||
int64_t captured_input_list_size = 0;
|
||||
|
||||
FunctionSchema opWithIntListInputWithoutOutput(
|
||||
"_test::int_list_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", ListType::ofInts())}),
|
||||
(std::vector<Argument>{}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithIntListInputWithoutOutput,
|
||||
kernel([] (Tensor, ArrayRef<int64_t> a) {captured_input_list_size = a.size();}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_input_list_size = 0;
|
||||
auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
EXPECT_EQ(3, captured_input_list_size);
|
||||
}
|
||||
|
||||
FunctionSchema opWithIntListInputWithOutput(
|
||||
"_test::int_list_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", ListType::ofInts())}),
|
||||
(std::vector<Argument>{Argument("output", IntType::get())}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithIntListInputWithOutput,
|
||||
kernel([] (Tensor, ArrayRef<int64_t> a) -> int64_t {return a.size();}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(3, outputs[0].toInt());
|
||||
}
|
||||
|
||||
FunctionSchema opWithTensorListInputWithoutOutput(
|
||||
"_test::tensor_list_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("input", ListType::ofTensors())}),
|
||||
(std::vector<Argument>{}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorListInputWithoutOutput,
|
||||
kernel([] (ArrayRef<Tensor> a) -> void {captured_input_list_size = a.size();}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
captured_input_list_size = 0;
|
||||
auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
|
||||
EXPECT_EQ(0, outputs.size());
|
||||
EXPECT_EQ(2, captured_input_list_size);
|
||||
}
|
||||
|
||||
FunctionSchema opWithTensorListInputWithOutput(
|
||||
"_test::tensor_list_input",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("input", ListType::ofTensors())}),
|
||||
(std::vector<Argument>{Argument("output", IntType::get())}));
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
|
||||
auto registrar = RegisterOperators()
|
||||
.op(opWithTensorListInputWithOutput,
|
||||
kernel([] (ArrayRef<Tensor> a) -> int64_t {return a.size();}),
|
||||
dispatchKey(TensorType1()));
|
||||
|
||||
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(2, outputs[0].toInt());
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1", IntType::get()),
|
||||
Argument("ret2", IntType::get())})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
|
||||
), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2")})
|
||||
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1")})
|
||||
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
|
||||
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", FloatType::get())})
|
||||
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", FloatType::get())})
|
||||
), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
|
||||
), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
|
||||
), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
|
||||
), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
}
|
@ -10,6 +10,7 @@
|
||||
#include <ATen/core/op_registration/kernel_stackbased.h>
|
||||
#include <ATen/core/op_registration/kernel_functor.h>
|
||||
#include <ATen/core/op_registration/kernel_function.h>
|
||||
#include <ATen/core/op_registration/kernel_lambda.h>
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
|
||||
namespace c10 {
|
||||
|
@ -112,3 +112,44 @@ namespace test_is_type_condition {
|
||||
static_assert(!is_type_condition<NotATypeCondition>::value, "");
|
||||
}
|
||||
}
|
||||
|
||||
namespace test_lambda_is_stateless {
|
||||
template<class Result, class... Args>
|
||||
struct MyStatelessFunctor final {
|
||||
Result operator()(Args...) {}
|
||||
};
|
||||
|
||||
template<class Result, class... Args>
|
||||
struct MyStatelessConstFunctor final {
|
||||
Result operator()(Args...) const {}
|
||||
};
|
||||
|
||||
void func() {
|
||||
auto stateless_lambda = [] (int a) {return a;};
|
||||
static_assert(is_stateless_lambda<decltype(stateless_lambda)>::value, "");
|
||||
|
||||
int b = 4;
|
||||
auto stateful_lambda_1 = [&] (int a) {return a + b;};
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_1)>::value, "");
|
||||
|
||||
auto stateful_lambda_2 = [=] (int a) {return a + b;};
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_2)>::value, "");
|
||||
|
||||
auto stateful_lambda_3 = [b] (int a) {return a + b;};
|
||||
static_assert(!is_stateless_lambda<decltype(stateful_lambda_3)>::value, "");
|
||||
|
||||
static_assert(!is_stateless_lambda<MyStatelessFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(!is_stateless_lambda<MyStatelessFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
|
||||
|
||||
class Dummy final {};
|
||||
static_assert(!is_stateless_lambda<Dummy>::value, "A non-functor type is also not a lambda");
|
||||
|
||||
static_assert(!is_stateless_lambda<int>::value, "An int is not a lambda");
|
||||
|
||||
using Func = int(int);
|
||||
static_assert(!is_stateless_lambda<Func>::value, "A function is not a lambda");
|
||||
static_assert(!is_stateless_lambda<Func*>::value, "A function pointer is not a lambda");
|
||||
}
|
||||
}
|
||||
|
@ -7,24 +7,6 @@
|
||||
#include <c10/util/Array.h>
|
||||
|
||||
namespace c10 { namespace guts {
|
||||
namespace detail {
|
||||
/**
|
||||
* strip_class: helper to remove the class type from pointers to `operator()`.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
struct strip_class {};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...)> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...) const> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename T>
|
||||
using strip_class_t = typename strip_class<T>::type;
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* Access information about result type or arguments from a function type.
|
||||
@ -43,16 +25,6 @@ struct function_traits<Result (Args...)> {
|
||||
static constexpr auto number_of_parameters = sizeof...(Args);
|
||||
};
|
||||
|
||||
/**
|
||||
* Evaluates to true_type, iff the given class is a Functor
|
||||
* (i.e. has a call operator with some set of arguments)
|
||||
*/
|
||||
|
||||
template<class Functor, class Enable = void>
|
||||
struct is_functor : std::false_type {};
|
||||
template<class Functor>
|
||||
struct is_functor<Functor, guts::enable_if_t<is_function_type<detail::strip_class_t<decltype(&Functor::operator())>>::value>> : std::true_type {};
|
||||
|
||||
/**
|
||||
* infer_function_traits: creates a `function_traits` type for a simple
|
||||
* function (pointer) or functor (lambda/struct). Currently does not support
|
||||
|
@ -49,6 +49,64 @@ template <template <class...> class Template, class... Args>
|
||||
struct is_instantiation_of<Template, Template<Args...>> : std::true_type {};
|
||||
template<template<class...> class Template, class T> using is_instantiation_of_t = typename is_instantiation_of<Template, T>::type;
|
||||
|
||||
namespace detail {
|
||||
/**
|
||||
* strip_class: helper to remove the class type from pointers to `operator()`.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
struct strip_class {};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...)> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...) const> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename T>
|
||||
using strip_class_t = typename strip_class<T>::type;
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* Evaluates to true_type, iff the given class is a Functor
|
||||
* (i.e. has a call operator with some set of arguments)
|
||||
*/
|
||||
|
||||
template<class Functor, class Enable = void>
|
||||
struct is_functor : std::false_type {};
|
||||
template<class Functor>
|
||||
struct is_functor<Functor, guts::enable_if_t<is_function_type<detail::strip_class_t<decltype(&Functor::operator())>>::value>> : std::true_type {};
|
||||
|
||||
|
||||
/**
|
||||
* lambda_is_stateless<T> is true iff the lambda type T is stateless
|
||||
* (i.e. does not have a closure).
|
||||
* Example:
|
||||
* auto stateless_lambda = [] (int a) {return a;};
|
||||
* lambda_is_stateless<decltype(stateless_lambda)> // true
|
||||
* auto stateful_lambda = [&] (int a) {return a;};
|
||||
* lambda_is_stateless<decltype(stateful_lambda)> // false
|
||||
*/
|
||||
namespace detail {
|
||||
template<class LambdaType, class FuncType> struct is_stateless_lambda__ final {
|
||||
static_assert(!std::is_same<LambdaType, LambdaType>::value, "Base case shouldn't be hit");
|
||||
};
|
||||
// implementation idea: According to the C++ standard, stateless lambdas are convertible to function pointers
|
||||
template<class LambdaType, class C, class Result, class... Args>
|
||||
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...) const> : std::is_convertible<LambdaType, Result(*)(Args...)> {};
|
||||
template<class LambdaType, class C, class Result, class... Args>
|
||||
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...)> : std::is_convertible<LambdaType, Result(*)(Args...)> {};
|
||||
|
||||
// case where LambdaType is not even a functor
|
||||
template<class LambdaType, class Enable = void> struct is_stateless_lambda_ final : std::false_type {};
|
||||
// case where LambdaType is a functor
|
||||
template<class LambdaType> struct is_stateless_lambda_<LambdaType, guts::enable_if_t<is_functor<LambdaType>::value>>
|
||||
: is_stateless_lambda__<LambdaType, decltype(&LambdaType::operator())> {};
|
||||
}
|
||||
template<class T>
|
||||
using is_stateless_lambda = detail::is_stateless_lambda_<guts::decay_t<T>>;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user