Introduce lambda-based kernel API (#18541)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18541

Allow registering lambdas as c10 kernels.

Reviewed By: dzhulgakov

Differential Revision: D14653005

fbshipit-source-id: f867cc776b1339e83b7a2e1935f5cf924cfba44a
This commit is contained in:
Sebastian Messmer
2019-03-30 00:03:46 -07:00
committed by Facebook Github Bot
parent 24752eb7b8
commit f4e87e193a
7 changed files with 976 additions and 33 deletions

View File

@ -138,6 +138,18 @@ namespace detail {
return guts::make_unique<FunctionSchema>(inferFunctionSchema<KernelFunctor>("", ""));
}
};
template<class KernelFunctor, class... ConstructorParameters>
// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
kernelFunctor(ConstructorParameters&&... constructorParameters) {
return {
&detail::wrap_kernel_functor<KernelFunctor>::call,
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
detail::FunctionSchemaInferer<KernelFunctor>()
};
}
}
/**
@ -181,11 +193,7 @@ template<class KernelFunctor, class... ConstructorParameters>
inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
kernel(ConstructorParameters&&... constructorParameters) {
return {
&detail::wrap_kernel_functor<KernelFunctor>::call,
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
detail::FunctionSchemaInferer<KernelFunctor>()
};
return detail::kernelFunctor<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...);
}
}

View File

@ -0,0 +1,60 @@
#pragma once
#include <ATen/core/op_registration/kernel_functor.h>
#include <c10/util/TypeTraits.h>
namespace c10 {
namespace detail {
// WrapRuntimeKernelFunctor: Wraps any runtime functor into a functor that
// inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
// This can, for example, be used for lamdas, functors or even function pointers.
// In the case of function pointers, since it is a runtime function pointer,
// there is an overhead for calling it whenever the kernel is invoked.
template<class FuncType, class ReturnType, class ParameterList> class WrapRuntimeKernelFunctor_ {};
template<class FuncType, class ReturnType, class... Parameters>
class WrapRuntimeKernelFunctor_<FuncType, ReturnType, guts::typelist::typelist<Parameters...>> final : public c10::OperatorKernel {
public:
template<class FuncType_>
explicit WrapRuntimeKernelFunctor_(FuncType_&& kernel_func)
: kernel_func_(std::forward<FuncType_>(kernel_func)) {}
auto operator()(Parameters&&... args) -> decltype(std::declval<FuncType>()(std::forward<Parameters>(args)...)) {
return kernel_func_(std::forward<Parameters>(args)...);
}
private:
FuncType kernel_func_;
};
template<class FuncType>
using WrapRuntimeKernelFunctor = WrapRuntimeKernelFunctor_<
FuncType,
typename guts::infer_function_traits_t<FuncType>::return_type,
typename guts::infer_function_traits_t<FuncType>::parameter_types
>;
}
/**
* Use this to register an operator whose kernel is implemented as a stateless lambda.
*
* Example:
*
* > static auto registry = c10::RegisterOperators()
* > .op("my_op",
* > c10::kernel([] (Tensor a) -> Tensor{...}),
* > c10::dispatchKey(CPUTensorId()));
*/
template<class Lambda>
inline constexpr auto kernel(Lambda&& functor) ->
guts::enable_if_t<guts::is_stateless_lambda<guts::decay_t<Lambda>>::value,
decltype(detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor)))> {
// We don't support stateful lambdas (i.e. lambdas with a capture), because their
// behavior would be nonobvious. A functor kernel with cache gets a new instance of
// its cache each time the kernel is looked up from the dispatch table.
// A lambda with a capture would be global and share its capture between all kernel lookups.
// So, instead of making users having to think about it (including the thread-safety
// issues this causes), let's just forbid stateful lambdas alltogether.
return detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor));
}
}

View File

@ -0,0 +1,803 @@
#include <gtest/gtest.h>
#include <ATen/core/op_registration/test_helpers.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/Tensor.h>
using c10::RegisterOperators;
using c10::FunctionSchema;
using c10::Argument;
using c10::IntType;
using c10::FloatType;
using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
using c10::TensorTypeId;
using c10::KernelCache;
using c10::Stack;
using c10::guts::make_unique;
using c10::ivalue::TensorList;
using c10::ivalue::IntList;
using c10::intrusive_ptr;
using c10::ArrayRef;
using std::unique_ptr;
using at::Tensor;
namespace {
C10_DECLARE_TENSOR_TYPE(TensorType1);
C10_DEFINE_TENSOR_TYPE(TensorType1);
C10_DECLARE_TENSOR_TYPE(TensorType2);
C10_DEFINE_TENSOR_TYPE(TensorType2);
FunctionSchema errorOpSchema(
"_test::error",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("input", IntType::get())}),
(std::vector<Argument>{Argument("output", IntType::get())}));
FunctionSchema opSchema(
"_test::my_op",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("input", IntType::get())}),
(std::vector<Argument>{Argument("output", IntType::get())}));
void expectCallsIncrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(type_id), 5);
EXPECT_EQ(1, result.size());
EXPECT_EQ(6, result[0].toInt());
}
void expectCallsDecrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(type_id), 5);
EXPECT_EQ(1, result.size());
EXPECT_EQ(4, result[0].toInt());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
.op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()))
.op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()))
.op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()))
.op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
auto registrar3 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()));
auto registrar4 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
{
auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
{
auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i-1;}), dispatchKey(TensorType2()));
// assert that schema and cpu kernel are present
expectCallsIncrement(TensorType1());
expectCallsDecrement(TensorType2());
}
// now registrar2 is destructed. Assert that schema is still present but cpu kernel is not
expectCallsIncrement(TensorType1());
expectDoesntFindKernel("_test::my_op", TensorType2());
}
// now both registrars are destructed. Assert that the whole schema is gone
expectDoesntFindOperator("_test::my_op");
}
bool was_called = false;
FunctionSchema opWithoutOutputSchema(
"_test::no_return",
"",
(std::vector<Argument>{Argument("dummy")}),
(std::vector<Argument>{}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators().op(opWithoutOutputSchema,
kernel([] (const Tensor&) -> void {was_called = true;}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
ASSERT_TRUE(op.has_value());
was_called = false;
auto result = callOp(*op, dummyTensor(TensorType1()));
EXPECT_TRUE(was_called);
EXPECT_EQ(0, result.size());
}
FunctionSchema opWithZeroOutputsSchema(
"_test::zero_outputs",
"",
(std::vector<Argument>{Argument("dummy")}),
(std::vector<Argument>{}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators().op(opWithZeroOutputsSchema,
kernel([] (const Tensor&) -> std::tuple<> {was_called = true; return {};}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
ASSERT_TRUE(op.has_value());
was_called = false;
auto result = callOp(*op, dummyTensor(TensorType1()));
EXPECT_TRUE(was_called);
EXPECT_EQ(0, result.size());
}
FunctionSchema opWithIntOutputSchema(
"_test::int_output",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("a", IntType::get()),
Argument("b", IntType::get())}),
(std::vector<Argument>{Argument("sum", IntType::get())}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithIntOutputSchema,
kernel([] (Tensor, int64_t a, int64_t b) {return a+b;}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(TensorType1()), 3, 6);
EXPECT_EQ(1, result.size());
EXPECT_EQ(9, result[0].toInt());
}
FunctionSchema opWithTensorOutput(
"_test::returning_tensor",
"",
(std::vector<Argument>{Argument("input")}),
(std::vector<Argument>{Argument("output")}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorOutput,
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType1()))
.op(opWithTensorOutput,
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(TensorType1()));
EXPECT_EQ(1, result.size());
EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
result = callOp(*op, dummyTensor(TensorType2()));
EXPECT_EQ(1, result.size());
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
}
FunctionSchema opWithTensorListOutputSchema(
"_test::list_output",
"",
(std::vector<Argument>{Argument("input1"),
Argument("input2"),
Argument("input3")}),
(std::vector<Argument>{Argument("output", ListType::ofTensors())}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorListOutputSchema,
kernel([] (const Tensor& a, const Tensor& b, const Tensor& c) -> std::vector<Tensor> {return {a, b, c};}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), dummyTensor(TensorType1()));
EXPECT_EQ(1, result.size());
EXPECT_EQ(3, result[0].toTensorListRef().size());
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[0].type_id());
EXPECT_EQ(TensorType2(), result[0].toTensorListRef()[1].type_id());
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
}
FunctionSchema opWithIntListOutputSchema(
"_test::list_output",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("input1", IntType::get()),
Argument("input2", IntType::get()),
Argument("input3", IntType::get())}),
(std::vector<Argument>{Argument("output", ListType::ofInts())}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithIntListOutputSchema,
kernel([] (const Tensor&, int64_t a, int64_t b, int64_t c) -> std::vector<int64_t> {return {a,b,c};}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(TensorType1()), 2, 4, 6);
EXPECT_EQ(1, result.size());
EXPECT_EQ(3, result[0].toIntListRef().size());
EXPECT_EQ(2, result[0].toIntListRef()[0]);
EXPECT_EQ(4, result[0].toIntListRef()[1]);
EXPECT_EQ(6, result[0].toIntListRef()[2]);
}
FunctionSchema opWithMultipleOutputsSchema(
"_test::multiple_outputs",
"",
(std::vector<Argument>{Argument("dummy")}),
(std::vector<Argument>{Argument("output1"),
Argument("output2", IntType::get()),
Argument("output3", ListType::ofTensors())}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithMultipleOutputsSchema,
kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>> {
return std::tuple<Tensor, int64_t, std::vector<Tensor>>(
dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())}
);
}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(TensorType1()));
EXPECT_EQ(3, result.size());
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
EXPECT_EQ(5, result[1].toInt());
EXPECT_EQ(2, result[2].toTensorListRef().size());
EXPECT_EQ(TensorType1(), result[2].toTensorListRef()[0].type_id());
EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id());
}
FunctionSchema opWithTensorInputWithOutput(
"_test::tensor_input",
"",
(std::vector<Argument>{Argument("input")}),
(std::vector<Argument>{Argument("output")}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorInputWithOutput,
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType1()))
.op(opWithTensorInputWithOutput,
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(TensorType1()));
EXPECT_EQ(1, result.size());
EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
result = callOp(*op, dummyTensor(TensorType2()));
EXPECT_EQ(1, result.size());
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorInputWithOutput,
kernel([] (Tensor a) {return a;}),
dispatchKey(TensorType1()))
.op(opWithTensorInputWithOutput,
kernel([] (Tensor a) {return a;}),
dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
auto result = callOp(*op, dummyTensor(TensorType1()));
EXPECT_EQ(1, result.size());
EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
result = callOp(*op, dummyTensor(TensorType2()));
EXPECT_EQ(1, result.size());
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
}
Tensor captured_input;
FunctionSchema opWithTensorInputWithoutOutput(
"_test::tensor_input",
"",
(std::vector<Argument>{Argument("input")}),
(std::vector<Argument>{}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorInputWithoutOutput,
kernel([] (const Tensor& a) -> void {captured_input = a;}),
dispatchKey(TensorType1()))
.op(opWithTensorInputWithoutOutput,
kernel([] (const Tensor& a) -> void {captured_input = a;}),
dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
auto outputs = callOp(*op, dummyTensor(TensorType1()));
EXPECT_EQ(0, outputs.size());
EXPECT_EQ(TensorType1(), captured_input.type_id());
outputs = callOp(*op, dummyTensor(TensorType2()));
EXPECT_EQ(0, outputs.size());
EXPECT_EQ(TensorType2(), captured_input.type_id());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorInputWithoutOutput,
kernel([] (Tensor a) -> void {captured_input = a;}),
dispatchKey(TensorType1()))
.op(opWithTensorInputWithoutOutput,
kernel([] (Tensor a) -> void {captured_input = a;}),
dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
auto outputs = callOp(*op, dummyTensor(TensorType1()));
EXPECT_EQ(0, outputs.size());
EXPECT_EQ(TensorType1(), captured_input.type_id());
outputs = callOp(*op, dummyTensor(TensorType2()));
EXPECT_EQ(0, outputs.size());
EXPECT_EQ(TensorType2(), captured_input.type_id());
}
int64_t captured_int_input = 0;
FunctionSchema opWithIntInputWithoutOutput(
"_test::int_input",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("input", IntType::get())}),
(std::vector<Argument>{}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithIntInputWithoutOutput,
kernel([] (Tensor, int64_t a) -> void {captured_int_input = a;}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
captured_int_input = 0;
auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
EXPECT_EQ(0, outputs.size());
EXPECT_EQ(3, captured_int_input);
}
FunctionSchema opWithIntInputWithOutput(
"_test::int_input",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("input", IntType::get())}),
(std::vector<Argument>{Argument("output", IntType::get())}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithIntInputWithOutput,
kernel([] (Tensor, int64_t a) {return a + 1;}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(4, outputs[0].toInt());
}
int64_t captured_input_list_size = 0;
FunctionSchema opWithIntListInputWithoutOutput(
"_test::int_list_input",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("input", ListType::ofInts())}),
(std::vector<Argument>{}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithIntListInputWithoutOutput,
kernel([] (Tensor, ArrayRef<int64_t> a) {captured_input_list_size = a.size();}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
captured_input_list_size = 0;
auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
EXPECT_EQ(0, outputs.size());
EXPECT_EQ(3, captured_input_list_size);
}
FunctionSchema opWithIntListInputWithOutput(
"_test::int_list_input",
"",
(std::vector<Argument>{Argument("dummy"),
Argument("input", ListType::ofInts())}),
(std::vector<Argument>{Argument("output", IntType::get())}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithIntListInputWithOutput,
kernel([] (Tensor, ArrayRef<int64_t> a) -> int64_t {return a.size();}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(3, outputs[0].toInt());
}
FunctionSchema opWithTensorListInputWithoutOutput(
"_test::tensor_list_input",
"",
(std::vector<Argument>{Argument("input", ListType::ofTensors())}),
(std::vector<Argument>{}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorListInputWithoutOutput,
kernel([] (ArrayRef<Tensor> a) -> void {captured_input_list_size = a.size();}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
captured_input_list_size = 0;
auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
EXPECT_EQ(0, outputs.size());
EXPECT_EQ(2, captured_input_list_size);
}
FunctionSchema opWithTensorListInputWithOutput(
"_test::tensor_list_input",
"",
(std::vector<Argument>{Argument("input", ListType::ofTensors())}),
(std::vector<Argument>{Argument("output", IntType::get())}));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
.op(opWithTensorListInputWithOutput,
kernel([] (ArrayRef<Tensor> a) -> int64_t {return a.size();}),
dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(2, outputs[0].toInt());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{})
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2")})
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
c10::Error
);
}
}

View File

@ -10,6 +10,7 @@
#include <ATen/core/op_registration/kernel_stackbased.h>
#include <ATen/core/op_registration/kernel_functor.h>
#include <ATen/core/op_registration/kernel_function.h>
#include <ATen/core/op_registration/kernel_lambda.h>
#include <ATen/core/op_registration/infer_schema.h>
namespace c10 {

View File

@ -112,3 +112,44 @@ namespace test_is_type_condition {
static_assert(!is_type_condition<NotATypeCondition>::value, "");
}
}
namespace test_lambda_is_stateless {
template<class Result, class... Args>
struct MyStatelessFunctor final {
Result operator()(Args...) {}
};
template<class Result, class... Args>
struct MyStatelessConstFunctor final {
Result operator()(Args...) const {}
};
void func() {
auto stateless_lambda = [] (int a) {return a;};
static_assert(is_stateless_lambda<decltype(stateless_lambda)>::value, "");
int b = 4;
auto stateful_lambda_1 = [&] (int a) {return a + b;};
static_assert(!is_stateless_lambda<decltype(stateful_lambda_1)>::value, "");
auto stateful_lambda_2 = [=] (int a) {return a + b;};
static_assert(!is_stateless_lambda<decltype(stateful_lambda_2)>::value, "");
auto stateful_lambda_3 = [b] (int a) {return a + b;};
static_assert(!is_stateless_lambda<decltype(stateful_lambda_3)>::value, "");
static_assert(!is_stateless_lambda<MyStatelessFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
static_assert(!is_stateless_lambda<MyStatelessFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
static_assert(!is_stateless_lambda<MyStatelessConstFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
class Dummy final {};
static_assert(!is_stateless_lambda<Dummy>::value, "A non-functor type is also not a lambda");
static_assert(!is_stateless_lambda<int>::value, "An int is not a lambda");
using Func = int(int);
static_assert(!is_stateless_lambda<Func>::value, "A function is not a lambda");
static_assert(!is_stateless_lambda<Func*>::value, "A function pointer is not a lambda");
}
}

View File

@ -7,24 +7,6 @@
#include <c10/util/Array.h>
namespace c10 { namespace guts {
namespace detail {
/**
* strip_class: helper to remove the class type from pointers to `operator()`.
*/
template <typename T>
struct strip_class {};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...)> {
using type = Result(Args...);
};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...) const> {
using type = Result(Args...);
};
template <typename T>
using strip_class_t = typename strip_class<T>::type;
} // namespace detail
/**
* Access information about result type or arguments from a function type.
@ -43,16 +25,6 @@ struct function_traits<Result (Args...)> {
static constexpr auto number_of_parameters = sizeof...(Args);
};
/**
* Evaluates to true_type, iff the given class is a Functor
* (i.e. has a call operator with some set of arguments)
*/
template<class Functor, class Enable = void>
struct is_functor : std::false_type {};
template<class Functor>
struct is_functor<Functor, guts::enable_if_t<is_function_type<detail::strip_class_t<decltype(&Functor::operator())>>::value>> : std::true_type {};
/**
* infer_function_traits: creates a `function_traits` type for a simple
* function (pointer) or functor (lambda/struct). Currently does not support

View File

@ -49,6 +49,64 @@ template <template <class...> class Template, class... Args>
struct is_instantiation_of<Template, Template<Args...>> : std::true_type {};
template<template<class...> class Template, class T> using is_instantiation_of_t = typename is_instantiation_of<Template, T>::type;
namespace detail {
/**
* strip_class: helper to remove the class type from pointers to `operator()`.
*/
template <typename T>
struct strip_class {};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...)> {
using type = Result(Args...);
};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...) const> {
using type = Result(Args...);
};
template <typename T>
using strip_class_t = typename strip_class<T>::type;
} // namespace detail
/**
* Evaluates to true_type, iff the given class is a Functor
* (i.e. has a call operator with some set of arguments)
*/
template<class Functor, class Enable = void>
struct is_functor : std::false_type {};
template<class Functor>
struct is_functor<Functor, guts::enable_if_t<is_function_type<detail::strip_class_t<decltype(&Functor::operator())>>::value>> : std::true_type {};
/**
* lambda_is_stateless<T> is true iff the lambda type T is stateless
* (i.e. does not have a closure).
* Example:
* auto stateless_lambda = [] (int a) {return a;};
* lambda_is_stateless<decltype(stateless_lambda)> // true
* auto stateful_lambda = [&] (int a) {return a;};
* lambda_is_stateless<decltype(stateful_lambda)> // false
*/
namespace detail {
template<class LambdaType, class FuncType> struct is_stateless_lambda__ final {
static_assert(!std::is_same<LambdaType, LambdaType>::value, "Base case shouldn't be hit");
};
// implementation idea: According to the C++ standard, stateless lambdas are convertible to function pointers
template<class LambdaType, class C, class Result, class... Args>
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...) const> : std::is_convertible<LambdaType, Result(*)(Args...)> {};
template<class LambdaType, class C, class Result, class... Args>
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...)> : std::is_convertible<LambdaType, Result(*)(Args...)> {};
// case where LambdaType is not even a functor
template<class LambdaType, class Enable = void> struct is_stateless_lambda_ final : std::false_type {};
// case where LambdaType is a functor
template<class LambdaType> struct is_stateless_lambda_<LambdaType, guts::enable_if_t<is_functor<LambdaType>::value>>
: is_stateless_lambda__<LambdaType, decltype(&LambdaType::operator())> {};
}
template<class T>
using is_stateless_lambda = detail::is_stateless_lambda_<guts::decay_t<T>>;
/**