Throw error if multiple kernels registered (#20737)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20737

If someone tries to register multiple kernels in the same .op() call, we're now throwing an error.

Differential Revision: D15425660

fbshipit-source-id: 6d2f1444da3e16a6a98863d847965c2aa211e046
This commit is contained in:
Sebastian Messmer
2019-05-21 12:13:32 -07:00
committed by Facebook Github Bot
parent f3d827f311
commit cc02a1af61
2 changed files with 25 additions and 0 deletions

View File

@ -173,12 +173,21 @@ public:
* > .dispatchKey(CUDATensorId()));
*/
Options&& dispatchKey(TensorTypeId dispatch_key) && {
if (config.dispatch_key.has_value()) {
AT_ERROR("Operator registration: Cannot register multiple dispatch keys in the same op() call. Please call op() multiple times if you want to register multiple kernels.");
}
config.dispatch_key = dispatch_key;
return std::move(*this);
}
private:
Options&& kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction&& cache_creator, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && {
if (nullptr != config.kernel_func) {
AT_ERROR("Operator registration: Cannot register multiple kernels in the same op() call. Please call op() multiple times if you want to register multiple kernels.");
}
AT_ASSERTM(nullptr == config.cache_creator_func, "kernel_func was nullptr, so cache_creator_func must be too");
AT_ASSERTM(nullptr == config.inferred_function_schema, "kernel_func was nullptr, so inferred_function_schema must be too");
config.kernel_func = kernel_func;
config.cache_creator_func = std::move(cache_creator);
config.inferred_function_schema = std::move(inferred_function_schema);

View File

@ -384,7 +384,23 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenNewer
}, "Didn't find kernel to dispatch to for operator '_test::dummy'");
}
TEST(OperatorRegistrationTest, whenTryingToRegisterWithMultipleKernels_thenFails) {
expectThrows<c10::Error>([&] {
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().kernel<DummyKernel>());
}, "Cannot register multiple kernels in the same op() call");
}
TEST(OperatorRegistrationTest, whenTryingToRegisterWithMultipleDispatchKeys_thenFails) {
expectThrows<c10::Error>([&] {
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().dispatchKey(TensorType1()).dispatchKey(TensorType2()));
}, "Cannot register multiple dispatch keys in the same op() call");
}
TEST(OperatorRegistrationTest, whenTryingToRegisterWithDispatchKeyWithoutKernel_thenFails) {
expectThrows<c10::Error>([&] {
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().dispatchKey(TensorType1()));
}, "Tried to register an operator with a dispatch key but without a kernel");
}
/**
* This is used to check that a given type works correctly when passed as input