mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Throw error if multiple kernels registered (#20737)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20737 If someone tries to register multiple kernels in the same .op() call, we're now throwing an error. Differential Revision: D15425660 fbshipit-source-id: 6d2f1444da3e16a6a98863d847965c2aa211e046
This commit is contained in:
committed by
Facebook Github Bot
parent
f3d827f311
commit
cc02a1af61
@ -173,12 +173,21 @@ public:
|
||||
* > .dispatchKey(CUDATensorId()));
|
||||
*/
|
||||
Options&& dispatchKey(TensorTypeId dispatch_key) && {
|
||||
if (config.dispatch_key.has_value()) {
|
||||
AT_ERROR("Operator registration: Cannot register multiple dispatch keys in the same op() call. Please call op() multiple times if you want to register multiple kernels.");
|
||||
}
|
||||
config.dispatch_key = dispatch_key;
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
Options&& kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction&& cache_creator, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && {
|
||||
if (nullptr != config.kernel_func) {
|
||||
AT_ERROR("Operator registration: Cannot register multiple kernels in the same op() call. Please call op() multiple times if you want to register multiple kernels.");
|
||||
}
|
||||
AT_ASSERTM(nullptr == config.cache_creator_func, "kernel_func was nullptr, so cache_creator_func must be too");
|
||||
AT_ASSERTM(nullptr == config.inferred_function_schema, "kernel_func was nullptr, so inferred_function_schema must be too");
|
||||
|
||||
config.kernel_func = kernel_func;
|
||||
config.cache_creator_func = std::move(cache_creator);
|
||||
config.inferred_function_schema = std::move(inferred_function_schema);
|
||||
|
@ -384,7 +384,23 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenNewer
|
||||
}, "Didn't find kernel to dispatch to for operator '_test::dummy'");
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest, whenTryingToRegisterWithMultipleKernels_thenFails) {
|
||||
expectThrows<c10::Error>([&] {
|
||||
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().kernel<DummyKernel>());
|
||||
}, "Cannot register multiple kernels in the same op() call");
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest, whenTryingToRegisterWithMultipleDispatchKeys_thenFails) {
|
||||
expectThrows<c10::Error>([&] {
|
||||
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().dispatchKey(TensorType1()).dispatchKey(TensorType2()));
|
||||
}, "Cannot register multiple dispatch keys in the same op() call");
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest, whenTryingToRegisterWithDispatchKeyWithoutKernel_thenFails) {
|
||||
expectThrows<c10::Error>([&] {
|
||||
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().dispatchKey(TensorType1()));
|
||||
}, "Tried to register an operator with a dispatch key but without a kernel");
|
||||
}
|
||||
|
||||
/**
|
||||
* This is used to check that a given type works correctly when passed as input
|
||||
|
Reference in New Issue
Block a user