Catchall kernels instead of fallback kernels (#20773)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20773

This removes the feature to register fallback kernels that are called when no other kernel matches.
Instead, we introduce the concept of catchall kernels that are always called independent of inputs.
If you only have a fallback/catchall kernel and no kernels with concrete dispatch keys, then both concepts behave in the same way.
The difference is that we now disallow operators to have both, a catchall kernel and kernels with concrete dispatch keys.
This was possible before when they have been fallback kernels.

The reason for this change is that we anticipate needing a method_missing feature in backends, i.e. a backend-wide fallback to call when the backend doesn't specify a kernel for an operator.
We are not clear on precendence between this backend-wide fallback and an operator level fallback. Disallow fallbacks for now so we are free to choose later without breaking backwards compatibility.

Reviewed By: dzhulgakov

Differential Revision: D15438977

fbshipit-source-id: cb3aa764a1659d909ee21a7bd8ec3d32438aafaa
This commit is contained in:
Sebastian Messmer
2019-05-23 23:38:20 -07:00
committed by Facebook Github Bot
parent c25e33789e
commit fc941d3bca
9 changed files with 267 additions and 191 deletions

View File

@ -4,6 +4,7 @@
#include <c10/util/LeftRight.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/either.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/dispatch/KernelFunction.h>
@ -61,8 +62,7 @@ class KernelTable_ final {
if (!emplaced.second) {
// Element already existed. Overwrite it.
emplaced.first->second = value;
AT_WARN("Registered a kernel that overwrote a previously registered kernel with same dispatch key '",
detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
AT_WARN("Registered a kernel for operator ", operator_name," with dispatch key ", detail::dispatch_key_to_string(key), " that overwrote a previously registered kernel with the same dispatch key for the same operator.");
}
}
@ -113,33 +113,55 @@ class KernelTable_ final {
class DispatchTable final {
public:
DispatchTable(const FunctionSchema& schema)
: kernels_()
: kernels_(make_left<detail::KernelTable_, DispatchTableEntry>())
, dispatch_strategy_(get_dispatch_strategy_(schema))
, operator_name_(schema.name()) {}
/**
* Register a kernel in the table at some dispatch key.
* @param dispatch_key Dispatch key to define when this kernel is selected.
* If this is TensorTypeIds::undefined(), this registers a fallback
* kernel that is called whenever no other kernel matches.
* @param kernel Concrete kernel function implementation to register
*/
void setKernel(
TensorTypeId dispatch_key,
const DispatchTableEntry& kernel) {
const bool is_fallback_kernel = (dispatch_key == TensorTypeIds::undefined());
AT_ASSERTM(is_fallback_kernel || dispatch_strategy_.is_valid_, "Tried to register a kernel with a dispatch key for operator schema ", operator_name_, " that doesn't have tensor arguments.");
kernels_.set(dispatch_key, kernel, operator_name_);
AT_ASSERTM(dispatch_key != TensorTypeIds::undefined());
AT_ASSERTM(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", detail::dispatch_key_to_string(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments.");
AT_ASSERTM(kernels_.is_left(), "Tried to register a kernel with dispatch key ", detail::dispatch_key_to_string(dispatch_key)," for operator ", operator_name_, ", which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys.");
kernels_.left().set(dispatch_key, kernel, operator_name_);
}
/**
* Deregister the kernel for some dispatch key.
*
* @param dispatch_key Dispatch key to unregister. If this is
* TensorTypeIds::undefined(), it will deregister the fallback kernel.
* @param dispatch_key Dispatch key to unregister.
*/
void removeKernelIfExists(TensorTypeId dispatch_key) {
kernels_.removeIfExists(dispatch_key, operator_name_);
AT_ASSERTM(kernels_.is_left(), "Tried to remove the kernel for dispatch key ", detail::dispatch_key_to_string(dispatch_key), " for operator ", operator_name_, ", which only has a catch-all kernel.");
kernels_.left().removeIfExists(dispatch_key, operator_name_);
}
/**
* Register a catch-all kernel that is called for this operator
* independent of the inputs. An operator can have either
* a catch-all kernel or a set of kernels with concrete
* dispatch keys, not both.
*/
void setCatchallKernel(const DispatchTableEntry& kernel) {
if (kernels_.is_right()) {
AT_WARN("Registered a catch-all kernel for operator ", operator_name_," that overwrote a previously registered catch-all kernel for the same operator.");
} else {
AT_ASSERTM(0 == kernels_.left().size(), "Tried to register a catch-all kernel for operator ", operator_name_, " which already has kernels with dispatch keys. An operator can only have either a catch-all kernel or kernels with dispatch keys.");
}
kernels_ = make_right<detail::KernelTable_, DispatchTableEntry>(kernel);
}
/**
* Remove the catch-all kernel.
*/
void removeCatchallKernel() {
AT_ASSERTM(kernels_.is_right(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered.");
kernels_ = make_left<detail::KernelTable_, DispatchTableEntry>();
}
/**
@ -150,40 +172,40 @@ class DispatchTable final {
* @return Kernel function pointing to the right kernel for the given arguments.
*/
const DispatchTableEntry& lookup(const Stack* stack) const {
if (C10_LIKELY(dispatch_strategy_.is_valid_)) {
TensorTypeId dispatch_key = dispatch_strategy_.get_dispatch_key(stack);
auto found = kernels_.lookup(dispatch_key);
if (nullptr != found) {
return *found;
}
return kernels_.map<const DispatchTableEntry&>(
[&] (const detail::KernelTable_& table) -> const DispatchTableEntry& {
// We have a dispatch table. Find the correct kernel for the inputs and return it.
// regular dispatch didn't find a kernel, let's check the fallback kernel.
const DispatchTableEntry* fallbackKernel = fallback_kernel();
if (nullptr != fallbackKernel) {
return *fallbackKernel;
}
AT_ASSERTM(dispatch_strategy_.is_valid_, "Operator ", operator_name_, " has an invalid dispatch key but kernels registered.");
// no kernel found and fallback kernel doesn't exist either
AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name_,
TensorTypeId dispatch_key = dispatch_strategy_.get_dispatch_key(stack, operator_name_);
auto found = table.lookup(dispatch_key);
AT_ASSERTM(nullptr != found, "Didn't find kernel to dispatch to for operator '", operator_name_,
"'. Tried to look up kernel for dispatch key '", detail::dispatch_key_to_string(dispatch_key),
"'. Registered dispatch keys are: ", list_all_dispatch_keys_());
} else {
// with an invalid dispatch key, only the fallback kernel is allowed.
const DispatchTableEntry* fallbackKernel = fallback_kernel();
"'. Registered dispatch keys are: ", listAllDispatchKeys());
AT_ASSERTM(kernels_.size() == ((nullptr == fallbackKernel)?0:1), "Cannot have an invalid dispatch key but registered kernels");
if (nullptr != fallbackKernel) {
return *fallbackKernel;
}
// no kernel registered and fallback kernel doesn't exist either
AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name_, "'");
return *found;
},
[] (const DispatchTableEntry& entry) -> const DispatchTableEntry& {
// We have a catch-all kernel. Just return it.
return entry;
}
);
}
bool isEmpty() const {
return 0 == kernels_.size();
return kernels_.map<bool>(
[] (const detail::KernelTable_& table) {return 0 == table.size();},
[] (const DispatchTableEntry&) {return false;}
);
}
std::string listAllDispatchKeys() const {
return kernels_.map<std::string>(
[] (const detail::KernelTable_& table) {return table.list_all_dispatch_keys();},
[] (const DispatchTableEntry&) {return "ANY";}
);
}
private:
@ -204,16 +226,16 @@ private:
// as long as they only have fallback kernels and no dispatched kernels.
bool is_valid_;
TensorTypeId get_dispatch_key(const Stack* stack) const {
TensorTypeId get_dispatch_key(const Stack* stack, const std::string& operator_name) const {
const IValue& first_tensor_arg = torch::jit::peek(
*stack,
0,
reverse_index_of_first_tensor_arg_
);
if (first_tensor_arg_is_tensor_list_) {
if (C10_UNLIKELY(first_tensor_arg_is_tensor_list_)) {
const auto& tensor_list = first_tensor_arg.toTensorListRef();
if (tensor_list.size() == 0) {
throw std::runtime_error("Tried to dispatch based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
throw std::runtime_error("Tried to dispatch operator " + operator_name + " based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
}
return tensor_list[0].type_id();
} else {
@ -222,10 +244,6 @@ private:
}
};
const DispatchTableEntry* fallback_kernel() const {
return kernels_.lookup(TensorTypeIds::undefined());
}
static DispatchStrategy get_dispatch_strategy_(const FunctionSchema& schema) {
for (size_t i = 0; i < schema.arguments().size(); ++i) {
const auto& type = schema.arguments()[i].type();
@ -242,15 +260,11 @@ private:
return {0, false, false};
}
std::string list_all_dispatch_keys_() const {
std::string result = kernels_.list_all_dispatch_keys();
if (fallback_kernel() != nullptr) {
result += ", FALLBACK";
}
return result;
}
detail::KernelTable_ kernels_;
// kernels_ either contains a dispatch table or
// a single catch-all kernel that is called for every backend
// The empty state (i.e. no kernels registered) is represented
// as an empty table.
either<detail::KernelTable_, DispatchTableEntry> kernels_;
DispatchStrategy dispatch_strategy_;
std::string operator_name_;
};

View File

@ -106,9 +106,9 @@ RegistrationHandleRAII Dispatcher::registerKernel(const OperatorHandle& op, Tens
return op.operatorIterator_->op.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, std::move(cache_creator_func)});
}
RegistrationHandleRAII Dispatcher::registerFallbackKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func) {
RegistrationHandleRAII Dispatcher::registerCatchallKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func) {
// note: this doesn't need the mutex to protect the iterator because write operations on the list keep iterators intact.
return op.operatorIterator_->op.registerFallbackKernel(DispatchTableEntry{kernel_func, std::move(cache_creator_func)});
return op.operatorIterator_->op.registerCatchallKernel(DispatchTableEntry{kernel_func, std::move(cache_creator_func)});
}
void Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {

View File

@ -126,7 +126,7 @@ public:
* @return A RAII object that manages the lifetime of the registration.
* Once that object is destructed, the kernel will be deregistered.
*/
RegistrationHandleRAII registerFallbackKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func);
RegistrationHandleRAII registerCatchallKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func);
/**
* Perform a dynamic dispatch and get the kernel for an operator.

View File

@ -3,28 +3,43 @@
namespace c10 {
namespace impl {
namespace {
std::string listAllDispatchKeys(const ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>& kernels) {
if (kernels.size() == 0) {
return "";
}
std::ostringstream str;
str << detail::dispatch_key_to_string(kernels.begin()->first);
for (auto iter = ++kernels.begin(); iter != kernels.end(); ++iter) {
str << ", " << detail::dispatch_key_to_string(iter->first);
}
return str.str();
}
}
OperatorEntry::OperatorEntry(FunctionSchema&& schema)
: schema_(std::move(schema))
, dispatchTable_(schema_)
, kernels_() {}
, kernels_(make_left<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>()) {}
void OperatorEntry::prepareForDeregistration() {
return dispatchTable_.read([&] (const DispatchTable& dispatchTable) {
if (!dispatchTable.isEmpty()) {
std::ostringstream str;
str << schema_;
AT_ERROR("Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", str.str());
AT_ERROR("Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", toString(schema_), ". Registered kernels for dispatch keys: ", dispatchTable.listAllDispatchKeys());
}
});
AT_ASSERTM(kernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels");
AT_ASSERTM(kernels_.is_left(), "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have a catch-all kernel. The operator schema is ", toString(schema_));
AT_ASSERTM(kernels_.left().size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_.left()), ". The operator schema is ", toString(schema_));
}
RegistrationHandleRAII OperatorEntry::registerKernel(TensorTypeId dispatch_key, DispatchTableEntry kernel) {
std::unique_lock<std::mutex> lock(kernelsMutex_);
AT_CHECK(kernels_.is_left(), "Tried to register a kernel with dispatch key ", detail::dispatch_key_to_string(dispatch_key)," for an operator which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is ", toString(schema_));
// Add the kernel to the kernels list,
// possibly creating the list if this is the first kernel.
auto& k = kernels_[dispatch_key];
auto& k = kernels_.left()[dispatch_key];
k.push_front(kernel);
std::list<DispatchTableEntry>::iterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
@ -38,12 +53,72 @@ RegistrationHandleRAII OperatorEntry::registerKernel(TensorTypeId dispatch_key,
});
}
RegistrationHandleRAII OperatorEntry::registerCatchallKernel(DispatchTableEntry kernel) {
std::unique_lock<std::mutex> lock(kernelsMutex_);
if (kernels_.is_left()) {
AT_CHECK(0 == kernels_.left().size(), "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys ", listAllDispatchKeys(kernels_.left()), ". An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is ", toString(schema_));
kernels_ = make_right<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>();
}
// Add the kernel to the kernels list,
// possibly creating the list if this is the first kernel.
auto& k = kernels_.right();
k.push_front(kernel);
std::list<DispatchTableEntry>::iterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
// that the dispatch table points to the newest kernel
updateCatchallDispatchTable_();
return RegistrationHandleRAII([this, inserted] {
// list iterators stay valid even if the list changes,
// so we can use the iterator to deregister the kernel from the list
deregisterCatchallKernel_(inserted);
});
}
void OperatorEntry::deregisterKernel_(TensorTypeId dispatch_key, std::list<DispatchTableEntry>::iterator kernel) {
std::unique_lock<std::mutex> lock(kernelsMutex_);
AT_CHECK(kernels_.is_left(), "Tried deregister a kernel for dispatch key ", detail::dispatch_key_to_string(dispatch_key), " for an operator that only has a catch-all kernel. The operator schema is ", toString(schema_));
auto& kernels = kernels_.left();
auto found = kernels.find(dispatch_key);
AT_ASSERTM(found != kernels.end(), "Tried to deregister a kernel for dispatch key ", detail::dispatch_key_to_string(dispatch_key), " but there are no kernels registered for this dispatch key. The operator schema is ", toString(schema_));
auto& k = found->second;
k.erase(kernel);
if (k.empty()) {
// the invariant says we don't want empty lists but instead remove the list from the map
kernels.erase(found);
}
updateDispatchTable_(dispatch_key);
}
void OperatorEntry::deregisterCatchallKernel_(std::list<DispatchTableEntry>::iterator kernel) {
std::unique_lock<std::mutex> lock(kernelsMutex_);
AT_CHECK(kernels_.is_right(), "Tried to deregister a catch-all kernel for an operator that doesn't have a catch-all kernel registered. The operator schema is ", toString(schema_));
auto& k = kernels_.right();
k.erase(kernel);
if (k.empty()) {
// the invariant says that the empty state is represented with is_left()
kernels_ = make_left<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>();
}
updateCatchallDispatchTable_();
}
void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) {
// precondition: kernelsMutex_ is locked
auto k = kernels_.find(dispatch_key);
AT_ASSERTM(kernels_.is_left(), "Can't update the dispatch table a dispatch key ", detail::dispatch_key_to_string(dispatch_key), " because the operator only has catch-all kernels. The operator schema is ", toString(schema_));
if (k == kernels_.end()) {
auto& kernels = kernels_.left();
auto k = kernels.find(dispatch_key);
if (k == kernels.end()) {
dispatchTable_.write([&] (DispatchTable& dispatchTable) {
dispatchTable.removeKernelIfExists(dispatch_key);
});
@ -54,23 +129,18 @@ void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) {
}
}
void OperatorEntry::deregisterKernel_(TensorTypeId dispatch_key, std::list<DispatchTableEntry>::iterator kernel) {
std::unique_lock<std::mutex> lock(kernelsMutex_);
void OperatorEntry::updateCatchallDispatchTable_() {
// precondition: kernelsMutex_ is locked
auto found = kernels_.find(dispatch_key);
AT_ASSERTM(found != kernels_.end(), "Tried to deregister a kernel but there are no kernels registered for this dispatch key.");
auto& k = found->second;
k.erase(kernel);
if (k.empty()) {
// the invariant says we don't want empty lists but instead remove the list from the map
kernels_.erase(found);
if (kernels_.is_left()) {
dispatchTable_.write([&] (DispatchTable& dispatchTable) {
dispatchTable.removeCatchallKernel();
});
} else {
dispatchTable_.write([&] (DispatchTable& dispatchTable) {
dispatchTable.setCatchallKernel(kernels_.right().front());
});
}
updateDispatchTable_(dispatch_key);
}
RegistrationHandleRAII OperatorEntry::registerFallbackKernel(DispatchTableEntry kernel) {
return registerKernel(TensorTypeIds::undefined(), std::move(kernel));
}
}

View File

@ -31,19 +31,26 @@ public:
void prepareForDeregistration();
RegistrationHandleRAII registerKernel(TensorTypeId dispatch_key, DispatchTableEntry kernel);
RegistrationHandleRAII registerFallbackKernel(DispatchTableEntry kernel);
RegistrationHandleRAII registerCatchallKernel(DispatchTableEntry kernel);
private:
void deregisterKernel_(TensorTypeId dispatch_key, std::list<DispatchTableEntry>::iterator kernel);
void deregisterFallbackKernel_();
void deregisterCatchallKernel_(std::list<DispatchTableEntry>::iterator kernel);
FunctionSchema schema_;
// The dispatchTable stores the current kernel for each dispatch key
LeftRight<DispatchTable> dispatchTable_;
// The kernels map stores all registered kernels for a certain dispatch key.
// If an operator library gets loaded that overwrites already existing kernels,
// kernels_ is either:
// left: a kernel map listing mapping from a dispatch key to a list of all
// kernels for that operator, or it is
// right: a list of all catch-all kernels registered for this operator.
// An operator can only have either dispatched kernels or catch-all kernels,
// not both.
// In both cases, the list of kernels stores all registered kernels for the
// corresponding dispatch key (or for catch-all).
// If an operator library gets loaded that overwrites an already existing kernel,
// both kernels will be in that list but only the newer one will be in
// dispatchTable. If any of the kernels go away (say the library gets
// unloaded), we remove the kernel from this list and update the
@ -54,18 +61,25 @@ private:
// kernels is a larger data structure and accessed quite infrequently
// while dispatchTable is accessed often and should be kept small to fit
// into CPU caches.
// Invariants:
// - dispatchTable[dispatch_key] == kernels[dispatch_key].front()
// Invariants (assuming kernels_.is_left()):
// - dispatchTable[dispatch_key] == kernels_.left()[dispatch_key].front()
// - dispatchTable[dispatch_key] does not exist if and only if
// kernels[dispatch_key] does not exist
// - If kernels[dispatch_key] exists, then it has elements.
// kernels_.left()[dispatch_key] does not exist
// - If kernels_.left()[dispatch_key] exists, then it has elements.
// It is never an empty list.
ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>> kernels_;
std::mutex kernelsMutex_;
// Analogous invariants for kernels_.is_right().
// The empty state (i.e. no kernels registered) is represented as an empty
// map with kernels_.is_left().
c10::either<
ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, // dispatched kernels
std::list<DispatchTableEntry> // catch-all kernels
> kernels_;
std::mutex kernelsMutex_; // protects kernels_
// This function re-establishes the invariant that dispatchTable
// contains the front element from the kernels list for a given dispatch key.
void updateDispatchTable_(TensorTypeId dispatch_key);
void updateCatchallDispatchTable_();
};
}

View File

@ -21,7 +21,7 @@ public:
if (dispatch_key.has_value()) {
kernel_registration_handle_ = Dispatcher::singleton().registerKernel(op_.opHandle(), *dispatch_key, kernel, std::move(cache_creator));
} else {
kernel_registration_handle_ = Dispatcher::singleton().registerFallbackKernel(op_.opHandle(), kernel, std::move(cache_creator));
kernel_registration_handle_ = Dispatcher::singleton().registerCatchallKernel(op_.opHandle(), kernel, std::move(cache_creator));
}
}
}

View File

@ -37,7 +37,7 @@ struct MockKernel final : OperatorKernel {
private:
bool* called_;
};
TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWrongDispatchKey_thenFails) {
TEST(OperatorRegistrationTest, whenCallingOpWithWrongDispatchKey_thenFails) {
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().dispatchKey(TensorType1()));
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
@ -47,110 +47,61 @@ TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWro
}, "Didn't find kernel to dispatch to for operator '_test::dummy'");
}
TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOpWithWrongDispatchKey_thenFails) {
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().dispatchKey(TensorType1()));
{
auto inner_registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>());
// this registered a fallback kernel, but now that registration goes out of scope and deregisters it
}
TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCallsCatchallKernel) {
bool called = false;
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called)); // note: no dispatch key means this is the catchall kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
EXPECT_FALSE(called);
callOp(*op, dummyTensor(TensorType2()));
EXPECT_TRUE(called);
}
TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernel_thenFails) {
bool called = false;
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called));
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(TensorType2()));
}, "Didn't find kernel to dispatch to for operator '_test::dummy'");
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called).dispatchKey(TensorType1()));
}, "for an operator which already has a catch-all kernel registered");
}
TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernel_whenCallingOp_thenCallsFallbackKernel) {
TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel) {
bool called = false;
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
EXPECT_FALSE(called);
callOp(*op, dummyTensor(TensorType2()));
EXPECT_TRUE(called);
}
TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernelAndOtherKernelOutOfScope_whenCallingOp_thenCallsFallbackKernel) {
bool called = false;
bool other_called = false;
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
{
auto inner_registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&other_called).dispatchKey(TensorType2()));
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called).dispatchKey(TensorType1()));
}
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called)); // note: no dispatch key means this is the catchall kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
EXPECT_FALSE(called);
callOp(*op, dummyTensor(TensorType2()));
EXPECT_TRUE(called);
EXPECT_FALSE(other_called);
}
TEST(OperatorRegistrationTest, givenOpWithFirstFallbackAndThenOtherKernel_whenCallingWithCorrectDispatchKey_thenCallsCorrectKernel) {
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel).dispatchKey(TensorType1()));
TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) {
bool called = false;
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called).dispatchKey(TensorType1()));
expectThrows<c10::Error>([&] {
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called)); // note: no dispatch key means this is the catchall kernel
}, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys");
}
TEST(OperatorRegistrationTest, givenOpWithCatchallKernelOutOfScope_whenRegisteringDispatchedKernelAndCallingOp_thenCallsCatchallKernel) {
bool called = false;
{
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called)); // note: no dispatch key means this is the catchall kernel
}
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called).dispatchKey(TensorType1()));
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
EXPECT_FALSE(called_kernel);
EXPECT_FALSE(called_fallback);
EXPECT_FALSE(called);
callOp(*op, dummyTensor(TensorType1()));
EXPECT_TRUE(called_kernel);
EXPECT_FALSE(called_fallback);
}
TEST(OperatorRegistrationTest, givenOpWithFirstFallbackAndThenOtherKernel_whenCallingWithWrongDispatchKey_thenCallsFallbackKernel) {
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel).dispatchKey(TensorType1()));
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
EXPECT_FALSE(called_kernel);
EXPECT_FALSE(called_fallback);
callOp(*op, dummyTensor(TensorType2()));
EXPECT_FALSE(called_kernel);
EXPECT_TRUE(called_fallback);
}
TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCallingWithCorrectDispatchKey_thenCallsCorrectKernel) {
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel).dispatchKey(TensorType1()))
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
EXPECT_FALSE(called_kernel);
EXPECT_FALSE(called_fallback);
callOp(*op, dummyTensor(TensorType1()));
EXPECT_TRUE(called_kernel);
EXPECT_FALSE(called_fallback);
}
TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCallingWithWrongDispatchKey_thenCallsFallbackKernel) {
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel).dispatchKey(TensorType1()))
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
EXPECT_FALSE(called_kernel);
EXPECT_FALSE(called_fallback);
callOp(*op, dummyTensor(TensorType2()));
EXPECT_FALSE(called_kernel);
EXPECT_TRUE(called_fallback);
EXPECT_TRUE(called);
}
TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegistering_thenOnlyRegistersSchema) {
@ -199,15 +150,14 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterw
}
TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters) {
// as long as we don't register non-fallback kernels, ops without tensor arguments are fine
// as long as we don't register non-catchall kernels, ops without tensor arguments are fine
auto registrar = c10::RegisterOperators().op("_test::dummy() -> ()");
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value()); // assert schema is registered
}
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenRegistering_thenShowsWarning) {
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegistering_thenShowsWarning) {
auto registrar = c10::RegisterOperators()
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().dispatchKey(TensorType1()));
@ -217,10 +167,10 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenRegistering_t
testing::internal::CaptureStderr();
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>().dispatchKey(TensorType1()));
std::string output = testing::internal::GetCapturedStderr();
EXPECT_THAT(output, testing::HasSubstr("Registered a kernel that overwrote a previously registered kernel with same dispatch key"));
EXPECT_THAT(output, testing::HasSubstr("that overwrote a previously registered kernel with the same dispatch key for the same operator"));
}
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) {
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel1).dispatchKey(TensorType1()));
@ -234,7 +184,20 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenCalled_thenCa
EXPECT_TRUE(called_kernel2);
}
TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenCalled_thenCallsNewerKernel) {
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenCalled_thenShowsWarning) {
auto registrar = c10::RegisterOperators()
.op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>());
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value()); // assert schema is registered
testing::internal::CaptureStderr();
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>());
std::string output = testing::internal::GetCapturedStderr();
EXPECT_THAT(output, testing::HasSubstr("that overwrote a previously registered catch-all kernel for the same operator"));
}
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenCalled_thenCallsNewerKernel) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel1));
@ -248,7 +211,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenCalle
EXPECT_TRUE(called_kernel2);
}
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenNewerKernelDeletedAndOpCalled_thenCallsOlderKernel) {
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewerKernelDeletedAndOpCalled_thenCallsOlderKernel) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel1).dispatchKey(TensorType1()));
@ -264,7 +227,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenNewerKernelDe
EXPECT_FALSE(called_kernel2);
}
TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenNewerKernelDeletedAndOpCalled_thenCallsOlderKernel) {
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerKernelDeletedAndOpCalled_thenCallsOlderKernel) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel1));
@ -280,7 +243,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenNewer
EXPECT_FALSE(called_kernel2);
}
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenOlderKernelDeletedAndOpCalled_thenCallsNewerKernel) {
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlderKernelDeletedAndOpCalled_thenCallsNewerKernel) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel1).dispatchKey(TensorType1()));
@ -296,7 +259,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenOlderKernelDe
EXPECT_TRUE(called_kernel2);
}
TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenOlderKernelDeletedAndOpCalled_thenCallsNewerKernel) {
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderKernelDeletedAndOpCalled_thenCallsNewerKernel) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(&called_kernel1));
@ -312,7 +275,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenOlder
EXPECT_TRUE(called_kernel2);
}
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) {
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
@ -330,7 +293,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenOlderAndThenN
}, "Didn't find kernel to dispatch to for operator '_test::dummy'");
}
TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) {
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
@ -348,7 +311,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenOlder
}, "Didn't find kernel to dispatch to for operator '_test::dummy'");
}
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) {
TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
@ -366,7 +329,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenNewerAndThenO
}, "Didn't find kernel to dispatch to for operator '_test::dummy'");
}
TEST(OperatorRegistrationTest, givenKernelsWithSameFallbackDispatchKey_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) {
TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) {
bool called_kernel1 = false;
bool called_kernel2 = false;
auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
@ -581,6 +544,10 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
std::vector<std::string>(), [] (const std::vector<std::string>& v) {EXPECT_EQ(0, v.size());},
std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());},
"(str[] a) -> str[]");
testArgTypes<std::vector<Tensor>>::test(
std::vector<Tensor>({}), [] (const std::vector<Tensor>& v) {EXPECT_EQ(0, v.size());},
std::vector<Tensor>({}), [] (const IValue& v) {EXPECT_EQ(0, v.toTensorListRef().size());},
"(Tensor[] a) -> Tensor[]");
// list types (with non-empty list)

View File

@ -28,6 +28,7 @@ template<class T> using result_of_t = std::result_of_t<T>;
template<class T> using decay_t = std::decay_t<T>;
template<class T> using remove_const_t = std::remove_const_t<T>;
template<class T> using remove_pointer_t = std::remove_pointer_t<T>;
template<class... T> using common_type_t = std::common_type_t<T...>;
#else
template<bool B, class T, class F> using conditional_t = typename std::conditional<B, T, F>::type;
template<bool B, class T = void> using enable_if_t = typename std::enable_if<B, T>::type;
@ -38,6 +39,7 @@ template<class T> using result_of_t = typename std::result_of<T>::type;
template<class T> using decay_t = typename std::decay<T>::type;
template<class T> using remove_const_t = typename std::remove_const<T>::type;
template<class T> using remove_pointer_t = typename std::remove_pointer<T>::type;
template<class... T> using common_type_t = typename std::common_type<T...>::type;
#endif

View File

@ -130,6 +130,15 @@ class either final {
return std::move(right());
}
template<class Result, class LeftMapFunc, class RightMapFunc>
Result map(LeftMapFunc&& leftMapFunc, RightMapFunc&& rightMapFunc) const {
if (Side::left == _side) {
return std::forward<LeftMapFunc>(leftMapFunc)(_left);
} else {
return std::forward<RightMapFunc>(rightMapFunc)(_right);
}
}
private:
union {
Left _left;