mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
free up dispatch key space (in C++) (#72402)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72402
The original PR had an array-out-of-bounds access in `DispatchKeyExtractor.cpp`, that wasn't caught by ASAN and appeared to only manifest in a subset of android internal tests. After fixing the OOB access (and adding more asserts), I confirmed that the android internal test passes.
Reland of D33255193 (20b8653dfa
)
ghstack-source-id: 148830728
Test Plan:
Steps to test:
(1) connect to a mobile OD
(2) run `one_world android emulator android-29` in a terminal to start the android emulator
(3) In a separate terminal, run the test: `buck test //fbandroid/instrumentation_tests/com/facebook/pytorch/bi_xray:instrumentation_test -c test.external_runner=tpx -- --regex 'testBIXRayModel.*PyTorchBIXRayInstrumentationTest' --force-remote-execution --run-disabled`
I also ran `buck test fbandroid/mode/dbg //fbandroid/instrumentation_tests/com/facebook/pytorch/bi_xray:instrumentation_test`, which failed before and passed after the PR.
Reviewed By: albanD
Differential Revision: D34034848
fbshipit-source-id: 9677ee2c0a1afd1183896f7055009445712523c5
(cherry picked from commit 9ab9b12d355540ad0923c6869ed088ff6c21490c)
This commit is contained in:
committed by
PyTorch MergeBot
parent
4f8b986e28
commit
6690256021
@ -28,8 +28,7 @@ constexpr auto kFunctorchWrappedTensors = DispatchKeySet({
|
||||
|
||||
constexpr auto kTensorSubclassLike = kFunctorchWrappedTensors | DispatchKeySet({
|
||||
DispatchKey::Batched,
|
||||
DispatchKey::SparseCPU,
|
||||
DispatchKey::SparseCUDA,
|
||||
DispatchKey::Sparse,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::Meta,
|
||||
|
@ -43,7 +43,6 @@ inline bool variable_excluded_from_dispatch() {
|
||||
// Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
|
||||
return true;
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::Autograd));
|
||||
return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
|
||||
#endif
|
||||
}
|
||||
|
@ -6,11 +6,52 @@
|
||||
namespace c10 {
|
||||
|
||||
void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) {
|
||||
// (1) update nonFallthroughKeys_
|
||||
if (has_fallthrough) {
|
||||
nonFallthroughKeys_ = nonFallthroughKeys_.remove(k);
|
||||
} else {
|
||||
nonFallthroughKeys_ = nonFallthroughKeys_.add(k);
|
||||
}
|
||||
// (2) update nonFallthroughKeysPerBackend_
|
||||
if (isPerBackendFunctionalityKey(toFunctionalityKey(k))) {
|
||||
// This is a per-backend functionality key.
|
||||
// We need to figure out what the current backend is,
|
||||
// and only update the bitset for that backend.
|
||||
// subtracting 1 because the first backend should have index 0 (CPU),
|
||||
// But the enum starts with BackendComponent::InvalidBit.
|
||||
auto backend_idx = static_cast<uint8_t>(toBackendComponent(k)) - 1;
|
||||
TORCH_INTERNAL_ASSERT(backend_idx >= 0 && static_cast<uint8_t>(backend_idx) < nonFallthroughKeysPerBackend_.size());
|
||||
if (has_fallthrough) {
|
||||
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].remove(k);
|
||||
} else {
|
||||
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].add(k);
|
||||
}
|
||||
|
||||
// Set requiresBitsetPerBackend_ accordingly
|
||||
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size() - 1)) {
|
||||
if (nonFallthroughKeysPerBackend_[i] != nonFallthroughKeysPerBackend_[i+1]) {
|
||||
requiresBitsetPerBackend_ = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
requiresBitsetPerBackend_ = false;
|
||||
return;
|
||||
} else {
|
||||
// Otherwise, if a fallthrough is set for a functionality that isn't per backend,
|
||||
// Then we update the fallthrough bitset for EVERY backend.
|
||||
// TODO: we could probably optimize this by only lazily updating these values
|
||||
// the first time that we see requiresBitsetPerBackend_ = true
|
||||
// (which should almost never happen)
|
||||
if (has_fallthrough) {
|
||||
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
|
||||
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].remove(k);
|
||||
}
|
||||
} else {
|
||||
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
|
||||
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].add(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string DispatchKeyExtractor::dumpState() const {
|
||||
|
@ -156,14 +156,24 @@ public:
|
||||
}
|
||||
});
|
||||
// Keys that are fallthrough should be skipped
|
||||
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
||||
if (requiresBitsetPerBackend_) {
|
||||
auto backend_idx = ks.getBackendIndex();
|
||||
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
|
||||
} else {
|
||||
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
||||
}
|
||||
}
|
||||
|
||||
template<class... Args>
|
||||
DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
|
||||
auto ks = detail::multi_dispatch_key_set(args...);
|
||||
// Keys that are fallthrough should be skipped
|
||||
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
||||
if (requiresBitsetPerBackend_) {
|
||||
auto backend_idx = ks.getBackendIndex();
|
||||
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
|
||||
} else {
|
||||
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
|
||||
}
|
||||
}
|
||||
|
||||
void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
|
||||
@ -193,7 +203,12 @@ private:
|
||||
|
||||
explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
|
||||
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
|
||||
, nonFallthroughKeys_(DispatchKeySet::FULL) {}
|
||||
, nonFallthroughKeys_(DispatchKeySet::FULL)
|
||||
, requiresBitsetPerBackend_(false) {
|
||||
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
|
||||
nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
|
||||
}
|
||||
}
|
||||
|
||||
// this is a bitset that has ones for each argument index which has to be
|
||||
// considered for dispatch. This avoids having to iterate over the stack
|
||||
@ -205,8 +220,14 @@ private:
|
||||
// fallthrough
|
||||
c10::utils::bitset dispatch_arg_indices_reverse_;
|
||||
|
||||
// Set of keys for which the operator does NOT have fallthrough kernel.
|
||||
// Set of functionality keys for which the operator does NOT have fallthrough kernel.
|
||||
DispatchKeySet nonFallthroughKeys_;
|
||||
// Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
|
||||
// This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
|
||||
std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
|
||||
// Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
|
||||
// or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
|
||||
bool requiresBitsetPerBackend_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -267,14 +267,15 @@ void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name)
|
||||
RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
TORCH_CHECK(
|
||||
!backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].kernel.isValid(),
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
||||
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].debug, ", new registration ", debug
|
||||
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
||||
);
|
||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||
// cannot be unobxed
|
||||
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||
|
||||
for (auto& op : operators_) {
|
||||
op.op.updateFallback(*this, dispatchKey);
|
||||
@ -288,7 +289,8 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
||||
void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = {};
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
backendFallbackKernels_[idx] = {};
|
||||
|
||||
for (auto& op : operators_) {
|
||||
op.op.updateFallback(*this, dispatchKey);
|
||||
|
@ -291,7 +291,7 @@ private:
|
||||
// Map from namespace to debug string (saying, e.g., where the library was defined)
|
||||
ska::flat_hash_map<std::string, std::string> libraries_;
|
||||
|
||||
std::array<impl::AnnotatedKernel, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> backendFallbackKernels_;
|
||||
std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
|
||||
|
||||
std::unique_ptr<detail::RegistrationListenerList> listeners_;
|
||||
std::mutex mutex_;
|
||||
@ -531,8 +531,7 @@ C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorH
|
||||
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
|
||||
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
|
||||
.template getDispatchKeySetUnboxed<Args...>(args...);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
|
||||
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
|
||||
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
// By default, when there're no high-frequency or non-sampled callbacks,
|
||||
// RecordFunction is pre-sampled as a perf optimization;
|
||||
@ -553,7 +552,7 @@ template<class Return, class... Args>
|
||||
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
|
||||
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
|
||||
// do not use RecordFunction on redispatch
|
||||
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
|
||||
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
|
||||
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
@ -561,7 +560,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
|
||||
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
|
||||
const auto& entry = op.operatorDef_->op;
|
||||
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
|
||||
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
|
||||
const auto& kernel = entry.lookup(dispatchKeySet);
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
@ -593,7 +592,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
|
||||
inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
|
||||
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
|
||||
const auto& entry = op.operatorDef_->op;
|
||||
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
|
||||
const auto& kernel = entry.lookup(dispatchKeySet);
|
||||
return kernel.callBoxed(op, dispatchKeySet, stack);
|
||||
}
|
||||
|
||||
|
@ -283,7 +283,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
||||
}
|
||||
|
||||
// 3. Backend fallback
|
||||
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
|
||||
auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
|
||||
if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
|
||||
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
|
||||
}
|
||||
@ -299,10 +299,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
||||
// or alias keys and their associated keysets).
|
||||
// This function should be considered a private helper for updateDispatchTable_()
|
||||
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
|
||||
const auto dispatch_ix = c10::getDispatchTableIndexForDispatchKey(dispatch_key);
|
||||
if (C10_UNLIKELY(dispatch_ix == -1)) {
|
||||
return;
|
||||
}
|
||||
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
|
||||
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
|
||||
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
|
||||
}
|
||||
@ -329,8 +326,12 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
|
||||
}
|
||||
// Note [Refresh Runtime Autograd entries in dispatchTable_]
|
||||
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
|
||||
// In theory, we should only have to check if the given runtime key has "dense" functionality,
|
||||
// e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
|
||||
// However, there are some backends that should be included in this set that don't have the dense key set.
|
||||
// E.g. DispatchKey::Meta, DispatchKey::ORT.
|
||||
if (c10::isBackendDispatchKey(dispatch_key)) {
|
||||
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
|
||||
DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
|
||||
updateDispatchTableEntry_(dispatcher, autograd_key);
|
||||
}
|
||||
}
|
||||
@ -357,8 +358,9 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher)
|
||||
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd
|
||||
// or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd)
|
||||
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
|
||||
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
|
||||
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
|
||||
updateDispatchTable_(dispatcher, DispatchKey::Undefined);
|
||||
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
|
||||
updateDispatchTable_(dispatcher, k);
|
||||
}
|
||||
}
|
||||
|
||||
@ -371,9 +373,10 @@ void OperatorEntry::checkInvariants() const {
|
||||
for (const auto& kv : kernels_) {
|
||||
TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState());
|
||||
}
|
||||
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
|
||||
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), static_cast<DispatchKey>(iter));
|
||||
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[iter]),
|
||||
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
|
||||
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), k);
|
||||
auto idx = getDispatchTableIndexForDispatchKey(k);
|
||||
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[idx]),
|
||||
"Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n"
|
||||
"Computed table:\n~~~~~~~~~~~\n", dumpComputedTable());
|
||||
}
|
||||
@ -384,7 +387,8 @@ std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
str << "[";
|
||||
|
||||
bool has_kernels = false;
|
||||
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
|
||||
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
|
||||
auto iter = getDispatchTableIndexForDispatchKey(k);
|
||||
if (!dispatchTable_[iter].isValid()) {
|
||||
continue;
|
||||
}
|
||||
@ -443,8 +447,12 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const {
|
||||
// updateDispatchTableFull_ would update the dispatch table to be)
|
||||
std::string OperatorEntry::dumpComputedTable() const {
|
||||
std::ostringstream oss;
|
||||
for (uint8_t i = 0; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
|
||||
auto k = static_cast<DispatchKey>(i);
|
||||
// Need to handle Undefined separately, because its a runtime key that can't be represented
|
||||
// in a DispatchKeySet.
|
||||
std::vector<DispatchKey> runtime_keys = {DispatchKey::Undefined};
|
||||
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) runtime_keys.push_back(k);
|
||||
|
||||
for (auto k : runtime_keys) {
|
||||
auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
|
||||
if (kernel_prov.first.kernel.isValid()) {
|
||||
oss << toString(k) << ": "
|
||||
|
@ -173,11 +173,8 @@ public:
|
||||
|
||||
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
|
||||
|
||||
const KernelFunction& lookup(DispatchKey k) const {
|
||||
const auto idx = getDispatchTableIndexForDispatchKey(k);
|
||||
if (C10_UNLIKELY(idx == -1)) {
|
||||
reportError(k);
|
||||
}
|
||||
const KernelFunction& lookup(DispatchKeySet ks) const {
|
||||
const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
|
||||
const auto& kernel = dispatchTable_[idx];
|
||||
// A valid kernel *always* has a boxed kernel and *may* have an
|
||||
// unboxed kernel. However, we typically do unboxed calls in at::
|
||||
@ -187,7 +184,7 @@ public:
|
||||
// in the common case.
|
||||
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
|
||||
if (!kernel.isValid()) {
|
||||
reportError(k);
|
||||
reportError(ks.highestPriorityTypeId());
|
||||
}
|
||||
}
|
||||
return kernel;
|
||||
@ -211,7 +208,7 @@ private:
|
||||
OperatorName name_;
|
||||
c10::optional<AnnotatedSchema> schema_;
|
||||
|
||||
std::array<KernelFunction, c10::getDispatchTableIndexForDispatchKey(DispatchKey::NumDispatchKeys)> dispatchTable_;
|
||||
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
|
||||
DispatchKeyExtractor dispatchKeyExtractor_;
|
||||
|
||||
// kernels_ stores all registered kernels for the corresponding dispatch key
|
||||
|
@ -591,7 +591,7 @@ TEST(OperatorRegistrationTest, AutogradBackendOverridesAutogradKernel) {
|
||||
|
||||
void LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key) {
|
||||
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
|
||||
.kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(key))
|
||||
.kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(toBackendComponent(key)))
|
||||
.kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
|
||||
|
||||
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
|
||||
@ -1791,22 +1791,22 @@ TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) {
|
||||
|
||||
TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
bool sparsecpu_called, math_called = false;
|
||||
bool fpga_called, math_called = false;
|
||||
auto m = MAKE_TORCH_LIBRARY(test);
|
||||
m.def("fn", torch::dispatch(c10::DispatchKey::SparseCPU, [&](const Tensor& x) { sparsecpu_called = true; return x; }));
|
||||
m.def("fn", torch::dispatch(c10::DispatchKey::FPGA, [&](const Tensor& x) { fpga_called = true; return x; }));
|
||||
m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
|
||||
|
||||
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
{
|
||||
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
|
||||
ASSERT_TRUE(sparsecpu_called);
|
||||
callOp(*op, dummyTensor(c10::DispatchKey::FPGA));
|
||||
ASSERT_TRUE(fpga_called);
|
||||
}
|
||||
|
||||
{
|
||||
expectThrows<c10::Error>([&] {
|
||||
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
|
||||
callOp(*op, dummyTensor(c10::DispatchKey::FPGA, /*requires_grad=*/true));
|
||||
}, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther.");
|
||||
}
|
||||
}
|
||||
@ -1849,18 +1849,15 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
|
||||
}
|
||||
|
||||
{
|
||||
// TODO(#43908): currently this will fallthrough AutogradPrivateUse1 then call catchall kernel
|
||||
// at AutogradCPU, while backend extenders are indeed expecting to call PrivateUse1 kernel.
|
||||
// This confusing behavior is caused by we registering fallthrough as backend fallback for
|
||||
// Autograd keys. Note users could always work around this by registering the same kernel to
|
||||
// AutogradPrivateUse1 as shown below until we support it.
|
||||
auto op = Dispatcher::singleton().findOp({"test::fn", ""});
|
||||
ASSERT_TRUE(op.has_value());
|
||||
catchall_called = false;
|
||||
privateuse1_called = false;
|
||||
callOp(*op,
|
||||
dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true),
|
||||
dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
|
||||
ASSERT_TRUE(catchall_called);
|
||||
ASSERT_FALSE(catchall_called);
|
||||
ASSERT_TRUE(privateuse1_called);
|
||||
}
|
||||
|
||||
m.impl("fn", c10::DispatchKey::AutogradPrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; });
|
||||
@ -1876,6 +1873,27 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NewOperatorRegistrationTest, registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite) {
|
||||
bool math_called = false;
|
||||
bool cpu_called = false;
|
||||
auto m = MAKE_TORCH_LIBRARY(test);
|
||||
m.def("fn(Tensor dummy) -> Tensor");
|
||||
m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; });
|
||||
m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
|
||||
|
||||
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
{
|
||||
math_called = cpu_called = false;
|
||||
// Meta should redispatch to the AutogradOther backend,
|
||||
// which the composite kernel should be registered to.
|
||||
callOp(*op, dummyTensor(c10::DispatchKey::Meta, /*requires_grad=*/true));
|
||||
ASSERT_TRUE(math_called);
|
||||
ASSERT_FALSE(cpu_called);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NewOperatorRegistrationTest, dispatchMultiple) {
|
||||
bool cpu_called = false;
|
||||
bool cuda_called = false;
|
||||
|
@ -1,14 +1,47 @@
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
const char* toString(BackendComponent t) {
|
||||
switch (t) {
|
||||
case BackendComponent::CPUBit:
|
||||
return "CPUBit";
|
||||
case BackendComponent::CUDABit:
|
||||
return "CUDABit";
|
||||
case BackendComponent::HIPBit:
|
||||
return "HIPBit";
|
||||
case BackendComponent::XLABit:
|
||||
return "XLABit";
|
||||
case BackendComponent::LazyBit:
|
||||
return "LazyBit";
|
||||
case BackendComponent::XPUBit:
|
||||
return "XPUBit";
|
||||
case BackendComponent::MLCBit:
|
||||
return "MLCBit";
|
||||
case BackendComponent::HPUBit:
|
||||
return "HPUBit";
|
||||
case BackendComponent::VEBit:
|
||||
return "VEBit";
|
||||
case BackendComponent::PrivateUse1Bit:
|
||||
return "PrivateUse1Bit";
|
||||
case BackendComponent::PrivateUse2Bit:
|
||||
return "PrivateUse2Bit";
|
||||
case BackendComponent::PrivateUse3Bit:
|
||||
return "PrivateUse3Bit";
|
||||
case BackendComponent::InvalidBit:
|
||||
return "InvalidBit";
|
||||
default:
|
||||
return "UNKNOWN_BACKEND_BIT";
|
||||
}
|
||||
}
|
||||
|
||||
const char* toString(DispatchKey t) {
|
||||
switch (t) {
|
||||
case DispatchKey::Undefined:
|
||||
return "Undefined";
|
||||
|
||||
case DispatchKey::CPU:
|
||||
return "CPU";
|
||||
case DispatchKey::CUDA:
|
||||
@ -101,8 +134,6 @@ const char* toString(DispatchKey t) {
|
||||
return "AutogradMLC";
|
||||
case DispatchKey::AutogradHPU:
|
||||
return "AutogradHPU";
|
||||
case DispatchKey::AutogradNestedTensor:
|
||||
return "AutogradNestedTensor";
|
||||
case DispatchKey::AutogradPrivateUse1:
|
||||
return "AutogradPrivateUse1";
|
||||
case DispatchKey::AutogradPrivateUse2:
|
||||
@ -111,6 +142,8 @@ const char* toString(DispatchKey t) {
|
||||
return "AutogradPrivateUse3";
|
||||
case DispatchKey::AutogradOther:
|
||||
return "AutogradOther";
|
||||
case DispatchKey::AutogradNestedTensor:
|
||||
return "AutogradNestedTensor";
|
||||
|
||||
case DispatchKey::ZeroTensor:
|
||||
return "ZeroTensor";
|
||||
@ -168,6 +201,15 @@ const char* toString(DispatchKey t) {
|
||||
case DispatchKey::FuncTorchBatched:
|
||||
return "FuncTorchBatched";
|
||||
|
||||
case DispatchKey::Dense:
|
||||
return "Dense";
|
||||
case DispatchKey::Quantized:
|
||||
return "Quantized";
|
||||
case DispatchKey::Sparse:
|
||||
return "Sparse";
|
||||
case DispatchKey::AutogradFunctionality:
|
||||
return "AutogradFunctionality";
|
||||
|
||||
default:
|
||||
return "UNKNOWN_TENSOR_TYPE_ID";
|
||||
}
|
||||
@ -176,76 +218,37 @@ const char* toString(DispatchKey t) {
|
||||
std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
|
||||
return str << toString(rhs);
|
||||
}
|
||||
std::ostream& operator<<(std::ostream& str, BackendComponent rhs) {
|
||||
return str << toString(rhs);
|
||||
}
|
||||
|
||||
// for a given backend key, return the associated autograd key.
|
||||
// for non-backend keys, return AutogradOther as a default.
|
||||
// Note: it's convenient and fast to return a default here rather than (say)
|
||||
// returning an optional<DispatchKey>, or throwing. But it makes callers
|
||||
// responsible for either a) enforcing the invariant that only backend keys
|
||||
// be passed as arguments, or b) interpreting our return value carefully.
|
||||
//
|
||||
DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
|
||||
switch (t) {
|
||||
case DispatchKey::CPU:
|
||||
return DispatchKey::AutogradCPU;
|
||||
case DispatchKey::XPU:
|
||||
return DispatchKey::AutogradXPU;
|
||||
case DispatchKey::CUDA:
|
||||
return DispatchKey::AutogradCUDA;
|
||||
case DispatchKey::XLA:
|
||||
return DispatchKey::AutogradXLA;
|
||||
case DispatchKey::Lazy:
|
||||
return DispatchKey::AutogradLazy;
|
||||
case DispatchKey::MLC:
|
||||
return DispatchKey::AutogradMLC;
|
||||
case DispatchKey::HPU:
|
||||
return DispatchKey::AutogradHPU;
|
||||
case DispatchKey::NestedTensor:
|
||||
return DispatchKey::AutogradNestedTensor;
|
||||
case DispatchKey::PrivateUse1:
|
||||
return DispatchKey::AutogradPrivateUse1;
|
||||
case DispatchKey::PrivateUse2:
|
||||
return DispatchKey::AutogradPrivateUse2;
|
||||
case DispatchKey::PrivateUse3:
|
||||
return DispatchKey::AutogradPrivateUse3;
|
||||
default:
|
||||
return DispatchKey::AutogradOther;
|
||||
}
|
||||
DispatchKey getAutogradKeyFromBackend(BackendComponent k) {
|
||||
// We want this to return an autograd key. We're relying on the fact that
|
||||
// getAutogradRelatedKeySetFromBackend returns an autograd key +
|
||||
// ADInplaceOrView, and autograd has higher precedence. The core mapping from
|
||||
// backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend`
|
||||
// instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a
|
||||
// hotpath function, and we want to make sure that it doesn't have to
|
||||
// construct any DispatchKeySets at runtime.
|
||||
return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId();
|
||||
}
|
||||
|
||||
c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
|
||||
{"Undefined", c10::DispatchKey::Undefined},
|
||||
{"CPU", c10::DispatchKey::CPU},
|
||||
{"CUDA", c10::DispatchKey::CUDA},
|
||||
{"HIP", c10::DispatchKey::HIP},
|
||||
{"Dense", c10::DispatchKey::Dense},
|
||||
{"FPGA", c10::DispatchKey::FPGA},
|
||||
{"ORT", c10::DispatchKey::ORT},
|
||||
{"XLA", c10::DispatchKey::XLA},
|
||||
{"MLC", c10::DispatchKey::MLC},
|
||||
{"Vulkan", c10::DispatchKey::Vulkan},
|
||||
{"Metal", c10::DispatchKey::Metal},
|
||||
{"XPU", c10::DispatchKey::XPU},
|
||||
{"HPU", c10::DispatchKey::HPU},
|
||||
{"VE", c10::DispatchKey::VE},
|
||||
{"Lazy", c10::DispatchKey::Lazy},
|
||||
{"Meta", c10::DispatchKey::Meta},
|
||||
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
|
||||
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
|
||||
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
|
||||
{"Quantized", c10::DispatchKey::Quantized},
|
||||
{"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
|
||||
{"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
|
||||
{"SparseCPU", c10::DispatchKey::SparseCPU},
|
||||
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
|
||||
{"SparseHIP", c10::DispatchKey::SparseHIP},
|
||||
{"SparseXPU", c10::DispatchKey::SparseXPU},
|
||||
{"SparseVE", c10::DispatchKey::SparseVE},
|
||||
{"Sparse", c10::DispatchKey::Sparse},
|
||||
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
|
||||
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
|
||||
{"NestedTensor", c10::DispatchKey::NestedTensor},
|
||||
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
|
||||
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
|
||||
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
|
||||
{"BackendSelect", c10::DispatchKey::BackendSelect},
|
||||
{"Python", c10::DispatchKey::Python},
|
||||
{"Named", c10::DispatchKey::Named},
|
||||
@ -256,17 +259,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
c10::DispatchKey::FuncTorchDynamicLayerBackMode},
|
||||
{"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
|
||||
{"AutogradOther", c10::DispatchKey::AutogradOther},
|
||||
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
|
||||
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
|
||||
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
|
||||
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
|
||||
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
|
||||
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
|
||||
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
|
||||
{"AutogradFunctionality", c10::DispatchKey::AutogradFunctionality},
|
||||
{"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
|
||||
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
|
||||
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
|
||||
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
|
||||
{"Tracer", c10::DispatchKey::Tracer},
|
||||
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
|
||||
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
|
||||
@ -280,6 +274,41 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"TESTING_ONLY_GenericWrapper",
|
||||
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
|
||||
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
|
||||
|
||||
{"CPU", c10::DispatchKey::CPU},
|
||||
{"CUDA", c10::DispatchKey::CUDA},
|
||||
{"HIP", c10::DispatchKey::HIP},
|
||||
{"XLA", c10::DispatchKey::XLA},
|
||||
{"MLC", c10::DispatchKey::MLC},
|
||||
{"XPU", c10::DispatchKey::XPU},
|
||||
{"HPU", c10::DispatchKey::HPU},
|
||||
{"Lazy", c10::DispatchKey::Lazy},
|
||||
{"NestedTensor", c10::DispatchKey::NestedTensor},
|
||||
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
|
||||
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
|
||||
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
|
||||
|
||||
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
|
||||
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
|
||||
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
|
||||
|
||||
{"SparseCPU", c10::DispatchKey::SparseCPU},
|
||||
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
|
||||
{"SparseHIP", c10::DispatchKey::SparseHIP},
|
||||
{"SparseXPU", c10::DispatchKey::SparseXPU},
|
||||
{"SparseVE", c10::DispatchKey::SparseVE},
|
||||
|
||||
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
|
||||
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
|
||||
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
|
||||
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
|
||||
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
|
||||
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
|
||||
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
|
||||
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
|
||||
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
|
||||
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
|
||||
|
||||
{"Autograd", c10::DispatchKey::Autograd},
|
||||
{"CompositeImplicitAutograd",
|
||||
c10::DispatchKey::CompositeImplicitAutograd},
|
||||
|
@ -9,20 +9,98 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// Semantically, each value of BackendComponent identifies a "backend" for our
|
||||
// dispatch. Some functionalities that we may dispatch to are allowed to
|
||||
// register different handlers for each backend. The BackendComponent is then
|
||||
// used to figure out which backend implementation to dispatch to.
|
||||
|
||||
// In implementation terms, the backend component identifies a specific "bit" in
|
||||
// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom
|
||||
// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to
|
||||
// functionalities. When we encounter a functionality bit that is known to be
|
||||
// customizeable per-backend, then we also look at the lower BackendComponent
|
||||
// bits and take the highest bit to determine which backend's implementation to
|
||||
// use.
|
||||
|
||||
enum class BackendComponent : uint8_t {
|
||||
|
||||
// A "backend" is colloquially used to refer to handlers for dispatch
|
||||
// which actually implement the numerics of an operation in question.
|
||||
//
|
||||
// Due to the nature of the enum, these backends are specified in
|
||||
// an ordered way, but for most backends this order is not semantically
|
||||
// meaningful (e.g., it's valid to reorder these backends without changing
|
||||
// semantics). The only situation when backend ordering is meaningful
|
||||
// is when the backend participates in multiple dispatch with another
|
||||
// backend; e.g., CPU and CUDA (cuda must have higher priority).
|
||||
|
||||
// These keys don't correspond to individual kernels.
|
||||
// Instead, they represent the backends that are allowed to override specific
|
||||
// pieces of functionality:
|
||||
// - dense kernels (e.g. DispatchKey::CPU)
|
||||
// - sparse kernels (e.g. DispatchKey::SparseCPU)
|
||||
// - quantized kernels (e.g. DispatchKey::QuantizedCPU)
|
||||
// - autograd kernels (e.g. DispatchKey::AutogradCPU)
|
||||
// We reserve space in the runtime operator table for this full cross product
|
||||
// of
|
||||
// [backends in this enum] x [keys below that are explicitly marked as having
|
||||
// per-backend functionality]
|
||||
|
||||
InvalidBit = 0,
|
||||
CPUBit,
|
||||
CUDABit,
|
||||
HIPBit,
|
||||
XLABit,
|
||||
MLCBit,
|
||||
XPUBit,
|
||||
HPUBit,
|
||||
VEBit,
|
||||
LazyBit,
|
||||
PrivateUse1Bit,
|
||||
PrivateUse2Bit,
|
||||
PrivateUse3Bit,
|
||||
// Define an alias to represent end of backend dispatch keys.
|
||||
// If you add new backend keys after PrivateUse3, please also update it here.
|
||||
// (But you shouldn't: private use keys should have higher precedence than
|
||||
// all built-in keys)
|
||||
EndOfBackendKeys = PrivateUse3Bit,
|
||||
};
|
||||
|
||||
// Semantically, a dispatch key identifies a possible "level" in our
|
||||
// dispatch, for which a handler may be registered. Traditional
|
||||
// backends like CPU and CUDA get dispatch keys; however, so do
|
||||
// "wrapping" layers like Variable (for autograd handling).
|
||||
// dispatch, for which a handler may be registered. Each handler corresponds
|
||||
// to a type of functionality.
|
||||
//
|
||||
// In implementation terms, the dispatch key identifies a specific "bit" in a
|
||||
// DispatchKeySet. Higher bit indexes get handled by dispatching first (because
|
||||
// we "count leading zeros" when we extract the highest priority dispatch
|
||||
// key.)
|
||||
//
|
||||
// Note [DispatchKey Classification]
|
||||
// This enum actually contains several types of keys, which are explained
|
||||
// in more detail further down:
|
||||
// (1) non-customizable backends (e.g. FPGA)
|
||||
// (2) non-customizable functionalities (e.g. Functionalize)
|
||||
// (3) functionalized that are customizable per backend (e.g. Dense, Sparse,
|
||||
// AutogradFunctionality) (4) per-backend instances of customizable
|
||||
// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g.
|
||||
// CompositeImplicitAutograd)
|
||||
//
|
||||
// Of the categories above, it's important to note:
|
||||
// (a) which keys are assigned individual bits in a DispatchKeySet
|
||||
// (b) which keys are assigned individual slots in the runtime operator table
|
||||
// ("Runtime keys")
|
||||
//
|
||||
// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet.
|
||||
// (1), (2) and (4) all get their own dedicated slots in the runtime operator
|
||||
// table.
|
||||
|
||||
// See Note [DispatchKeySet Internal Representation] for more details.
|
||||
//
|
||||
// NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py
|
||||
enum class DispatchKey : uint8_t {
|
||||
enum class DispatchKey : uint16_t {
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
// This is not a "real" tensor id, but it exists to give us a "nullopt"
|
||||
// This is not a "real" functionality, but it exists to give us a "nullopt"
|
||||
// element we can return for cases when a DispatchKeySet contains no elements.
|
||||
// You can think a more semantically accurate definition of DispatchKey is:
|
||||
//
|
||||
@ -38,24 +116,31 @@ enum class DispatchKey : uint8_t {
|
||||
// this will get eliminated, but for now it's convenient)
|
||||
CatchAll = Undefined,
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ BACKENDS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
// A "backend" is colloquially used to refer to handlers for dispatch
|
||||
// which actually implement the numerics of an operation in question.
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
// Every value in the enum (up to EndOfFunctionalityKeys)
|
||||
// corresponds to an individual "functionality" that can be dispatched to.
|
||||
// This is represented in the DispatchKeySet by assigning each of these enum
|
||||
// values
|
||||
// to each of the remaining (64 - len(BackendComponent)) bits.
|
||||
//
|
||||
// Due to the nature of the enum, these backends are specified in
|
||||
// an ordered way, but for most backends this order is not semantically
|
||||
// meaningful (e.g., it's valid to reorder these backends without changing
|
||||
// semantics). The only situation when backend ordering is meaningful
|
||||
// is when the backend participates in multiple dispatch with another
|
||||
// backend; e.g., CPU and SparseCPU (sparse must have
|
||||
// higher priority).
|
||||
// Most of these functionalities have a single handler assigned to them,
|
||||
// making them "runtime keys".
|
||||
// That map to a single slot in the runtime operator table.
|
||||
//
|
||||
// A few functionalities are allowed to be customizable per backend.
|
||||
// See [Note: Per-Backend Functionality Dispatch Keys] for details.
|
||||
|
||||
// See [Note: Per-Backend Functionality Dispatch Keys]
|
||||
Dense,
|
||||
|
||||
// Below are non-extensible backends.
|
||||
// These are backends that currently don't have their own overrides for
|
||||
// Autograd/Sparse/Quantized kernels,
|
||||
// and we therefore don't waste space in the runtime operator table allocating
|
||||
// space for them.
|
||||
// If any of these backends ever need to customize, e.g., Autograd, then we'll
|
||||
// need to add a DispatchKey::*Bit for them.
|
||||
|
||||
// Here are backends which you think of as traditionally specifying
|
||||
// how to implement operations on some device.
|
||||
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
|
||||
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
|
||||
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
|
||||
// CUDA]
|
||||
FPGA, // Xilinx support lives out of tree at
|
||||
// https://gitlab.com/pytorch-complex/vitis_kernels
|
||||
|
||||
@ -67,14 +152,8 @@ enum class DispatchKey : uint8_t {
|
||||
// - aten/src/ATen/test/extension_backend_test.cpp
|
||||
ORT,
|
||||
|
||||
XLA, // lives out of tree at https://github.com/pytorch/xla
|
||||
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
|
||||
Vulkan,
|
||||
Metal,
|
||||
XPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||
HPU, // For out of tree & closed source integration of HPU / Habana
|
||||
VE, // For out of tree & closed source integration of SX-Aurora / NEC
|
||||
Lazy, // For lazy tensor backends
|
||||
|
||||
// A meta tensor is a tensor without any data associated with it. (They
|
||||
// have also colloquially been referred to as tensors on the "null" device).
|
||||
@ -83,11 +162,8 @@ enum class DispatchKey : uint8_t {
|
||||
// tensor with the output shape and dtype, but wouldn't actually add anything.
|
||||
Meta,
|
||||
|
||||
// Here are backends which specify more specialized operators
|
||||
// based on the dtype of the tensor.
|
||||
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
|
||||
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
|
||||
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||
// See [Note: Per-Backend Functionality Dispatch Keys]
|
||||
Quantized,
|
||||
|
||||
// This backend is to support custom RNGs; it lets you go
|
||||
// to a different kernel if you pass in a generator that is not a
|
||||
@ -106,31 +182,29 @@ enum class DispatchKey : uint8_t {
|
||||
// the corresponding dense tensors, and must be handled before them.
|
||||
MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp
|
||||
// NB: not to be confused with MKLDNN, which is Caffe2 only
|
||||
SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp
|
||||
SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp
|
||||
SparseHIP, // TODO: I think this is not actually used, due to Note
|
||||
// [Masquerading as CUDA]
|
||||
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||
SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC
|
||||
|
||||
// See [Note: Per-Backend Functionality Dispatch Keys]
|
||||
Sparse,
|
||||
|
||||
SparseCsrCPU,
|
||||
SparseCsrCUDA,
|
||||
|
||||
// Note [Non-Customizable Backend Keys]
|
||||
// Every key above here is considered a "non-customizable backend".
|
||||
// These are backends that will work correctly with autograd, but
|
||||
// but currently don't require separate implementations
|
||||
// for autograd sparse or quantized kernels.
|
||||
// Any new backends that don't need to be customized should go above here.
|
||||
// If an existing backend needs to e.g. override autograd, then we can
|
||||
// consider promoting it into the "BackendComponent" enum
|
||||
//
|
||||
// For all intents and purposes from the perspective of DispatchKeySet,
|
||||
// "non-customizable backend" keys are treated the same way
|
||||
// as other functionality keys
|
||||
EndOfNonCustomizableBackends = SparseCsrCUDA,
|
||||
|
||||
NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
|
||||
|
||||
// Here are reserved backends for user-defined backends, see Note [Private use
|
||||
// DispatchKey]
|
||||
// To see some example about how to use this, check out ORT
|
||||
PrivateUse1,
|
||||
PrivateUse2,
|
||||
PrivateUse3,
|
||||
|
||||
// Define an alias key to represent end of backend dispatch keys.
|
||||
// If you add new backend keys after PrivateUse3, please also update it here.
|
||||
// (But you shouldn't: private use keys should have higher precedence than
|
||||
// all built-in keys)
|
||||
EndOfBackendKeys = PrivateUse3,
|
||||
|
||||
// In some situations, it is not immediately obvious what the correct
|
||||
// backend for function is, because the function in question doesn't
|
||||
// have any "tensor" arguments. In this case, a BackendSelect function
|
||||
@ -233,20 +307,18 @@ enum class DispatchKey : uint8_t {
|
||||
// AutogradOther key. We can add specific autograd key for those backends
|
||||
// upon request.
|
||||
AutogradOther,
|
||||
AutogradCPU,
|
||||
AutogradCUDA,
|
||||
AutogradXLA,
|
||||
AutogradLazy,
|
||||
AutogradXPU,
|
||||
AutogradMLC,
|
||||
AutogradHPU,
|
||||
AutogradNestedTensor, // lives out of tree at
|
||||
|
||||
// See [Note: Per-Backend Functionality Dispatch Keys]
|
||||
AutogradFunctionality,
|
||||
|
||||
// NestedTensor is an example of something that isn't a "real backend"
|
||||
// (because it mostly consists of redispatching kernels)
|
||||
// but it would like to override autograd functionality in C++.
|
||||
// We can handle cases like this by adding an extra functionality key
|
||||
// exclusively for handling autograd for NestedTensor.
|
||||
// lives out of tree at
|
||||
// https://github.com/pytorch/nestedtensor
|
||||
// Here are some reserved pre-autograd keys for user-defined backends, see
|
||||
// Note [Private use DispatchKey]
|
||||
AutogradPrivateUse1,
|
||||
AutogradPrivateUse2,
|
||||
AutogradPrivateUse3,
|
||||
AutogradNestedTensor,
|
||||
|
||||
Tracer,
|
||||
|
||||
@ -299,9 +371,100 @@ enum class DispatchKey : uint8_t {
|
||||
TESTING_ONLY_GenericMode,
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
NumDispatchKeys, // Sentinel, end of runtime keys.
|
||||
EndOfFunctionalityKeys, // End of functionality keys.
|
||||
|
||||
// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
|
||||
// Here are backends which you think of as traditionally specifying
|
||||
// how to implement operations on some device.
|
||||
|
||||
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
|
||||
StartOfDenseBackends,
|
||||
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
|
||||
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
|
||||
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
|
||||
// CUDA]
|
||||
XLA, // lives out of tree at https://github.com/pytorch/xla
|
||||
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
|
||||
XPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||
HPU, // For out of tree & closed source integration of HPU / Habana
|
||||
VE, // For out of tree & closed source integration of SX-Aurora / NEC
|
||||
Lazy, // For lazy tensor backends
|
||||
// Here are reserved backends for user-defined backends, see Note [Private use
|
||||
// DispatchKey]
|
||||
// To see some example about how to use this, check out ORT
|
||||
PrivateUse1,
|
||||
PrivateUse2,
|
||||
PrivateUse3,
|
||||
EndOfDenseBackends = PrivateUse3,
|
||||
|
||||
// ~~~~~~~~~~~~~~ "Quantized" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~ //
|
||||
// keys starting with an _ are not currently used,
|
||||
// but are needed to ensure that every backend is indexed correctly.
|
||||
|
||||
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
|
||||
StartOfQuantizedBackends,
|
||||
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
|
||||
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
|
||||
_QuantizedHIP,
|
||||
_QuantizedXLA,
|
||||
_QuantizedMLC,
|
||||
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||
_QuantizedHPU,
|
||||
_QuantizedVE,
|
||||
_QuantizedLazy,
|
||||
_QuantizedPrivateUse1,
|
||||
_QuantizedPrivateUse2,
|
||||
_QuantizedPrivateUse3,
|
||||
EndOfQuantizedBackends = _QuantizedPrivateUse3,
|
||||
|
||||
// ~~~~~~~~~~~~~~ "Sparse" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~ //
|
||||
// keys starting with an _ are not currently used,
|
||||
// but are needed to ensure that every backend is indexed correctly.
|
||||
|
||||
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
|
||||
StartOfSparseBackends,
|
||||
SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp
|
||||
SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp
|
||||
SparseHIP, // TODO: I think this is not actually used, due to Note
|
||||
// [Masquerading as CUDA]
|
||||
_SparseXLA,
|
||||
_SparseMLC,
|
||||
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||
_SparseHPU,
|
||||
SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC
|
||||
_SparseLazy,
|
||||
_SparsePrivateUse1,
|
||||
_SparsePrivateUse2,
|
||||
_SparsePrivateUse3,
|
||||
EndOfSparseBackends = _SparsePrivateUse3,
|
||||
|
||||
// ~~~~~~~~~~~~~~ "Autograd" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~ //
|
||||
// keys starting with an _ are not currently used,
|
||||
// but are needed to ensure that every backend is indexed correctly.
|
||||
|
||||
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
|
||||
StartOfAutogradBackends,
|
||||
AutogradCPU,
|
||||
AutogradCUDA,
|
||||
_AutogradHIP,
|
||||
AutogradXLA,
|
||||
AutogradMLC,
|
||||
AutogradXPU,
|
||||
AutogradHPU,
|
||||
_AutogradVE,
|
||||
AutogradLazy,
|
||||
// Here are some reserved pre-autograd keys for user-defined backends, see
|
||||
// Note [Private use DispatchKey]
|
||||
AutogradPrivateUse1,
|
||||
AutogradPrivateUse2,
|
||||
AutogradPrivateUse3,
|
||||
EndOfAutogradBackends = AutogradPrivateUse3,
|
||||
// If we add a new per-backend functionality key that has higher priority
|
||||
// than Autograd, then this key should be updated.
|
||||
EndOfRuntimeBackendKeys = EndOfAutogradBackends,
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
// Note [Alias Dispatch Keys]
|
||||
// Alias dispatch keys are synthetic dispatch keys which map to multiple
|
||||
// runtime dispatch keys. Alisa keys have precedence, but they are always
|
||||
// lower precedence than runtime keys. You can register a kernel to an
|
||||
@ -321,6 +484,7 @@ enum class DispatchKey : uint8_t {
|
||||
|
||||
// Define an alias key to represent end of alias dispatch keys.
|
||||
// If you add new alias keys after Autograd, please also update it here.
|
||||
StartOfAliasKeys = Autograd,
|
||||
EndOfAliasKeys = CompositeExplicitAutograd, //
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
@ -360,54 +524,83 @@ enum class DispatchKey : uint8_t {
|
||||
// built-in autograd formulas for operators are not appropriate.
|
||||
|
||||
static_assert(
|
||||
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
|
||||
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
|
||||
(static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) +
|
||||
static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys)) <= 64,
|
||||
"The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)"
|
||||
" both map to backend and functionality bits"
|
||||
" into a 64-bit bitmask; you must have less than 64 total entries between them");
|
||||
|
||||
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
|
||||
/**
|
||||
* The method below maps the dispatch key in the enum DispatchKey to an
|
||||
* integer index in the dispatchTable_ array in OperatorEntry. The array
|
||||
* is trimmed for mobile to reduce peak memory usage since it's
|
||||
* unnecessary to reserve additional space for dispatch keys that will
|
||||
* never be used on mobile.
|
||||
*/
|
||||
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
|
||||
switch (dk) {
|
||||
case DispatchKey::Undefined:
|
||||
return 0;
|
||||
case DispatchKey::CPU:
|
||||
return 1;
|
||||
case DispatchKey::QuantizedCPU:
|
||||
return 2;
|
||||
case DispatchKey::SparseCPU:
|
||||
return 3;
|
||||
case DispatchKey::BackendSelect:
|
||||
return 4;
|
||||
case DispatchKey::ADInplaceOrView:
|
||||
return 5;
|
||||
case DispatchKey::AutogradOther:
|
||||
return 6;
|
||||
case DispatchKey::AutogradCPU:
|
||||
return 7;
|
||||
case DispatchKey::NumDispatchKeys: // Sentinel, end of runtime keys.
|
||||
return 8;
|
||||
default:
|
||||
return -1;
|
||||
// Check if a DispatchKey is an alias mapping to other runtime keys.
|
||||
constexpr bool isAliasDispatchKey(DispatchKey k) {
|
||||
return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys;
|
||||
}
|
||||
|
||||
// [Note: Per-Backend Functionality Dispatch Keys]
|
||||
// Check if a DispatchKey is a per-backend functionality key
|
||||
// Any functionalities that can be customized per-backend should be added here.
|
||||
// These keys correspond to functionalities that can be customized indivually
|
||||
// per backend. While they only take up one bit in the `DispatchKeySet` bitset,
|
||||
// they map to (# backends) slots in the operator table.
|
||||
// Each of these keys also has a separate set of "runtime keys" in the dispatch
|
||||
// key enum, per backend, which *do* map to the individual operator table slots.
|
||||
// For example, the "Sparse" key maps to an individual bit in the
|
||||
// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual
|
||||
// slots in the runtime operator table.
|
||||
|
||||
constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
|
||||
if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
|
||||
k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#else
|
||||
/**
|
||||
* For the server use-case, make this a simple pass-through.
|
||||
*/
|
||||
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
|
||||
return static_cast<int>(dk);
|
||||
|
||||
// Note that this includes Undefined in the total count.
|
||||
// BUT EndOfFunctionalityKeys is its own (placeholder) key.
|
||||
// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3.
|
||||
// In the above example, there are 3 total functionality keys.
|
||||
constexpr uint8_t num_functionality_keys =
|
||||
static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys);
|
||||
|
||||
// Note [No More Than 16 Backends]
|
||||
// Search for this note to find places in the code where the "no more than 16
|
||||
// backends" invariant is baked in.
|
||||
static_assert(
|
||||
static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) <= 16,
|
||||
"BackendComponent currently only supports <= 16 backends. If we really need to extend this, \
|
||||
there are a few places where this invariant is baked in");
|
||||
|
||||
constexpr uint8_t numPerBackendFunctionalityKeys() {
|
||||
uint8_t count = 0;
|
||||
for (uint8_t k = 0; k <= num_functionality_keys; ++k) {
|
||||
if (isPerBackendFunctionalityKey(static_cast<DispatchKey>(k)))
|
||||
++count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
|
||||
// See [Note: Trimmed Mobile Dispatch Keys]
|
||||
constexpr uint8_t num_backends = 1; // Only CPU
|
||||
constexpr uint16_t num_runtime_entries = 8;
|
||||
#else
|
||||
constexpr uint8_t num_backends =
|
||||
static_cast<uint8_t>(BackendComponent::EndOfBackendKeys);
|
||||
constexpr uint16_t num_runtime_entries = num_functionality_keys +
|
||||
(numPerBackendFunctionalityKeys() * (num_backends - 1));
|
||||
#endif
|
||||
|
||||
C10_API const char* toString(DispatchKey);
|
||||
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
|
||||
// See Note [No More Than 16 Backends]
|
||||
constexpr uint16_t full_backend_mask =
|
||||
(static_cast<uint16_t>(1) << num_backends) - 1;
|
||||
|
||||
C10_API DispatchKey getAutogradKeyFromBackend(DispatchKey t);
|
||||
C10_API const char* toString(DispatchKey);
|
||||
C10_API const char* toString(BackendComponent);
|
||||
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
|
||||
C10_API std::ostream& operator<<(std::ostream&, BackendComponent);
|
||||
|
||||
C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k);
|
||||
|
||||
// Parses a string into a dispatch key.
|
||||
// If the string cannot be correctly parsed, throws an exception.
|
||||
@ -420,10 +613,86 @@ C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
|
||||
// torch::dispatch(torch::kCPU, ...) is also valid.
|
||||
constexpr DispatchKey kAutograd = DispatchKey::Autograd;
|
||||
|
||||
// Check if a DispatchKey is an alias mapping to other runtime keys.
|
||||
inline bool isAliasDispatchKey(DispatchKey k) {
|
||||
return k > DispatchKey::NumDispatchKeys && k <= DispatchKey::EndOfAliasKeys;
|
||||
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
|
||||
// This function relies on the invariant that the dispatch keys between
|
||||
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
|
||||
// in the same order as `BackendComponent`.
|
||||
constexpr BackendComponent toBackendComponent(DispatchKey k) {
|
||||
if (k >= DispatchKey::StartOfDenseBackends &&
|
||||
k <= DispatchKey::EndOfDenseBackends) {
|
||||
return static_cast<BackendComponent>(
|
||||
static_cast<uint8_t>(k) -
|
||||
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends));
|
||||
} else if (
|
||||
k >= DispatchKey::StartOfQuantizedBackends &&
|
||||
k <= DispatchKey::EndOfQuantizedBackends) {
|
||||
return static_cast<BackendComponent>(
|
||||
static_cast<uint8_t>(k) -
|
||||
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends));
|
||||
} else if (
|
||||
k >= DispatchKey::StartOfSparseBackends &&
|
||||
k <= DispatchKey::EndOfSparseBackends) {
|
||||
return static_cast<BackendComponent>(
|
||||
static_cast<uint8_t>(k) -
|
||||
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
|
||||
} else if (
|
||||
k >= DispatchKey::StartOfAutogradBackends &&
|
||||
k <= DispatchKey::EndOfAutogradBackends) {
|
||||
return static_cast<BackendComponent>(
|
||||
static_cast<uint8_t>(k) -
|
||||
static_cast<uint8_t>(DispatchKey::StartOfAutogradBackends));
|
||||
} else {
|
||||
return BackendComponent::InvalidBit;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
|
||||
if (k <= DispatchKey::EndOfFunctionalityKeys) {
|
||||
return k;
|
||||
} else if (k <= DispatchKey::EndOfDenseBackends) {
|
||||
return DispatchKey::Dense;
|
||||
} else if (k <= DispatchKey::EndOfQuantizedBackends) {
|
||||
return DispatchKey::Quantized;
|
||||
} else if (k <= DispatchKey::EndOfSparseBackends) {
|
||||
return DispatchKey::Sparse;
|
||||
} else if (k <= DispatchKey::EndOfAutogradBackends) {
|
||||
return DispatchKey::AutogradFunctionality;
|
||||
} else {
|
||||
return DispatchKey::Undefined;
|
||||
}
|
||||
}
|
||||
|
||||
// Given (DispatchKey::Dense, DispatchKey::CUDABit), returns DispatchKey::CUDA
|
||||
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
|
||||
// This function relies on the invariant that the dispatch keys between
|
||||
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
|
||||
// in the same order as `BackendComponent`.
|
||||
constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
|
||||
DispatchKey functionality_k,
|
||||
BackendComponent backend_k) {
|
||||
if (functionality_k == DispatchKey::Dense) {
|
||||
return static_cast<DispatchKey>(
|
||||
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) +
|
||||
static_cast<uint8_t>(backend_k));
|
||||
}
|
||||
if (functionality_k == DispatchKey::Sparse) {
|
||||
return static_cast<DispatchKey>(
|
||||
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
|
||||
static_cast<uint8_t>(backend_k));
|
||||
}
|
||||
if (functionality_k == DispatchKey::Quantized) {
|
||||
return static_cast<DispatchKey>(
|
||||
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
|
||||
static_cast<uint8_t>(backend_k));
|
||||
}
|
||||
if (functionality_k == DispatchKey::AutogradFunctionality) {
|
||||
return static_cast<DispatchKey>(
|
||||
static_cast<uint8_t>(DispatchKey::StartOfAutogradBackends) +
|
||||
static_cast<uint8_t>(backend_k));
|
||||
}
|
||||
return DispatchKey::Undefined;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
|
@ -1,37 +1,29 @@
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// backend_dispatch_keyset should include all runtime backend keys.
|
||||
// backend_dispatch_keyset includes all dispatch keys that map to backends.
|
||||
// Alias key DispatchKey::CompositeExplicitAutograd maps to
|
||||
// backend_dispatch_keyset NestedTensor has been explicitly removed due to
|
||||
// incompatibility with some kernels, such as structured kernels, that use the
|
||||
// DefaultBackend key.
|
||||
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
|
||||
DispatchKeySet({
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::Lazy,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
DispatchKey::PrivateUse2,
|
||||
DispatchKey::PrivateUse3,
|
||||
DispatchKey::MLC,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::ORT,
|
||||
DispatchKey::Meta,
|
||||
});
|
||||
// backend_dispatch_keyset
|
||||
constexpr DispatchKeySet backend_dispatch_keyset =
|
||||
autogradother_backends | DispatchKeySet(DispatchKey::Dense);
|
||||
|
||||
bool isBackendDispatchKey(DispatchKey t) {
|
||||
return t != DispatchKey::Undefined
|
||||
// See Note [No Alias Keys in DispatchKeySet]
|
||||
&& !isAliasDispatchKey(t) && backend_dispatch_keyset.has(t);
|
||||
&& !isAliasDispatchKey(t)
|
||||
// Note [NestedTensor Not Included in Backend Keys]
|
||||
// NestedTensor has been explicitly removed from the "backend keyset" due
|
||||
// to incompatibility with some kernels, so we don't want it to be
|
||||
// included in CompositeImplicitAutograd or CompositeExplicitAutograd
|
||||
// kernels.
|
||||
&& t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
|
||||
}
|
||||
|
||||
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
|
||||
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
|
||||
// maps to math_dispatch_keyset.
|
||||
// maps to [math_dispatch_keyset x full_backend_mask]
|
||||
constexpr DispatchKeySet math_dispatch_keyset =
|
||||
backend_dispatch_keyset | autograd_dispatch_keyset;
|
||||
|
||||
@ -39,7 +31,12 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
||||
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
||||
switch (t) {
|
||||
case DispatchKey::Autograd:
|
||||
return autograd_dispatch_keyset;
|
||||
// See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
|
||||
// That's why we OR it with a mask of the backend bits here.
|
||||
// getRuntimeDispatchKeySet() expects to return a keyset of runtime
|
||||
// dispatch keys, like AutogradCPU, but that requires having backend bits.
|
||||
return autograd_dispatch_keyset |
|
||||
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||
case DispatchKey::CompositeImplicitAutograd:
|
||||
return math_dispatch_keyset;
|
||||
case DispatchKey::CompositeExplicitAutograd:
|
||||
@ -53,11 +50,13 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
|
||||
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
||||
switch (t) {
|
||||
case DispatchKey::Autograd:
|
||||
return autograd_dispatch_keyset.has(k);
|
||||
return autograd_dispatch_keyset.has(toFunctionalityKey(k));
|
||||
case DispatchKey::CompositeImplicitAutograd:
|
||||
return math_dispatch_keyset.has(k);
|
||||
// See Note [NestedTensor Not Included in Backend Keys]
|
||||
return k != DispatchKey::NestedTensor && math_dispatch_keyset.has(k);
|
||||
case DispatchKey::CompositeExplicitAutograd:
|
||||
return backend_dispatch_keyset.has(k);
|
||||
// See Note [NestedTensor Not Included in Backend Keys]
|
||||
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
|
||||
default:
|
||||
return t == k;
|
||||
}
|
||||
@ -79,8 +78,6 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
||||
return DispatchKeySet(DispatchKey::MLC);
|
||||
case DispatchKey::AutogradHPU:
|
||||
return DispatchKeySet(DispatchKey::HPU);
|
||||
case DispatchKey::AutogradNestedTensor:
|
||||
return DispatchKeySet(DispatchKey::NestedTensor);
|
||||
case DispatchKey::AutogradXPU:
|
||||
return DispatchKeySet(DispatchKey::XPU);
|
||||
case DispatchKey::AutogradPrivateUse1:
|
||||
@ -96,23 +93,6 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
||||
}
|
||||
}
|
||||
|
||||
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
|
||||
switch (t) {
|
||||
case DispatchKey::CPU:
|
||||
return DispatchKeySet(DispatchKey::AutocastCPU);
|
||||
case DispatchKey::CUDA:
|
||||
case DispatchKey::XLA:
|
||||
return DispatchKeySet(DispatchKey::AutocastCUDA);
|
||||
default:
|
||||
return DispatchKeySet();
|
||||
}
|
||||
}
|
||||
|
||||
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
|
||||
return DispatchKeySet(
|
||||
{DispatchKey::ADInplaceOrView, getAutogradKeyFromBackend(t)});
|
||||
}
|
||||
|
||||
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
|
||||
return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
|
||||
}
|
||||
@ -129,18 +109,167 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
||||
return os;
|
||||
}
|
||||
os << "DispatchKeySet(";
|
||||
DispatchKey tid;
|
||||
bool first = true;
|
||||
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
|
||||
for (auto k : ts) {
|
||||
if (!first) {
|
||||
os << ", ";
|
||||
}
|
||||
os << tid;
|
||||
ts = ts.remove(tid);
|
||||
os << k;
|
||||
first = false;
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
|
||||
TORCH_INTERNAL_ASSERT(next_functionality_ >= num_backends);
|
||||
TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
|
||||
TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends);
|
||||
|
||||
// Create a masked version of the set representation to ignore previous
|
||||
// keys that we've iterated through.
|
||||
uint64_t masked_functionality_bits =
|
||||
llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
|
||||
uint64_t masked_backend_bits =
|
||||
llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
|
||||
*data_ptr_;
|
||||
|
||||
uint64_t first_functionality_idx =
|
||||
llvm::findFirstSet(masked_functionality_bits);
|
||||
uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
|
||||
|
||||
// If there are no keys, set to end iterator value
|
||||
if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
|
||||
next_functionality_ == iterator::end_iter_mask_val) {
|
||||
// Set up state to be the same as end()
|
||||
next_functionality_ = iterator::end_iter_mask_val;
|
||||
current_dispatchkey_idx_ = iterator::end_iter_key_val;
|
||||
next_backend_ = 0;
|
||||
current_backendcomponent_idx_ = iterator::end_iter_key_val;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The +1 is because of DispatchKey::Undefined and
|
||||
// BackendComponent::InvalidBit
|
||||
auto new_next_functionality = first_functionality_idx + 1;
|
||||
auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
|
||||
// and the -num_backends is because the first <num_backends> bits in the
|
||||
// keyset are not Dispatch Keys.
|
||||
auto next_dispatchkey_idx = new_next_functionality - num_backends;
|
||||
|
||||
// If the current functionality bit is a per-backend bit, we need special
|
||||
// handling
|
||||
if (isPerBackendFunctionalityKey(
|
||||
static_cast<DispatchKey>(next_dispatchkey_idx))) {
|
||||
// case 1: if the current backend is undefined, then there is no valid
|
||||
// backend instance of this functionality key so we can skip it.
|
||||
if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
|
||||
// increment the functionality mask so we skip the current functionality
|
||||
// bit on the next increment.
|
||||
next_functionality_ = new_next_functionality;
|
||||
++(*this);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Otherwise, at this point we know what the current backend and
|
||||
// functionality bits are.
|
||||
current_dispatchkey_idx_ = next_dispatchkey_idx;
|
||||
current_backendcomponent_idx_ = new_backendcomponent_idx;
|
||||
|
||||
// Next, we need to set up the masks for the next increment.
|
||||
uint64_t next_backendcomponent_bits =
|
||||
llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
|
||||
full_backend_mask & *data_ptr_;
|
||||
uint64_t next_backendcomponent_idx =
|
||||
llvm::findFirstSet(next_backendcomponent_bits);
|
||||
if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
|
||||
// case 2: the current backend is valid, but there is not another backend
|
||||
// in the keyset. In this case, we need to bump the functionality mask and
|
||||
// reset the backend mask for the next increment
|
||||
next_functionality_ = new_next_functionality;
|
||||
next_backend_ = 0;
|
||||
} else {
|
||||
// case 3: we have another backend to iterate over. We want to iterate
|
||||
// over the same functionality bit next time, but a different backend bit.
|
||||
next_backend_ = first_backendcomponent_idx + 1;
|
||||
}
|
||||
} else {
|
||||
// Functionality bits that aren't per backend are simpler to handle. We can
|
||||
// ignore the backend bits.
|
||||
TORCH_INTERNAL_ASSERT(next_backend_ == 0);
|
||||
current_dispatchkey_idx_ = next_dispatchkey_idx;
|
||||
next_functionality_ = new_next_functionality;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
|
||||
initializeFunctionalityOffsetsAndMasks() {
|
||||
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
|
||||
offsets_and_masks;
|
||||
// manualy set the first entry, which corresponds to Undefined.
|
||||
offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
|
||||
// loop through every functionality key (aside from Undefined).
|
||||
for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
|
||||
// functionality_idx should be Dense -> 1, ...
|
||||
auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
|
||||
auto k = static_cast<DispatchKey>(functionality_idx);
|
||||
|
||||
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
|
||||
// [Note: Trimmed Mobile Dispatch Keys]
|
||||
uint16_t mask = 0;
|
||||
uint16_t offset = 0;
|
||||
switch (k) {
|
||||
case DispatchKey::Undefined:
|
||||
offset = 0;
|
||||
case DispatchKey::CPU:
|
||||
offset = 1;
|
||||
case DispatchKey::QuantizedCPU:
|
||||
offset = 2;
|
||||
case DispatchKey::SparseCPU:
|
||||
offset = 3;
|
||||
case DispatchKey::BackendSelect:
|
||||
offset = 4;
|
||||
case DispatchKey::ADInplaceOrView:
|
||||
offset = 5;
|
||||
case DispatchKey::AutogradOther:
|
||||
offset = 6;
|
||||
case DispatchKey::AutogradCPU:
|
||||
offset = 7;
|
||||
default:
|
||||
// All other keys which are unsupported on mobile will get sent
|
||||
// to the undefined kernel, causing them to error.
|
||||
offset = 0;
|
||||
}
|
||||
offsets_and_masks[functionality_idx] =
|
||||
FunctionalityOffsetAndMask(offset, 0);
|
||||
}
|
||||
#else
|
||||
// If the previous functionality was not per-backend, then we can just
|
||||
// increment the previous offset. Otherwise, the next offset =
|
||||
// previous_offset + num_backends.
|
||||
auto next_offset = prev_offset_and_mask.offset +
|
||||
(prev_offset_and_mask.mask == 0 ? 1 : num_backends);
|
||||
// the mask is used in the runtime index calculation to find the offset of
|
||||
// the backend. For non-per-backend functionalities, this offset should
|
||||
// always be 0. Otherwise, we need to get the index of the backend (which we
|
||||
// can do using a backend mask).
|
||||
auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
|
||||
offsets_and_masks[functionality_idx] =
|
||||
FunctionalityOffsetAndMask(next_offset, next_mask);
|
||||
}
|
||||
// Sanity check that the computed offset index of the last functionality key
|
||||
// is correct. This assumes that the highest priority functionality key is not
|
||||
// per backend.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offsets_and_masks[num_functionality_keys - 1].offset ==
|
||||
(num_runtime_entries - 1),
|
||||
"num_runtime_entries: ",
|
||||
num_runtime_entries,
|
||||
"last_offset: ",
|
||||
offsets_and_masks[num_functionality_keys - 1].offset);
|
||||
#endif
|
||||
return offsets_and_masks;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -1,5 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
@ -8,29 +7,147 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
struct FunctionalityOffsetAndMask {
|
||||
// empty constructor shouldn't be used; only needed to initialize
|
||||
// the array before populating it.
|
||||
FunctionalityOffsetAndMask() {}
|
||||
FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
|
||||
: offset(offset), mask(mask) {}
|
||||
// This needs to big enough to cover the size of the operator table.
|
||||
uint16_t offset;
|
||||
// See Note [No More Than 16 Backends]
|
||||
// This mask needs to be big enough to mask all of the backend bits.
|
||||
// We probably don't ever want to have more than 16 backend bits, so uint16_t
|
||||
// should be enough.
|
||||
uint16_t mask;
|
||||
};
|
||||
static_assert(
|
||||
c10::num_runtime_entries < 65536,
|
||||
"The dispatcher currently only supports up to 2^16 runtime entries");
|
||||
|
||||
C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
|
||||
initializeFunctionalityOffsetsAndMasks();
|
||||
|
||||
C10_ALWAYS_INLINE static const std::
|
||||
array<FunctionalityOffsetAndMask, num_functionality_keys>&
|
||||
offsetsAndMasks() {
|
||||
static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
|
||||
return offsets_and_masks_;
|
||||
}
|
||||
|
||||
// A representation of a set of DispatchKeys. A DispatchKeySet contains both
|
||||
// "functionality" bits and "backend bits", and every tensor holds its own
|
||||
// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
|
||||
// keyset on every input tensor, or’ing them together, and dispatching to a
|
||||
// specific piece of functionality. The functionality bits are *ordered*. When
|
||||
// multiple functionality bits are set, we use the highest priority
|
||||
// functionality. Similarly, multiple backend bits can theoretically be set if
|
||||
// you call an operator with multiple tensors from difference devices (e.g. CPU
|
||||
// and CUDA), although support for mixed device dispatch is limited (the only
|
||||
// kernels that gracefully handle mixed device inputs for now are cuda kernels
|
||||
// that take in a scalar cpu tensor).
|
||||
|
||||
// A representation of a set of DispatchKeys. A tensor may have multiple
|
||||
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
|
||||
// DispatchKeySet specifies what type ids apply. The internal representation is
|
||||
// as a 64-bit bit set (this means only 64 tensor type ids are supported).
|
||||
//
|
||||
// Note that DispatchKeys are ordered; thus, we can ask questions like "what is
|
||||
// the highest priority DispatchKey in the set"? (The set itself is not
|
||||
// ordered; two sets with the same ids will always have the ids ordered in the
|
||||
// same way.)
|
||||
// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
|
||||
// "what is the highest priority DispatchKey in the set"? (The set itself is
|
||||
// not ordered; two sets with the same ids will always have the ids ordered in
|
||||
// the same way.)
|
||||
//
|
||||
// At the moment, there are no nontrivial uses of this set; tensors are always
|
||||
// singletons. In the near future, this set will represent variable? + tensor
|
||||
// type id. In the far future, it will be requires grad? + profiling? +
|
||||
// tracing? + lazy? + tensor type id.
|
||||
// Note [DispatchKeySet Internal Representation]
|
||||
// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
|
||||
// that get passed around at runtime.
|
||||
// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
|
||||
// and individual dispatch keys.
|
||||
//
|
||||
// (The difference between variable and requires grad, is that
|
||||
// there are currently three states a tensor can be:
|
||||
// 1. Not a variable
|
||||
// 2. Variable with requires_grad=False
|
||||
// 3. Variable with requires_grad=True
|
||||
// Eventually, we want to kill state (1), and only dispatch to autograd
|
||||
// handling code if one of the inputs requires grad.)
|
||||
// First: why do we have this distinction, and why not map every dispatch key
|
||||
// directly to a bit? This is mostly because we have several types of
|
||||
// functionalities that different backends would like to customize. For example,
|
||||
// we have:
|
||||
// - "Dense": CPU, CUDA, XLA, ... (~12 keys)
|
||||
// - "Sparse": SparseCPU, SparseCUDA, ...
|
||||
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
|
||||
// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
|
||||
// The problem is that total number of keys grows quadratically with [#
|
||||
// backends] x [# functionalities], making it very difficult to map each key
|
||||
// directly to a bit in a bitset without dramatically increasing the size of the
|
||||
// bitset over time.
|
||||
//
|
||||
// The two enums (BackendComponent and DispatchKey) can be divided roughly into
|
||||
// 5 categories.
|
||||
//
|
||||
// (1) "Building block" keys
|
||||
// (a) backends: jEverything in the BackendComponent enum (e.g. CPUBit,
|
||||
// CUDABIt) (b) functionalities: (per-backend) functionality-bit DispatchKeys
|
||||
// (e.g. AutogradFunctionality, Sparse, Dense)
|
||||
// (2) "Runtime" keys
|
||||
// (a) "non-customizable backends" (e.g. FPGA)
|
||||
// (b) "non-customizable functionalities" (e.g. Functionalize)
|
||||
// (c) "per-backend instances of customizable functionalities" (e.g. CPU,
|
||||
// SparseCPU, AutogradCPU)
|
||||
// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
|
||||
//
|
||||
// (1) Building block keys always correspond to individual bits in a
|
||||
// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
|
||||
// runtime keys. e.g.
|
||||
// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
|
||||
// DispatchKey::Dense});
|
||||
// // The keyset has the runtime dense-cpu key.
|
||||
// dense_cpu_ks.has(DispatchKey::CPU);
|
||||
// // And it contains the building block keys too.
|
||||
// dense_cpu_ks.has(DispatchKey::CPUBit);
|
||||
// dense_cpu_ks.has(DispatchKey::Dense);
|
||||
//
|
||||
// Not every backend and not every functionality counts as a "building block
|
||||
// key". This is mostly to give us more levers to pull in the design space.
|
||||
// Backend keys and functionality keys that count as "building blocks" will
|
||||
// contribute to a full cross product of functionality that can be overriden.
|
||||
//
|
||||
// For example, right now we have at least 12 "backend" building blocks (CPU,
|
||||
// CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense,
|
||||
// Sparse, Quantized, AutogradFunctionality, ...). These keys together allow
|
||||
// every dispatcher operator to be customized in up to 12*4 different ways. Each
|
||||
// of those requires a slot in the operator table of every dispatcher operator.
|
||||
// Not every piece of functionality necessarily needs to be customizeable
|
||||
// per-backend, and not every backend necessarily needs to be able to customize
|
||||
// every type of functionality.
|
||||
//
|
||||
//
|
||||
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
|
||||
// dispatch table, and you can directly register kernels to a runtime dispatch
|
||||
// key.
|
||||
//
|
||||
// For per-backend functionalities like "Dense" or "AutogradFunctionality",
|
||||
// you can think of the corresponding runtime dispatch keys as "instances" of
|
||||
// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
|
||||
// runtime instances of the "Dense" building block key.
|
||||
|
||||
// (2a) and (2b) are represented identically in the DispatchKeySet logic:
|
||||
// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
|
||||
// customizeable per backend.
|
||||
// In order to do so, we'd need to promote it to a per-backend functionality
|
||||
// "building block" key.
|
||||
// - non-customizeable backends (e.g. FPGA) can NOT customize existing
|
||||
// functionality like Sparse, Autograd, etc.
|
||||
// In order to do so, we'd need to promote it to a backend "building block"
|
||||
// key.
|
||||
//
|
||||
// In both cases, these keys directly correspond to runtime slots in the
|
||||
// operator table.
|
||||
//
|
||||
//
|
||||
// (3) "Alias" keys
|
||||
// See Note [Alias Dispatch Keys]
|
||||
//
|
||||
// Final note: for anyone making future changes to the Dispatcher +
|
||||
// DispatchKeySet internals, there's a closed PR with a basic
|
||||
// python-implementation of the Dispatcher that might be useful in quickly
|
||||
// testing out and validating changes. See it at
|
||||
// https://github.com/pytorch/pytorch/pull/68743
|
||||
|
||||
// An undefined tensor is one with an empty tensor type set.
|
||||
class DispatchKeySet final {
|
||||
public:
|
||||
@ -41,29 +158,146 @@ class DispatchKeySet final {
|
||||
// NB: default constructor representation as zero is MANDATORY as
|
||||
// use of DispatchKeySet in TLS requires this.
|
||||
constexpr DispatchKeySet() : repr_(0) {}
|
||||
|
||||
constexpr DispatchKeySet(Full)
|
||||
: repr_(std::numeric_limits<decltype(repr_)>::max()) {}
|
||||
: repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
|
||||
|
||||
constexpr DispatchKeySet(FullAfter, DispatchKey t)
|
||||
// LSB after t are OK, but not t itself.
|
||||
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
|
||||
// "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
|
||||
// Quantized > Dense). But backends don't really have an ordering.
|
||||
// Therefore, we're enforcing that FullAfter can only be used on
|
||||
// "functionality" keys.
|
||||
: repr_(
|
||||
(1ULL
|
||||
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
|
||||
1)) -
|
||||
1) {}
|
||||
|
||||
// Public version of DispatchKeySet(uint64_t) API; external users
|
||||
// must be explicit when they do this!
|
||||
constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
|
||||
explicit constexpr DispatchKeySet(DispatchKey t)
|
||||
: repr_(
|
||||
t == DispatchKey::Undefined
|
||||
? 0
|
||||
: 1ULL << (static_cast<uint8_t>(t) - 1)) {}
|
||||
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
|
||||
: repr_(0) {
|
||||
for (auto k : ks) {
|
||||
repr_ |= DispatchKeySet(k).repr_;
|
||||
|
||||
constexpr explicit DispatchKeySet(BackendComponent k) {
|
||||
if (k == BackendComponent::InvalidBit) {
|
||||
repr_ = 0;
|
||||
} else {
|
||||
repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr explicit DispatchKeySet(DispatchKey k) {
|
||||
if (k == DispatchKey::Undefined) {
|
||||
// Case 1: handle Undefined specifically
|
||||
repr_ = 0;
|
||||
} else if (k <= DispatchKey::EndOfFunctionalityKeys) {
|
||||
// Case 2: handle "functionality-only" keys
|
||||
// These keys have a functionality bit set, but no backend bits
|
||||
// These can technically be either:
|
||||
// - valid runtime keys (e.g. DispatchKey::AutogradOther,
|
||||
// DispatchKey::FuncTorchBatched, etc)
|
||||
// - "building block" keys that aren't actual runtime keys (e.g.
|
||||
// DispatchKey::Dense or Sparse)
|
||||
uint64_t functionality_val = 1ULL
|
||||
<< (num_backends + static_cast<uint8_t>(k) - 1);
|
||||
repr_ = functionality_val;
|
||||
} else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
|
||||
// Case 3: "runtime" keys that have a functionality bit AND a backend bit.
|
||||
// First compute which bit to flip for the functionality.
|
||||
auto functionality_k = toFunctionalityKey(k);
|
||||
// The - 1 is because Undefined is technically a "functionality" that
|
||||
// doesn't show up in the bitset. So e.g. Dense is technically the second
|
||||
// functionality, but the lowest functionality bit.
|
||||
uint64_t functionality_val = 1ULL
|
||||
<< (num_backends + static_cast<uint8_t>(functionality_k) - 1);
|
||||
|
||||
// then compute which bit to flip for the backend
|
||||
// Case 4a: handle the runtime instances of "per-backend functionality"
|
||||
// keys For example, given DispatchKey::CPU, we should set:
|
||||
// - the Dense functionality bit
|
||||
// - the CPUBit backend bit
|
||||
// first compute which bit to flip for the backend
|
||||
auto backend_k = toBackendComponent(k);
|
||||
uint64_t backend_val = backend_k == BackendComponent::InvalidBit
|
||||
? 0
|
||||
: 1ULL << (static_cast<uint8_t>(backend_k) - 1);
|
||||
repr_ = functionality_val + backend_val;
|
||||
} else {
|
||||
// At this point, we should have covered every case except for alias keys.
|
||||
// Technically it would be possible to add alias dispatch keys to a
|
||||
// DispatchKeySet, but the semantics are a little confusing and this
|
||||
// currently isn't needed anywhere.
|
||||
repr_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
|
||||
uint64_t repr = 0;
|
||||
for (auto k : ks) {
|
||||
repr |= DispatchKeySet(k).repr_;
|
||||
}
|
||||
return repr;
|
||||
}
|
||||
|
||||
constexpr uint64_t backend_bits_to_repr(
|
||||
std::initializer_list<BackendComponent> ks) {
|
||||
uint64_t repr = 0;
|
||||
for (auto k : ks) {
|
||||
repr |= DispatchKeySet(k).repr_;
|
||||
}
|
||||
return repr;
|
||||
}
|
||||
|
||||
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
|
||||
: repr_(keys_to_repr(ks)) {}
|
||||
|
||||
explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
|
||||
// Note: for some reason, putting this logic directly in the constructor
|
||||
// appears to fail to compile on CUDA 10.1.
|
||||
// See an example internal failure at
|
||||
// https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
|
||||
: repr_(backend_bits_to_repr(ks)) {}
|
||||
|
||||
// Test if a DispatchKey is in the set
|
||||
bool inline has(DispatchKey t) const {
|
||||
inline bool has(DispatchKey t) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
|
||||
return static_cast<bool>(repr_ & DispatchKeySet(t).repr_);
|
||||
return has_all(DispatchKeySet(t));
|
||||
}
|
||||
constexpr bool has_backend(BackendComponent t) const {
|
||||
return has_all(DispatchKeySet(t));
|
||||
}
|
||||
|
||||
// Test if a DispatchKey is in the set
|
||||
// Given a DispatchKeySet of functionality keys and (potentially) backend
|
||||
// keys, tests if all of them are in the current set.
|
||||
constexpr bool has_all(DispatchKeySet ks) const {
|
||||
return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
|
||||
}
|
||||
|
||||
// Given a DispatchKeySet of functionality keys and (potentially) backend
|
||||
// keys, tests if any of them are in the current set. This could technically
|
||||
// be pretty easily implemented using has(). It is strictly a perf
|
||||
// optimization though. There are many places in the code base where we want
|
||||
// to test for multiple functionality keys together. HOWEVER, runtime
|
||||
// per-backend functionality keys aren't allowed to be used with this
|
||||
// function, because you can end up with weird results. e.g.
|
||||
// DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
|
||||
// would return true.
|
||||
inline bool has_any(DispatchKeySet ks) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
// Either there are no backend bits in the input keyset
|
||||
((ks.repr_ & full_backend_mask) == 0) ||
|
||||
// or there are no per-backend-functionality bits
|
||||
// See [Note: Per-Backend Functionality Dispatch Keys]
|
||||
((ks &
|
||||
DispatchKeySet({
|
||||
DispatchKey::Dense,
|
||||
DispatchKey::Quantized,
|
||||
DispatchKey::Sparse,
|
||||
DispatchKey::AutogradFunctionality,
|
||||
})
|
||||
.repr_) == 0));
|
||||
return static_cast<bool>((repr_ & ks.repr_) != 0);
|
||||
}
|
||||
// Test if DispatchKeySet is a superset of ks.
|
||||
bool isSupersetOf(DispatchKeySet ks) const {
|
||||
@ -74,31 +308,64 @@ class DispatchKeySet final {
|
||||
return DispatchKeySet(repr_ | other.repr_);
|
||||
}
|
||||
// Perform set intersection
|
||||
DispatchKeySet operator&(DispatchKeySet other) const {
|
||||
constexpr DispatchKeySet operator&(DispatchKeySet other) const {
|
||||
return DispatchKeySet(repr_ & other.repr_);
|
||||
}
|
||||
// Compute the set difference self - other
|
||||
// Compute the set difference self - other,
|
||||
// but ONLY for the functionality keys.
|
||||
// Any backend bits set on self will remain unchanged.
|
||||
// See Note [Removing keys from DispatchKeySet Only Affects Functionality
|
||||
// Keys]
|
||||
DispatchKeySet operator-(DispatchKeySet other) const {
|
||||
return DispatchKeySet(repr_ & ~other.repr_);
|
||||
return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
|
||||
}
|
||||
|
||||
// Compute self ^ other
|
||||
constexpr DispatchKeySet operator^(DispatchKeySet other) const {
|
||||
return DispatchKeySet(repr_ ^ other.repr_);
|
||||
}
|
||||
// Perform set equality
|
||||
bool operator==(DispatchKeySet other) const {
|
||||
return repr_ == other.repr_;
|
||||
}
|
||||
bool operator!=(DispatchKeySet other) const {
|
||||
return repr_ != other.repr_;
|
||||
}
|
||||
// Add a DispatchKey to the DispatchKey set. Does NOT mutate,
|
||||
// returns the extended DispatchKeySet!
|
||||
C10_NODISCARD DispatchKeySet add(DispatchKey t) const {
|
||||
return *this | DispatchKeySet(t);
|
||||
}
|
||||
// Remove a DispatchKey from the DispatchKey set. This is
|
||||
// generally not an operation you should be doing (it's
|
||||
// used to implement operator<<)
|
||||
C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const {
|
||||
return DispatchKeySet(repr_ & ~DispatchKeySet(t).repr_);
|
||||
C10_NODISCARD DispatchKeySet add(DispatchKeySet ks) const {
|
||||
return *this | ks;
|
||||
}
|
||||
|
||||
// Remove a DispatchKey from the DispatchKey set.
|
||||
// This is generally not an operation you should be doing
|
||||
// (it's used to implement the printing overload, operator<<)
|
||||
//
|
||||
// Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
|
||||
// Only functionality bits are allowed to be removed from a keyset.
|
||||
// For now, we're only allowing removal of "functionality bits" from the
|
||||
// keyset, which is specifically needed by the fallthrough key calculation
|
||||
// logic. Why is removing backend bits problematic? Consider this example:
|
||||
//
|
||||
// DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
|
||||
// DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
|
||||
// DispatchKeySet([DispatchKey.CPU,
|
||||
// DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
|
||||
//
|
||||
// What do we want to happen?
|
||||
// Technically, we'd like it to be true that after removal,
|
||||
// the first keyset still has the CUDA dispatch key while the second doesn't.
|
||||
// Unfortunately there's no way to represent that, because the two keysets are
|
||||
// represented the same way internally: functionality bits: Autograd, Dense
|
||||
// backend bits: CPU, CUDA
|
||||
//
|
||||
// Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
|
||||
// bit from the bitset.
|
||||
constexpr DispatchKeySet remove(DispatchKey t) const {
|
||||
return DispatchKeySet(
|
||||
repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
|
||||
}
|
||||
// Is the set empty? (AKA undefined tensor)
|
||||
bool empty() const {
|
||||
@ -107,22 +374,78 @@ class DispatchKeySet final {
|
||||
uint64_t raw_repr() {
|
||||
return repr_;
|
||||
}
|
||||
// Return the type id in this set with the highest priority (i.e.,
|
||||
// is the largest in the DispatchKey enum). Intuitively, this
|
||||
// type id is the one that should handle dispatch (assuming there
|
||||
// aren't any further exclusions or inclusions).
|
||||
DispatchKey highestPriorityTypeId() const {
|
||||
// TODO: If I put Undefined as entry 64 and then adjust the
|
||||
// singleton constructor to shift from the right, we can get rid of the
|
||||
// subtraction here. It's modestly more complicated to get right so I
|
||||
// didn't do it for now.
|
||||
return static_cast<DispatchKey>(64 - llvm::countLeadingZeros(repr_));
|
||||
|
||||
DispatchKey highestFunctionalityKey() const {
|
||||
auto functionality_idx = indexOfHighestBit();
|
||||
// This means that none of the functionality bits were set.
|
||||
if (functionality_idx < num_backends)
|
||||
return DispatchKey::Undefined;
|
||||
// The first num_backend bits in the keyset don't correspond to real
|
||||
// dispatch keys.
|
||||
return static_cast<DispatchKey>(functionality_idx - num_backends);
|
||||
}
|
||||
|
||||
DispatchKey highestPriorityBackendTypeId() const {
|
||||
return (*this &
|
||||
((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
|
||||
.highestPriorityTypeId();
|
||||
// This is similar like toBackendComponent(DispatchKey), but less restrictive.
|
||||
// toBackendComponent() errors out if the key that it was passed has no
|
||||
// backend bits, which is useful for error checking. We need a version of that
|
||||
// here that can also handle "fake" backends like FPGA, because they need to
|
||||
// map to the AutogradOther key. For those backends, we return
|
||||
// BackendComponent::InvalidBit.
|
||||
BackendComponent highestBackendKey() const {
|
||||
// mask to mask out functionality bits
|
||||
auto backend_idx =
|
||||
DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
|
||||
// all zeros across the backend bits means that no backend bits are set.
|
||||
if (backend_idx == 0)
|
||||
return BackendComponent::InvalidBit;
|
||||
return static_cast<BackendComponent>(backend_idx);
|
||||
}
|
||||
|
||||
// returns the DispatchKey of highest priority in the set.
|
||||
DispatchKey highestPriorityTypeId() const {
|
||||
auto functionality_k = highestFunctionalityKey();
|
||||
if (isPerBackendFunctionalityKey(functionality_k)) {
|
||||
return toRuntimePerBackendFunctionalityKey(
|
||||
functionality_k, highestBackendKey());
|
||||
}
|
||||
return functionality_k;
|
||||
}
|
||||
|
||||
// Returns the index of the most-significant bit in the keyset.
|
||||
// This is used to as part of the calculation into the operator table to get:
|
||||
// - the highest "functionality" bit in the keyset.
|
||||
// - the highest "backend" bit in the keyset.
|
||||
uint8_t indexOfHighestBit() const {
|
||||
return 64 - llvm::countLeadingZeros(repr_);
|
||||
}
|
||||
|
||||
// returns the index in the operator table of highest priority key in the the
|
||||
// keyset Note that we could in theory implement this using
|
||||
// highestPriorityTypeId(), but this code is very hotpath and we can do it
|
||||
// faster without it.
|
||||
uint64_t getDispatchTableIndexForDispatchKeySet() const {
|
||||
auto functionality_idx =
|
||||
DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
|
||||
auto offset_and_mask = offsetsAndMasks()[functionality_idx];
|
||||
// Mask the functionality bits out first, then right-shift by 1.
|
||||
// right-shifting by 1 because everything is zero-indexed.
|
||||
// E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
|
||||
// give us an offset of 1, etc.
|
||||
auto backend_idx =
|
||||
DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
|
||||
return offset_and_mask.offset + backend_idx;
|
||||
}
|
||||
|
||||
// returns the "index" of the highest priority backend in the keyset.
|
||||
// This is pretty similar to getBackendKey(), but:
|
||||
// - It's hotpath code (part of the runtime bitset calculation)
|
||||
// - I's returns an integer index, not an enum value
|
||||
// - Everything is shifted to the right by 1.
|
||||
// BackendComponent::InvalidBit is technically the lowest enum value,
|
||||
// but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
|
||||
// etc.
|
||||
uint64_t getBackendIndex() const {
|
||||
return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -130,42 +453,47 @@ class DispatchKeySet final {
|
||||
uint64_t repr_ = 0;
|
||||
|
||||
public:
|
||||
// STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the
|
||||
// set. The iterator is only invalidated by the destruction of the underlying
|
||||
// DispatchKeySet as the iterator stores a pointer to the raw representation
|
||||
// of the DispatchKeySet.
|
||||
// STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
|
||||
// in the set. The iterator is only invalidated by the destruction of the
|
||||
// underlying DispatchKeySet as the iterator stores a pointer to the raw
|
||||
// representation of the DispatchKeySet. Note: When we encounter a per-backend
|
||||
// functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
|
||||
// in the keyset, for that functionality. For example, if the next
|
||||
// functionality key to iterate over is Autograd, and the backend bits in the
|
||||
// keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
|
||||
// then the next two keys we return will be DispatchKey::AutogradCPU,
|
||||
// DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
|
||||
// CUDA in DispatchKey.h).
|
||||
class iterator {
|
||||
public:
|
||||
using self_type = iterator;
|
||||
using iterator_category = std::input_iterator_tag;
|
||||
using value_type = DispatchKey;
|
||||
using difference_type = ptrdiff_t;
|
||||
// final mask value should mask out the entire keyset
|
||||
static const uint8_t end_iter_mask_val =
|
||||
num_backends + num_functionality_keys;
|
||||
// final key value should be the last DispatchKey
|
||||
static const uint8_t end_iter_key_val = num_functionality_keys;
|
||||
|
||||
explicit iterator(const uint64_t* data_ptr, uint8_t i = 0)
|
||||
: data_ptr_(data_ptr), i_(i) {
|
||||
// current_dispatchkey_idx_ will iterate through all functionality bits.
|
||||
// current_backendcomponent_idx_ will iterate through all backend bits.
|
||||
explicit iterator(
|
||||
const uint64_t* data_ptr,
|
||||
uint8_t next_functionality = num_backends,
|
||||
uint8_t next_backend = 0)
|
||||
: data_ptr_(data_ptr),
|
||||
next_functionality_(next_functionality),
|
||||
next_backend_(next_backend),
|
||||
// These are in an invalid state at construction time, and set by the
|
||||
// first increment call
|
||||
current_dispatchkey_idx_(end_iter_key_val),
|
||||
current_backendcomponent_idx_(end_iter_key_val) {
|
||||
// Go to the first key in the set
|
||||
++(*this);
|
||||
}
|
||||
|
||||
self_type& operator++() {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
|
||||
|
||||
// Create a masked version of the set representation to ignore previous
|
||||
// keys that we've iterated through.
|
||||
uint64_t masked_data = llvm::maskTrailingZeros<uint64_t>(i_) & *data_ptr_;
|
||||
uint64_t firstKeyIndex = llvm::findFirstSet(masked_data);
|
||||
|
||||
// If there are no keys, set to end iterator value
|
||||
if (firstKeyIndex == std::numeric_limits<uint64_t>::max() ||
|
||||
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
|
||||
i_ = static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
return *this;
|
||||
}
|
||||
|
||||
i_ = static_cast<uint8_t>(firstKeyIndex) + 1;
|
||||
return *this;
|
||||
}
|
||||
C10_API self_type& operator++();
|
||||
|
||||
self_type operator++(int) {
|
||||
self_type previous_iterator = *this;
|
||||
@ -174,18 +502,50 @@ class DispatchKeySet final {
|
||||
}
|
||||
|
||||
bool operator==(const self_type& rhs) const {
|
||||
return i_ == rhs.i_;
|
||||
return next_functionality_ == rhs.next_functionality_ &&
|
||||
current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
|
||||
next_backend_ == rhs.next_backend_ &&
|
||||
current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
|
||||
}
|
||||
bool operator!=(const self_type& rhs) const {
|
||||
return i_ != rhs.i_;
|
||||
return next_functionality_ != rhs.next_functionality_ ||
|
||||
current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
|
||||
next_backend_ != rhs.next_backend_ ||
|
||||
current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
|
||||
}
|
||||
DispatchKey operator*() const {
|
||||
return static_cast<DispatchKey>(i_);
|
||||
auto functionality_key =
|
||||
static_cast<DispatchKey>(current_dispatchkey_idx_);
|
||||
if (isPerBackendFunctionalityKey(functionality_key)) {
|
||||
auto next_key = toRuntimePerBackendFunctionalityKey(
|
||||
functionality_key,
|
||||
static_cast<BackendComponent>(current_backendcomponent_idx_));
|
||||
// We expect all of the Dense, Sparse, Quantized, and Autograd keys to
|
||||
// be ordered the same way with respect to their backends
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
toBackendComponent(next_key) ==
|
||||
static_cast<BackendComponent>(current_backendcomponent_idx_),
|
||||
"Tried to map functionality key ",
|
||||
toString(functionality_key),
|
||||
" and backend bit ",
|
||||
toString(
|
||||
static_cast<BackendComponent>(current_backendcomponent_idx_)),
|
||||
" to a runtime key, but ended up with ",
|
||||
toString(next_key),
|
||||
". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
|
||||
" Please double check that enum for inconsistencies.");
|
||||
return next_key;
|
||||
} else {
|
||||
return functionality_key;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const uint64_t* data_ptr_;
|
||||
uint8_t i_;
|
||||
uint8_t next_functionality_;
|
||||
uint8_t next_backend_;
|
||||
uint8_t current_dispatchkey_idx_;
|
||||
uint8_t current_backendcomponent_idx_;
|
||||
};
|
||||
|
||||
public:
|
||||
@ -195,31 +555,35 @@ class DispatchKeySet final {
|
||||
return iterator(&repr_);
|
||||
}
|
||||
|
||||
// We do not need to iterate beyond NumDispatchKeys so we will treat this as
|
||||
// the end iterator. NumDispatchKeys will always be strictly less than 64.
|
||||
// We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
|
||||
// this as the end iterator.
|
||||
iterator end() const {
|
||||
return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
|
||||
return iterator(&repr_, iterator::end_iter_mask_val);
|
||||
}
|
||||
};
|
||||
|
||||
C10_API std::string toString(DispatchKeySet);
|
||||
C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
|
||||
|
||||
// autograd_dispatch_keyset should include all runtime autograd keys.
|
||||
// Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset.
|
||||
C10_API inline uint64_t getDispatchTableIndexForDispatchKey(DispatchKey k) {
|
||||
return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
|
||||
}
|
||||
|
||||
// Alias key DispatchKey::Autograd maps to
|
||||
// (autograd_dispatch_keyset x full_backend_mask)
|
||||
// NB: keys in this set also get associated with CompositeImplicitAutograd
|
||||
//
|
||||
// Note [autograd_dispatch_keyset Does Not Include Backend Bits]
|
||||
// We don't want to include any backend bits (BackendComponent::CPUBit, etc)
|
||||
// directly in autograd_dispatch_keyset.
|
||||
// Why? keysets like autograd_dispatch_keyset are commonly used to remove
|
||||
// autograd keys from a DispatchKeySet throughout the code base. However, you
|
||||
// are only allowed to remove functionality bits from a keyset, not backend
|
||||
// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
|
||||
// Keys] for details. To be consistent and avoid confusion, we're explicitly
|
||||
// setting up autograd_dispatch_keyset to not have any backend bits.
|
||||
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
|
||||
DispatchKey::AutogradCPU,
|
||||
DispatchKey::AutogradCUDA,
|
||||
DispatchKey::AutogradXLA,
|
||||
DispatchKey::AutogradLazy,
|
||||
DispatchKey::AutogradNestedTensor,
|
||||
DispatchKey::AutogradMLC,
|
||||
DispatchKey::AutogradHPU,
|
||||
DispatchKey::AutogradXPU,
|
||||
DispatchKey::AutogradPrivateUse1,
|
||||
DispatchKey::AutogradPrivateUse2,
|
||||
DispatchKey::AutogradPrivateUse3,
|
||||
DispatchKey::AutogradFunctionality,
|
||||
DispatchKey::AutogradOther,
|
||||
});
|
||||
|
||||
@ -244,25 +608,28 @@ constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
|
||||
|
||||
// backend dispatch keys that map to DispatchKey::AutogradOther
|
||||
// NB: keys in this set also get associated with CompositeImplicitAutograd
|
||||
constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
|
||||
{DispatchKey::HIP,
|
||||
DispatchKey::VE,
|
||||
DispatchKey::FPGA,
|
||||
DispatchKey::ORT,
|
||||
DispatchKey::Vulkan,
|
||||
DispatchKey::Metal,
|
||||
DispatchKey::QuantizedCPU,
|
||||
DispatchKey::QuantizedCUDA,
|
||||
DispatchKey::CustomRNGKeyId,
|
||||
DispatchKey::MkldnnCPU,
|
||||
DispatchKey::SparseCPU,
|
||||
DispatchKey::SparseCUDA,
|
||||
DispatchKey::SparseHIP,
|
||||
DispatchKey::SparseVE,
|
||||
DispatchKey::SparseXPU,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::Meta});
|
||||
constexpr DispatchKeySet autogradother_backends =
|
||||
DispatchKeySet(
|
||||
// HIP and VE aren't in this list: they now have their own backend bits
|
||||
// which means that they can now have their own Autograd keys.
|
||||
// Technically, HIP will now redispatch to its own custom AutogradHIP
|
||||
// slot in the runtime table.
|
||||
{DispatchKey::FPGA,
|
||||
DispatchKey::ORT,
|
||||
DispatchKey::Vulkan,
|
||||
DispatchKey::Metal,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::CustomRNGKeyId,
|
||||
DispatchKey::MkldnnCPU,
|
||||
DispatchKey::Meta,
|
||||
// Sparse and Quantized backends also live here.
|
||||
DispatchKey::Sparse,
|
||||
DispatchKey::Quantized})
|
||||
// Including the backend bits because this keyset is used during op
|
||||
// registration, which requires looping over all runtime autogradother
|
||||
// backend keys.
|
||||
| DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||
|
||||
// The set of dispatch keys that come after autograd
|
||||
// n.b. this relies on the fact that AutogradOther is currently the lowest
|
||||
@ -292,6 +659,36 @@ constexpr DispatchKeySet after_func_keyset =
|
||||
// away with it by explicitly removing the key here.
|
||||
c10::DispatchKey::ADInplaceOrView);
|
||||
|
||||
constexpr DispatchKeySet backend_bitset_mask =
|
||||
DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
|
||||
|
||||
constexpr auto inplace_or_view_ks =
|
||||
DispatchKeySet(DispatchKey::ADInplaceOrView);
|
||||
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
|
||||
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
|
||||
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
|
||||
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
|
||||
constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
|
||||
constexpr auto autograd_mlc_ks = DispatchKeySet(DispatchKey::AutogradMLC);
|
||||
constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
|
||||
constexpr auto autograd_privateuse1_ks =
|
||||
DispatchKeySet(DispatchKey::AutogradPrivateUse1);
|
||||
constexpr auto autograd_privateuse2_ks =
|
||||
DispatchKeySet(DispatchKey::AutogradPrivateUse2);
|
||||
constexpr auto autograd_privateuse3_ks =
|
||||
DispatchKeySet(DispatchKey::AutogradPrivateUse3);
|
||||
constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
|
||||
|
||||
struct OpTableOffsetAndMask {
|
||||
uint16_t offset;
|
||||
uint16_t backend_mask;
|
||||
};
|
||||
|
||||
static_assert(
|
||||
num_backends <= 16,
|
||||
"Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
|
||||
" that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
|
||||
|
||||
// true if t is a backend dispatch key
|
||||
C10_API bool isBackendDispatchKey(DispatchKey t);
|
||||
|
||||
@ -307,10 +704,53 @@ C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
|
||||
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
|
||||
|
||||
// Returns a DispatchKeySet of autograd related keys mapped to backend.
|
||||
C10_API DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t);
|
||||
// for a given backend key, use the associated autograd key.
|
||||
// for non-backend keys, use AutogradOther as a default.
|
||||
// Note: it's convenient and fast to return a default here rather than (say)
|
||||
// returning an optional<DispatchKey>, or throwing. But it makes callers
|
||||
// responsible for either a) enforcing the invariant that only backend keys
|
||||
// be passed as arguments, or b) interpreting our return value carefully.
|
||||
inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
|
||||
switch (t) {
|
||||
case BackendComponent::CPUBit:
|
||||
return inplace_or_view_ks | autograd_cpu_ks;
|
||||
case BackendComponent::XPUBit:
|
||||
return inplace_or_view_ks | autograd_xpu_ks;
|
||||
case BackendComponent::CUDABit:
|
||||
return inplace_or_view_ks | autograd_cuda_ks;
|
||||
case BackendComponent::XLABit:
|
||||
return inplace_or_view_ks | autograd_xla_ks;
|
||||
case BackendComponent::LazyBit:
|
||||
return inplace_or_view_ks | autograd_lazy_ks;
|
||||
case BackendComponent::MLCBit:
|
||||
return inplace_or_view_ks | autograd_mlc_ks;
|
||||
case BackendComponent::HPUBit:
|
||||
return inplace_or_view_ks | autograd_hpu_ks;
|
||||
case BackendComponent::PrivateUse1Bit:
|
||||
return inplace_or_view_ks | autograd_privateuse1_ks;
|
||||
case BackendComponent::PrivateUse2Bit:
|
||||
return inplace_or_view_ks | autograd_privateuse2_ks;
|
||||
case BackendComponent::PrivateUse3Bit:
|
||||
return inplace_or_view_ks | autograd_privateuse3_ks;
|
||||
default:
|
||||
return inplace_or_view_ks | autograd_other_ks;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a DispatchKeySet of autocast related keys mapped to backend.
|
||||
C10_API DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t);
|
||||
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
|
||||
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
|
||||
switch (t) {
|
||||
case BackendComponent::CPUBit:
|
||||
return autocast_cpu_ks;
|
||||
case BackendComponent::CUDABit:
|
||||
case BackendComponent::XLABit:
|
||||
return autocast_cuda_ks;
|
||||
default:
|
||||
return DispatchKeySet();
|
||||
}
|
||||
}
|
||||
|
||||
// This API exists because we have a use case for checking
|
||||
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
|
||||
|
@ -190,7 +190,7 @@ TensorImpl::TensorImpl(
|
||||
|
||||
// TODO: be more explicit about the full key set at call sites so we
|
||||
// don't have to keep recomputing it here
|
||||
DispatchKey k = key_set.highestPriorityBackendTypeId();
|
||||
auto k = key_set.highestBackendKey();
|
||||
|
||||
key_set = key_set | getAutocastRelatedKeySetFromBackend(k);
|
||||
|
||||
|
@ -838,10 +838,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
bool is_sparse() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
return key_set_.has(DispatchKey::SparseCPU) ||
|
||||
key_set_.has(DispatchKey::SparseCUDA) ||
|
||||
key_set_.has(DispatchKey::SparseHIP) ||
|
||||
key_set_.has(DispatchKey::SparseXPU);
|
||||
return key_set_.has(DispatchKey::Sparse);
|
||||
}
|
||||
|
||||
// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR
|
||||
@ -854,9 +851,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
bool is_quantized() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
return key_set_.has(DispatchKey::QuantizedCPU) ||
|
||||
key_set_.has(DispatchKey::QuantizedCUDA) ||
|
||||
key_set_.has(DispatchKey::QuantizedXPU);
|
||||
return key_set_.has(DispatchKey::Quantized);
|
||||
}
|
||||
|
||||
bool is_meta() const {
|
||||
@ -868,53 +863,46 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
bool is_cpu() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
return key_set_.has(DispatchKey::CPU) ||
|
||||
key_set_.has(DispatchKey::SparseCPU) ||
|
||||
return key_set_.has_backend(BackendComponent::CPUBit) ||
|
||||
key_set_.has(DispatchKey::SparseCsrCPU) ||
|
||||
key_set_.has(DispatchKey::QuantizedCPU) ||
|
||||
key_set_.has(DispatchKey::MkldnnCPU);
|
||||
}
|
||||
|
||||
bool is_cuda() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
return key_set_.has(DispatchKey::CUDA) ||
|
||||
key_set_.has(DispatchKey::SparseCUDA) ||
|
||||
key_set_.has(DispatchKey::SparseCsrCUDA) ||
|
||||
key_set_.has(DispatchKey::QuantizedCUDA);
|
||||
return key_set_.has_backend(BackendComponent::CUDABit) ||
|
||||
key_set_.has(DispatchKey::SparseCsrCUDA);
|
||||
}
|
||||
|
||||
bool is_xpu() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
return key_set_.has(DispatchKey::XPU) ||
|
||||
key_set_.has(DispatchKey::SparseXPU) ||
|
||||
key_set_.has(DispatchKey::QuantizedXPU);
|
||||
return key_set_.has_backend(BackendComponent::XPUBit);
|
||||
}
|
||||
|
||||
bool is_xla() const {
|
||||
return key_set_.has(DispatchKey::XLA);
|
||||
return key_set_.has_backend(BackendComponent::XLABit);
|
||||
}
|
||||
|
||||
bool is_hpu() const {
|
||||
return key_set_.has(DispatchKey::HPU);
|
||||
return key_set_.has_backend(BackendComponent::HPUBit);
|
||||
}
|
||||
|
||||
bool is_lazy() const {
|
||||
return key_set_.has(DispatchKey::Lazy);
|
||||
return key_set_.has_backend(BackendComponent::LazyBit);
|
||||
}
|
||||
|
||||
bool is_hip() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
return key_set_.has(DispatchKey::HIP) ||
|
||||
key_set_.has(DispatchKey::SparseHIP);
|
||||
return key_set_.has_backend(BackendComponent::HIPBit);
|
||||
}
|
||||
|
||||
bool is_ve() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
return key_set_.has(DispatchKey::VE) || key_set_.has(DispatchKey::SparseVE);
|
||||
return key_set_.has_backend(BackendComponent::VEBit);
|
||||
}
|
||||
|
||||
bool is_mkldnn() const {
|
||||
@ -1548,13 +1536,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
*/
|
||||
inline bool has_compatible_shallow_copy_type(DispatchKeySet from) {
|
||||
auto is_dense = [](DispatchKeySet ts) {
|
||||
return ts.has(DispatchKey::CPU) || ts.has(DispatchKey::CUDA) ||
|
||||
ts.has(DispatchKey::HIP) || ts.has(DispatchKey::XPU);
|
||||
constexpr auto dense_backends = DispatchKeySet(
|
||||
{BackendComponent::CPUBit,
|
||||
BackendComponent::CUDABit,
|
||||
BackendComponent::HIPBit,
|
||||
BackendComponent::XPUBit});
|
||||
constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense);
|
||||
return ts.has_any(dense_k) && ts.has_any(dense_backends);
|
||||
};
|
||||
auto is_sparse = [](DispatchKeySet ts) {
|
||||
return ts.has(DispatchKey::SparseCPU) ||
|
||||
ts.has(DispatchKey::SparseCUDA) || ts.has(DispatchKey::SparseHIP) ||
|
||||
ts.has(DispatchKey::SparseXPU);
|
||||
constexpr auto sparse_backends = DispatchKeySet(
|
||||
{BackendComponent::CPUBit,
|
||||
BackendComponent::CUDABit,
|
||||
BackendComponent::HIPBit,
|
||||
BackendComponent::XPUBit});
|
||||
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
|
||||
return ts.has_any(sparse_k) && ts.has_any(sparse_backends);
|
||||
};
|
||||
return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) ||
|
||||
(is_sparse(key_set_) && is_sparse(from));
|
||||
|
@ -3,25 +3,163 @@
|
||||
#include <unordered_set>
|
||||
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
using namespace c10;
|
||||
|
||||
// This test exists not to be comprehensive, but to more clearly show
|
||||
// what the semantics of DispatchKeySet are.
|
||||
TEST(DispatchKeySet, ShowSemantics) {
|
||||
// the "CPU" dispatch key is an instance of a per-backend-functionality key.
|
||||
// It corresponds to "dense" functionality, "CPU" backend.
|
||||
// This means that it gets a dense functionality bit, and a cpu backend bit
|
||||
// set.
|
||||
auto undefined_set = DispatchKeySet();
|
||||
auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU);
|
||||
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
|
||||
ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
|
||||
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
|
||||
|
||||
auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy);
|
||||
ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense));
|
||||
ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit));
|
||||
ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy));
|
||||
|
||||
// You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block"
|
||||
// dispatch keys. You are allowed to directly create keysets out of them!
|
||||
auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) |
|
||||
DispatchKeySet(BackendComponent::CPUBit);
|
||||
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
|
||||
ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
|
||||
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
|
||||
ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks);
|
||||
|
||||
// Similarly, the AutogradCUDA key gets 2 bits in the keyset:
|
||||
// The "Autograd" functionality bit, and the "CUDA" backend bit
|
||||
auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA);
|
||||
ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality));
|
||||
ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit));
|
||||
|
||||
// Because DispatchKeySet uses a condensed internal representation, you cannot
|
||||
// use it to represent the FULL cross product of backends and functionalities
|
||||
// for example:
|
||||
auto autograd_dense_cpu_cuda = DispatchKeySet(
|
||||
{DispatchKey::AutogradFunctionality,
|
||||
DispatchKey::Dense,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU});
|
||||
auto fpga = DispatchKeySet(DispatchKey::FPGA);
|
||||
auto fpga_and_cpu = DispatchKeySet({DispatchKey::FPGA, DispatchKey::CPU});
|
||||
// this keyset has all of the building block keys:
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality));
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense));
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit));
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit));
|
||||
|
||||
// and it also has the "runtime" keys that correspond to the full
|
||||
// cross-product of functionality
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU));
|
||||
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA));
|
||||
|
||||
// This means that there's no way to represent a keyset with, say, only
|
||||
// Autograd CUDA + Dense CPU. Instead, you should think of a keyset as
|
||||
// inheriting the full set of functionalities + backends of its keys. This
|
||||
// means that the below keysets are all indistinguishable from each other.
|
||||
ASSERT_EQ(
|
||||
autograd_dense_cpu_cuda,
|
||||
DispatchKeySet(
|
||||
{DispatchKey::AutogradCUDA,
|
||||
DispatchKey::AutogradCPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU}));
|
||||
ASSERT_EQ(
|
||||
autograd_dense_cpu_cuda,
|
||||
DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU}));
|
||||
ASSERT_EQ(
|
||||
autograd_dense_cpu_cuda,
|
||||
DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU}));
|
||||
|
||||
// ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~
|
||||
|
||||
// Iterators allow you to iterate individually through the DispatchKey's in a
|
||||
// DispatchKeySet
|
||||
auto empty_set = DispatchKeySet();
|
||||
auto t1 = empty_set.begin();
|
||||
auto t2 = empty_set.end();
|
||||
ASSERT_EQ(*empty_set.begin(), *empty_set.end());
|
||||
|
||||
// However, only keys that correspond to actual runtime indices of kernels in
|
||||
// the operator table show up when you iterate through a keyset. i.e.
|
||||
// DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an
|
||||
// iterator.
|
||||
auto dense_cpu_iter = dense_cpu_set.begin();
|
||||
ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU);
|
||||
ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end());
|
||||
|
||||
auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin();
|
||||
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU);
|
||||
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA);
|
||||
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU);
|
||||
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA);
|
||||
ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end());
|
||||
|
||||
// But other "functionality bits" that are not defined per-backend DO get
|
||||
// their own slots in the operator table.
|
||||
auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) |
|
||||
DispatchKeySet(
|
||||
{DispatchKey::FPGA, // runtime key
|
||||
DispatchKey::Functionalize, // runtime key
|
||||
DispatchKey::Dense}); // NOT a runtime key
|
||||
auto mixed_iter = mixed_keyset.begin();
|
||||
ASSERT_EQ(*mixed_iter++, DispatchKey::CPU);
|
||||
ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA);
|
||||
ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize);
|
||||
ASSERT_EQ(*mixed_iter, *mixed_keyset.end());
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, Empty) {
|
||||
DispatchKeySet empty_set;
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
for (uint8_t i = 0;
|
||||
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
|
||||
i++) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
if (tid == DispatchKey::Undefined)
|
||||
continue;
|
||||
ASSERT_FALSE(empty_set.has(tid));
|
||||
}
|
||||
ASSERT_TRUE(empty_set.empty());
|
||||
DispatchKeySet empty_set2;
|
||||
ASSERT_TRUE(empty_set == empty_set2);
|
||||
ASSERT_EQ(empty_set.highestPriorityTypeId(), DispatchKey::Undefined);
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, Singleton) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
i++) {
|
||||
// This covers all keys that correspond to a single backend bit, e.g.
|
||||
// BackendComponent::CPUBit. Even though these are NOT runtime keys, we still
|
||||
// allow adding them directly to a keyset
|
||||
TEST(DispatchKeySet, SingletonBackendComponent) {
|
||||
for (const auto i : c10::irange(1, num_backends)) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
DispatchKeySet sing(tid);
|
||||
ASSERT_EQ(sing, sing);
|
||||
ASSERT_EQ(sing, DispatchKeySet().add(tid));
|
||||
ASSERT_EQ(sing, sing.add(tid));
|
||||
ASSERT_EQ(sing, sing | sing);
|
||||
ASSERT_FALSE(sing.empty());
|
||||
ASSERT_TRUE(sing.has(tid));
|
||||
}
|
||||
}
|
||||
|
||||
// This covers all keys that correspond to a single functionality bit:
|
||||
// - runtime, not-per-backend functionality keys, e.g.
|
||||
// DispatchKey::FuncTorchBatched
|
||||
// - runtime, "fake backend" keys, e.g. DispatchKey::FPGA
|
||||
// - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense
|
||||
// Even though it's not a runtime key, we still allow adding it directly to a
|
||||
// keyset.
|
||||
// DispatchKey::
|
||||
TEST(DispatchKeySet, SingletonFunctionalityKeys) {
|
||||
for (const auto i : c10::irange(1, num_functionality_keys)) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
DispatchKeySet sing(tid);
|
||||
ASSERT_EQ(sing, sing);
|
||||
@ -30,47 +168,145 @@ TEST(DispatchKeySet, Singleton) {
|
||||
ASSERT_EQ(sing, sing | sing);
|
||||
ASSERT_FALSE(sing.empty());
|
||||
ASSERT_TRUE(sing.has(tid));
|
||||
ASSERT_EQ(sing.highestPriorityTypeId(), tid);
|
||||
ASSERT_EQ(sing.remove(tid), DispatchKeySet());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, Doubleton) {
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
// This covers runtime keys that are per-backend,
|
||||
// and take up more than one bit in a DispatchKeySet. They take up one
|
||||
// functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA,
|
||||
// AutogradCPU, AutogradCUDA
|
||||
TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
|
||||
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
|
||||
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
|
||||
i++) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
// Skip these because they aren't real keys.
|
||||
if (tid == DispatchKey::StartOfDenseBackends ||
|
||||
tid == DispatchKey::StartOfSparseBackends ||
|
||||
tid == DispatchKey::StartOfQuantizedBackends ||
|
||||
tid == DispatchKey::StartOfAutogradBackends) {
|
||||
continue;
|
||||
}
|
||||
DispatchKeySet sing(tid);
|
||||
ASSERT_EQ(sing, sing);
|
||||
ASSERT_EQ(sing, DispatchKeySet().add(tid));
|
||||
ASSERT_EQ(sing, sing.add(tid));
|
||||
ASSERT_EQ(sing, sing | sing);
|
||||
ASSERT_FALSE(sing.empty());
|
||||
ASSERT_TRUE(sing.has(tid));
|
||||
|
||||
auto functionality_key = toFunctionalityKey(tid);
|
||||
auto backend_key = toBackendComponent(tid);
|
||||
// These two sets should be equivalent:
|
||||
// DispatchKeySet(DispatchKey::CPU)
|
||||
// DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit})
|
||||
auto expected_ks =
|
||||
DispatchKeySet(functionality_key) | DispatchKeySet(backend_key);
|
||||
ASSERT_EQ(sing, expected_ks);
|
||||
// These two sets should be equivalent:
|
||||
// DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense)
|
||||
// DispatchKeySet(BackendComponent::CPUBit)
|
||||
expected_ks = DispatchKeySet(toBackendComponent(tid));
|
||||
ASSERT_EQ(sing.remove(tid), expected_ks);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, DoubletonPerBackend) {
|
||||
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
|
||||
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
|
||||
i++) {
|
||||
for (uint8_t j = i + 1;
|
||||
j < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
j <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
|
||||
j++) {
|
||||
ASSERT_LT(i, j);
|
||||
auto tid1 = static_cast<DispatchKey>(i);
|
||||
auto tid2 = static_cast<DispatchKey>(j);
|
||||
auto doub = DispatchKeySet(tid1).add(tid2);
|
||||
ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2));
|
||||
ASSERT_TRUE(doub.has(tid1));
|
||||
ASSERT_TRUE(doub.has(tid2));
|
||||
ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j
|
||||
|
||||
// Skip these because they aren't real keys.
|
||||
if (tid1 == DispatchKey::StartOfDenseBackends ||
|
||||
tid1 == DispatchKey::StartOfSparseBackends ||
|
||||
tid1 == DispatchKey::StartOfQuantizedBackends ||
|
||||
tid1 == DispatchKey::StartOfAutogradBackends)
|
||||
continue;
|
||||
if (tid2 == DispatchKey::StartOfDenseBackends ||
|
||||
tid2 == DispatchKey::StartOfSparseBackends ||
|
||||
tid2 == DispatchKey::StartOfQuantizedBackends ||
|
||||
tid2 == DispatchKey::StartOfAutogradBackends)
|
||||
continue;
|
||||
|
||||
auto backend1 = toBackendComponent(tid1);
|
||||
auto backend2 = toBackendComponent(tid2);
|
||||
auto functionality1 = toFunctionalityKey(tid1);
|
||||
auto functionality2 = toFunctionalityKey(tid2);
|
||||
|
||||
auto combined = DispatchKeySet({tid1, tid2});
|
||||
// The combined set has the backend bits
|
||||
ASSERT_TRUE(combined.has_backend(backend1));
|
||||
ASSERT_TRUE(combined.has_backend(backend2));
|
||||
// and it has the backend bits
|
||||
ASSERT_TRUE(combined.has(functionality1));
|
||||
ASSERT_TRUE(combined.has(functionality2));
|
||||
// and it has the original two runtime keys
|
||||
ASSERT_TRUE(combined.has(tid1));
|
||||
ASSERT_TRUE(combined.has(tid2));
|
||||
|
||||
// Add all of the keys in the keyset to a real set
|
||||
std::unordered_set<DispatchKey> visited_keys;
|
||||
auto iter = combined.begin();
|
||||
while (*iter != *combined.end()) {
|
||||
visited_keys.insert(*iter);
|
||||
++iter;
|
||||
}
|
||||
std::unordered_set<DispatchKey> expected_keys;
|
||||
expected_keys.insert(
|
||||
toRuntimePerBackendFunctionalityKey(functionality1, backend1));
|
||||
expected_keys.insert(
|
||||
toRuntimePerBackendFunctionalityKey(functionality1, backend2));
|
||||
expected_keys.insert(
|
||||
toRuntimePerBackendFunctionalityKey(functionality2, backend1));
|
||||
expected_keys.insert(
|
||||
toRuntimePerBackendFunctionalityKey(functionality2, backend2));
|
||||
ASSERT_EQ(expected_keys, visited_keys);
|
||||
|
||||
if (backend1 == backend2 || functionality1 == functionality2) {
|
||||
// We have two runtime keys, with either the same backend or the same
|
||||
// per-backend functionalities. E.g. {AutogradCUDA, CUDA} or
|
||||
// {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in
|
||||
// this set.
|
||||
ASSERT_EQ(2, visited_keys.size());
|
||||
} else {
|
||||
// since i and j are different keys, they should not have the same
|
||||
// functionality and backend
|
||||
ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2);
|
||||
// We have two runtime keys, that have different backends + per-backend
|
||||
// functionalities. So we should expect the full cross product of
|
||||
// runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU,
|
||||
// then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU}
|
||||
ASSERT_EQ(4, visited_keys.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, Full) {
|
||||
DispatchKeySet full(DispatchKeySet::FULL);
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
|
||||
i++) {
|
||||
for (const auto i : c10::irange(1, num_functionality_keys)) {
|
||||
auto tid = static_cast<DispatchKey>(i);
|
||||
ASSERT_TRUE(full.has(tid));
|
||||
}
|
||||
ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys));
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, IteratorBasicOps) {
|
||||
DispatchKeySet empty_set;
|
||||
DispatchKeySet full_set(DispatchKeySet::FULL);
|
||||
DispatchKeySet mutated_set = empty_set.add(static_cast<DispatchKey>(1));
|
||||
DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU);
|
||||
|
||||
// Constructor + Comparison
|
||||
ASSERT_EQ(*empty_set.begin(), DispatchKey::NumDispatchKeys);
|
||||
ASSERT_EQ(*empty_set.end(), DispatchKey::NumDispatchKeys);
|
||||
ASSERT_EQ(*mutated_set.begin(), static_cast<DispatchKey>(1));
|
||||
ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys);
|
||||
ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys);
|
||||
ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU);
|
||||
|
||||
ASSERT_TRUE(empty_set.begin() == empty_set.end());
|
||||
ASSERT_TRUE(full_set.begin() != full_set.end());
|
||||
@ -90,16 +326,37 @@ TEST(DispatchKeySet, IteratorEmpty) {
|
||||
ASSERT_EQ(i, 0);
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, IteratorCrossProduct) {
|
||||
// The iterator should return all runtime keys in the set,
|
||||
// including the cross product of {backends} x {functionalities}
|
||||
auto ks =
|
||||
DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) |
|
||||
DispatchKeySet(
|
||||
{DispatchKey::Dense,
|
||||
DispatchKey::FPGA,
|
||||
DispatchKey::AutogradFunctionality});
|
||||
|
||||
auto iter = ks.begin();
|
||||
// iterate through dense backends first.
|
||||
ASSERT_EQ(DispatchKey::CPU, *(iter++));
|
||||
ASSERT_EQ(DispatchKey::CUDA, *(iter++));
|
||||
// FPGA doesn't have a backend bit, so it isn't included in the cross product.
|
||||
ASSERT_EQ(DispatchKey::FPGA, *(iter++));
|
||||
// iterate through the autograd keys laster.
|
||||
ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++));
|
||||
ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++));
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, IteratorFull) {
|
||||
DispatchKeySet full_set(DispatchKeySet::FULL);
|
||||
uint8_t i = 0;
|
||||
|
||||
for (const auto& it : full_set) {
|
||||
i++;
|
||||
ASSERT_TRUE(it == static_cast<DispatchKey>(i));
|
||||
ASSERT_TRUE(it != DispatchKey::NumDispatchKeys);
|
||||
}
|
||||
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1);
|
||||
// Total # of runtime entries includes an entry for DispatchKey::Undefined,
|
||||
// which is not included when iterating through the DispatchKeySet.
|
||||
ASSERT_EQ(i, num_runtime_entries - 1);
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, IteratorRangeFull) {
|
||||
@ -108,41 +365,61 @@ TEST(DispatchKeySet, IteratorRangeFull) {
|
||||
|
||||
for (DispatchKey dispatch_key : full_set) {
|
||||
i++;
|
||||
ASSERT_TRUE(dispatch_key == static_cast<DispatchKey>(i));
|
||||
}
|
||||
|
||||
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1);
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, SpecificKeys) {
|
||||
DispatchKeySet keyset({
|
||||
static_cast<DispatchKey>(0), // Undefined should be ignored
|
||||
static_cast<DispatchKey>(4),
|
||||
static_cast<DispatchKey>(10),
|
||||
static_cast<DispatchKey>(15),
|
||||
});
|
||||
std::unordered_set<DispatchKey> visited_keys;
|
||||
|
||||
for (DispatchKey key : keyset) {
|
||||
visited_keys.insert(key);
|
||||
}
|
||||
|
||||
ASSERT_EQ(visited_keys.size(), 3);
|
||||
ASSERT_TRUE(
|
||||
visited_keys.find(static_cast<DispatchKey>(4)) != visited_keys.end());
|
||||
ASSERT_TRUE(
|
||||
visited_keys.find(static_cast<DispatchKey>(10)) != visited_keys.end());
|
||||
ASSERT_TRUE(
|
||||
visited_keys.find(static_cast<DispatchKey>(15)) != visited_keys.end());
|
||||
// Total # of runtime entries includes an entry for DispatchKey::Undefined,
|
||||
// which is not included when iterating through the DispatchKeySet.
|
||||
ASSERT_EQ(i, num_runtime_entries - 1);
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, FailAtEndIterator) {
|
||||
DispatchKeySet full_set(DispatchKeySet::FULL);
|
||||
uint64_t raw_repr = full_set.raw_repr();
|
||||
|
||||
// doesn't throw
|
||||
DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
EXPECT_THROW(
|
||||
DispatchKeySet::iterator(
|
||||
&raw_repr, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) + 1),
|
||||
&raw_repr, num_backends + num_functionality_keys + 1),
|
||||
c10::Error);
|
||||
}
|
||||
|
||||
TEST(DispatchKeySet, TestKeyOrderingInvariants) {
|
||||
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
|
||||
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
|
||||
i++) {
|
||||
auto k = static_cast<DispatchKey>(i);
|
||||
// Note [The Ordering of Per-Backend Dispatch Keys Matters!]
|
||||
// The DispatchKey enum includes all of the runtime keys for
|
||||
// Dense/Sparse/Quantized/Autograd, (e.g. CPU, CUDA, SparseCPU, SparseCUDA,
|
||||
// AutogradCPU, AutogradCUDA, etc). And we expect the ordering of those keys
|
||||
// to be the same as the ordering of the backends in the `BackendComponent`
|
||||
// enum. This makes several utilities in `DispatchKey.h` and
|
||||
// `DispatchKeySet.h` significantly easier to implement. The purpose of the
|
||||
// test is to assert (through CI) that this invariant is maintained.
|
||||
//
|
||||
// The only way that we can really check this invariant is by
|
||||
// comparing the string names of each enum.
|
||||
// We only really care about the ordering for "real" keys that are actually
|
||||
// used, which we expect to be able to print properly. This saves us from
|
||||
// having to enumerate the full set of possible runtime keys in
|
||||
// DispatchKey::toString(). It also relies on toString() being implemented
|
||||
// correctly.
|
||||
auto functionality_str = std::string(toString(k));
|
||||
if (functionality_str == "UNKNOWN_TENSOR_TYPE_ID")
|
||||
continue;
|
||||
|
||||
auto computed_backend_k = toBackendComponent(k);
|
||||
auto computed_backend_str = std::string(toString(computed_backend_k));
|
||||
// Skip, e.g., the "Bit" from "CPUBit"
|
||||
computed_backend_str =
|
||||
computed_backend_str.substr(0, computed_backend_str.size() - 3);
|
||||
|
||||
ASSERT_TRUE(
|
||||
functionality_str.find(computed_backend_str) != std::string::npos)
|
||||
<< "DispatchKey invariant broken! Found a key that is not ordered correctly"
|
||||
<< " with its backend bit. key = " << toString(k) << ", " << k
|
||||
<< ", computed backend = " << toString(computed_backend_k);
|
||||
}
|
||||
}
|
||||
|
@ -532,8 +532,8 @@ AutogradXLA: fn_math [math kernel]
|
||||
lambda m: m.def_("foo(Tensor x) -> Tensor"),
|
||||
# m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
|
||||
lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"),
|
||||
# m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x })
|
||||
lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"),
|
||||
# m.impl("foo", torch::kFPGA, [](const Tensor & x) { return x })
|
||||
lambda m: m.impl_t_t("foo", "FPGA", debug="fn_fpga"),
|
||||
])
|
||||
state, table = result.state, result.table
|
||||
self.assertExpectedInline(state, '''\
|
||||
@ -541,12 +541,12 @@ name: test::foo
|
||||
schema: test::foo(Tensor x) -> (Tensor)
|
||||
debug: registered at /dev/null:0
|
||||
alias analysis kind: FROM_SCHEMA
|
||||
QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
FPGA: fn_fpga :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
''')
|
||||
|
||||
# computed dispatch table is too big, so we only check on a few entries we're interested in.
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',))
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_math [math kernel]
|
||||
@ -557,7 +557,7 @@ AutogradOther: ambiguous_autogradother [ambiguous autogradother]
|
||||
AutogradCPU: fn_math [math kernel]
|
||||
AutogradCUDA: fn_math [math kernel]
|
||||
AutogradXLA: fn_math [math kernel]
|
||||
QuantizedCPU: fn_quantizedcpu [kernel]
|
||||
FPGA: fn_fpga [kernel]
|
||||
''')
|
||||
|
||||
def test_computed_table_with_cpu_defaultbackend(self):
|
||||
@ -616,7 +616,7 @@ CompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0
|
||||
''')
|
||||
|
||||
# computed dispatch table is too big, so we only check on a few entries we're interested in.
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',))
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_defaultbackend [default backend kernel]
|
||||
@ -627,7 +627,7 @@ AutogradOther: fn_autograd [autograd kernel]
|
||||
AutogradCPU: fn_autograd [autograd kernel]
|
||||
AutogradCUDA: fn_autograd [autograd kernel]
|
||||
AutogradXLA: fn_autograd [autograd kernel]
|
||||
QuantizedCPU: fn_defaultbackend [default backend kernel]
|
||||
FPGA: fn_defaultbackend [default backend kernel]
|
||||
''')
|
||||
|
||||
def test_computed_table_with_cpu_autograd_math_defaultbackend(self):
|
||||
@ -808,7 +808,7 @@ key kernel
|
||||
CPU fn_CPU [kernel]
|
||||
XLA fn_XLA [kernel]
|
||||
Lazy fn_Lazy [kernel]
|
||||
QuantizedCPU fn_CompositeImplicitAutograd [math kernel]
|
||||
FPGA fn_CompositeImplicitAutograd [math kernel]
|
||||
AutogradOther fn_CompositeImplicitAutograd [math kernel]
|
||||
AutogradCPU fallthrough [backend fallback]
|
||||
AutogradXLA fallthrough [backend fallback]
|
||||
@ -829,7 +829,7 @@ key kernel
|
||||
CPU fn_CPU [kernel]
|
||||
XLA fn_XLA [kernel]
|
||||
Lazy fn_Lazy [kernel]
|
||||
QuantizedCPU fn_CompositeImplicitAutograd [math kernel]
|
||||
FPGA fn_CompositeImplicitAutograd [math kernel]
|
||||
AutogradOther fn_CompositeImplicitAutograd [math kernel]
|
||||
AutogradCPU fn_AutogradCPU [kernel]
|
||||
AutogradXLA fallthrough [backend fallback]
|
||||
@ -864,7 +864,7 @@ key kernel
|
||||
CPU fn_CPU [kernel]
|
||||
XLA fn_XLA [kernel]
|
||||
Lazy fn_Lazy [kernel]
|
||||
QuantizedCPU fn_CompositeExplicitAutograd [default backend kernel]
|
||||
FPGA fn_CompositeExplicitAutograd [default backend kernel]
|
||||
AutogradOther fallthrough [backend fallback]
|
||||
AutogradCPU fn_AutogradCPU [kernel]
|
||||
AutogradXLA fallthrough [backend fallback]
|
||||
@ -889,7 +889,7 @@ CompositeExplicitAutograd[alias] fn_CompositeExplicitAutograd
|
||||
|
||||
def test_autogradother(self):
|
||||
dispatcher = PythonDispatcher()
|
||||
dispatcher.register(["CPU", "QuantizedCPU", "CompositeImplicitAutograd"])
|
||||
dispatcher.register(["CPU", "FPGA", "CompositeImplicitAutograd"])
|
||||
self.assertExpectedInline(
|
||||
dispatcher.dispatchTable(),
|
||||
'''\
|
||||
@ -900,7 +900,7 @@ key kernel
|
||||
CPU fn_CPU [kernel]
|
||||
XLA fn_CompositeImplicitAutograd [math kernel]
|
||||
Lazy fn_CompositeImplicitAutograd [math kernel]
|
||||
QuantizedCPU fn_QuantizedCPU [kernel]
|
||||
FPGA fn_FPGA [kernel]
|
||||
AutogradOther ambiguous_autogradother [ambiguous autogradother]
|
||||
AutogradCPU fallthrough [backend fallback]
|
||||
AutogradXLA fn_CompositeImplicitAutograd [math kernel]
|
||||
@ -915,8 +915,8 @@ AutogradLazy fn_CompositeImplicitAutograd [math kernel]
|
||||
Registered Kernels
|
||||
key kernel
|
||||
---------------------------
|
||||
FPGA fn_FPGA
|
||||
CPU fn_CPU
|
||||
QuantizedCPU fn_QuantizedCPU
|
||||
CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
|
||||
'''
|
||||
)
|
||||
|
@ -3410,21 +3410,21 @@ class TestSparseOneOff(TestCase):
|
||||
def test_cuda_from_cpu(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"):
|
||||
"Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
|
||||
torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(),
|
||||
torch.randn(4, 4, 4),
|
||||
[3, 4, 4])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"):
|
||||
"Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
|
||||
torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(),
|
||||
torch.randn(4, 4, 4, 0),
|
||||
[3, 4, 4, 0])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"):
|
||||
"Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
|
||||
torch.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(),
|
||||
torch.randn(0, 4, 4, 0),
|
||||
[0, 4, 4, 0])
|
||||
|
@ -48,58 +48,66 @@ class DispatchKey(Enum):
|
||||
Undefined = 0
|
||||
CatchAll = Undefined
|
||||
|
||||
CPU = auto()
|
||||
CUDA = auto()
|
||||
HIP = auto()
|
||||
Dense = auto()
|
||||
FPGA = auto()
|
||||
ORT = auto()
|
||||
XLA = auto()
|
||||
Lazy = auto()
|
||||
Vulkan = auto()
|
||||
Metal = auto()
|
||||
XPU = auto()
|
||||
MKLDNN = auto()
|
||||
OpenGL = auto()
|
||||
OpenCL = auto()
|
||||
IDEEP = auto()
|
||||
QuantizedCPU = auto()
|
||||
QuantizedCUDA = auto()
|
||||
QuantizedXPU = auto()
|
||||
Quantized = auto()
|
||||
CustomRNGKeyId = auto()
|
||||
MkldnnCPU = auto()
|
||||
SparseCPU = auto()
|
||||
SparseCUDA = auto()
|
||||
Sparse = auto()
|
||||
SparseCsrCPU = auto()
|
||||
SparseCsrCUDA = auto()
|
||||
SparseHIP = auto()
|
||||
SparseXPU = auto()
|
||||
NestedTensor = auto()
|
||||
PrivateUse1 = auto()
|
||||
PrivateUse2 = auto()
|
||||
PrivateUse3 = auto()
|
||||
EndOfBackendKeys = PrivateUse3
|
||||
|
||||
ZeroTensor = auto()
|
||||
Meta = auto()
|
||||
BackendSelect = auto()
|
||||
Named = auto()
|
||||
AutogradOther = auto()
|
||||
AutogradCPU = auto()
|
||||
AutogradCUDA = auto()
|
||||
AutogradXLA = auto()
|
||||
AutogradLazy = auto()
|
||||
AutogradFunctionality = auto()
|
||||
AutogradNestedTensor = auto()
|
||||
AutogradXPU = auto()
|
||||
AutogradPrivateUse1 = auto()
|
||||
AutogradPrivateUse2 = auto()
|
||||
AutogradPrivateUse3 = auto()
|
||||
Tracer = auto()
|
||||
Autocast = auto()
|
||||
Batched = auto()
|
||||
VmapMode = auto()
|
||||
TESTING_ONLY_GenericWrapper = auto()
|
||||
TESTING_ONLY_GenericMode = auto()
|
||||
NumDispatchKeys = auto()
|
||||
EndOfFunctionalityKeys = TESTING_ONLY_GenericMode
|
||||
|
||||
CPU = auto()
|
||||
CUDA = auto()
|
||||
HIP = auto()
|
||||
XLA = auto()
|
||||
Lazy = auto()
|
||||
XPU = auto()
|
||||
NestedTensor = auto()
|
||||
PrivateUse1 = auto()
|
||||
PrivateUse2 = auto()
|
||||
PrivateUse3 = auto()
|
||||
|
||||
QuantizedCPU = auto()
|
||||
QuantizedCUDA = auto()
|
||||
QuantizedXPU = auto()
|
||||
|
||||
SparseCPU = auto()
|
||||
SparseCUDA = auto()
|
||||
SparseHIP = auto()
|
||||
SparseXPU = auto()
|
||||
|
||||
AutogradCPU = auto()
|
||||
AutogradCUDA = auto()
|
||||
AutogradXLA = auto()
|
||||
AutogradLazy = auto()
|
||||
AutogradXPU = auto()
|
||||
AutogradPrivateUse1 = auto()
|
||||
AutogradPrivateUse2 = auto()
|
||||
AutogradPrivateUse3 = auto()
|
||||
|
||||
Autograd = auto()
|
||||
CompositeImplicitAutograd = auto()
|
||||
CompositeExplicitAutograd = auto()
|
||||
|
@ -15,9 +15,9 @@ keys for a single example of each use case. These use cases are listed below:
|
||||
- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
|
||||
autograd kernel in pytorch core library.
|
||||
E.g. CPU, CUDA
|
||||
- QuantizedCPU/AutogradOther: represents in-tree backends which we usually have backend specific
|
||||
- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
|
||||
inference kernels, but they share the same autograd kernel specified in AutogradOther.
|
||||
E.g. QuantizedCPU, QuantizedCUDA
|
||||
E.g. FPGA, SparseCsrCPU
|
||||
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
|
||||
kernel defined in pytorch core library. Backend owner is responsible for registering both
|
||||
inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
|
||||
@ -53,7 +53,7 @@ class PythonDispatcher:
|
||||
name = "foo"
|
||||
runtime_keys = [
|
||||
"CPU", "AutogradCPU",
|
||||
"QuantizedCPU", "AutogradOther",
|
||||
"FPGA", "AutogradOther",
|
||||
"XLA", "AutogradXLA",
|
||||
"Lazy", "AutogradLazy",
|
||||
]
|
||||
|
Reference in New Issue
Block a user