Revert D34034848: free up dispatch key space (in C++)

Test Plan: revert-hammer

Differential Revision:
D34034848 (6690256021)

Original commit changeset: 9677ee2c0a1a

Original Phabricator Diff: D34034848 (6690256021)

fbshipit-source-id: fd50943d915ef813bb9f9ab278fb582429eea3b1
(cherry picked from commit 3acefee1cdb89bc051d1ef2e9deb5698d2bd85c3)
This commit is contained in:
Brian Hirsh
2022-02-14 15:23:25 -08:00
committed by PyTorch MergeBot
parent 7f560fb3e0
commit 22ccf448e8
20 changed files with 514 additions and 1747 deletions

View File

@ -28,7 +28,8 @@ constexpr auto kFunctorchWrappedTensors = DispatchKeySet({
constexpr auto kTensorSubclassLike = kFunctorchWrappedTensors | DispatchKeySet({
DispatchKey::Batched,
DispatchKey::Sparse,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::Meta,

View File

@ -43,6 +43,7 @@ inline bool variable_excluded_from_dispatch() {
// Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
return true;
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::Autograd));
return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
#endif
}

View File

@ -6,52 +6,11 @@
namespace c10 {
void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) {
// (1) update nonFallthroughKeys_
if (has_fallthrough) {
nonFallthroughKeys_ = nonFallthroughKeys_.remove(k);
} else {
nonFallthroughKeys_ = nonFallthroughKeys_.add(k);
}
// (2) update nonFallthroughKeysPerBackend_
if (isPerBackendFunctionalityKey(toFunctionalityKey(k))) {
// This is a per-backend functionality key.
// We need to figure out what the current backend is,
// and only update the bitset for that backend.
// subtracting 1 because the first backend should have index 0 (CPU),
// But the enum starts with BackendComponent::InvalidBit.
auto backend_idx = static_cast<uint8_t>(toBackendComponent(k)) - 1;
TORCH_INTERNAL_ASSERT(backend_idx >= 0 && static_cast<uint8_t>(backend_idx) < nonFallthroughKeysPerBackend_.size());
if (has_fallthrough) {
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].remove(k);
} else {
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].add(k);
}
// Set requiresBitsetPerBackend_ accordingly
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size() - 1)) {
if (nonFallthroughKeysPerBackend_[i] != nonFallthroughKeysPerBackend_[i+1]) {
requiresBitsetPerBackend_ = true;
return;
}
}
requiresBitsetPerBackend_ = false;
return;
} else {
// Otherwise, if a fallthrough is set for a functionality that isn't per backend,
// Then we update the fallthrough bitset for EVERY backend.
// TODO: we could probably optimize this by only lazily updating these values
// the first time that we see requiresBitsetPerBackend_ = true
// (which should almost never happen)
if (has_fallthrough) {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].remove(k);
}
} else {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].add(k);
}
}
}
}
std::string DispatchKeyExtractor::dumpState() const {

View File

@ -156,24 +156,14 @@ public:
}
});
// Keys that are fallthrough should be skipped
if (requiresBitsetPerBackend_) {
auto backend_idx = ks.getBackendIndex();
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
} else {
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
}
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
}
template<class... Args>
DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
auto ks = detail::multi_dispatch_key_set(args...);
// Keys that are fallthrough should be skipped
if (requiresBitsetPerBackend_) {
auto backend_idx = ks.getBackendIndex();
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
} else {
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
}
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
}
void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
@ -203,12 +193,7 @@ private:
explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
, nonFallthroughKeys_(DispatchKeySet::FULL)
, requiresBitsetPerBackend_(false) {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
}
}
, nonFallthroughKeys_(DispatchKeySet::FULL) {}
// this is a bitset that has ones for each argument index which has to be
// considered for dispatch. This avoids having to iterate over the stack
@ -220,14 +205,8 @@ private:
// fallthrough
c10::utils::bitset dispatch_arg_indices_reverse_;
// Set of functionality keys for which the operator does NOT have fallthrough kernel.
// Set of keys for which the operator does NOT have fallthrough kernel.
DispatchKeySet nonFallthroughKeys_;
// Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
// This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
// Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
// or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
bool requiresBitsetPerBackend_;
};
}

View File

@ -267,15 +267,14 @@ void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name)
RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
std::lock_guard<std::mutex> lock(mutex_);
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
TORCH_CHECK(
!backendFallbackKernels_[idx].kernel.isValid(),
!backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
backendFallbackKernels_[idx].debug, ", new registration ", debug
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].debug, ", new registration ", debug
);
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
// cannot be unobxed
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey);
@ -289,8 +288,7 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
std::lock_guard<std::mutex> lock(mutex_);
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
backendFallbackKernels_[idx] = {};
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = {};
for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey);

View File

@ -291,7 +291,7 @@ private:
// Map from namespace to debug string (saying, e.g., where the library was defined)
ska::flat_hash_map<std::string, std::string> libraries_;
std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
std::array<impl::AnnotatedKernel, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> backendFallbackKernels_;
std::unique_ptr<detail::RegistrationListenerList> listeners_;
std::mutex mutex_;
@ -531,7 +531,8 @@ C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorH
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...);
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
@ -552,7 +553,7 @@ template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// do not use RecordFunction on redispatch
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}
@ -560,7 +561,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorDef_->op;
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
const auto& kernel = entry.lookup(dispatchKeySet);
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
@ -592,7 +593,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorDef_->op;
const auto& kernel = entry.lookup(dispatchKeySet);
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
return kernel.callBoxed(op, dispatchKeySet, stack);
}

View File

@ -283,7 +283,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
}
// 3. Backend fallback
auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
}
@ -299,7 +299,10 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// or alias keys and their associated keysets).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
const auto dispatch_ix = c10::getDispatchTableIndexForDispatchKey(dispatch_key);
if (C10_UNLIKELY(dispatch_ix == -1)) {
return;
}
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
}
@ -326,12 +329,8 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
}
// Note [Refresh Runtime Autograd entries in dispatchTable_]
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
// In theory, we should only have to check if the given runtime key has "dense" functionality,
// e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
// However, there are some backends that should be included in this set that don't have the dense key set.
// E.g. DispatchKey::Meta, DispatchKey::ORT.
if (c10::isBackendDispatchKey(dispatch_key)) {
DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
updateDispatchTableEntry_(dispatcher, autograd_key);
}
}
@ -358,9 +357,8 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher)
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd
// or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd)
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
updateDispatchTable_(dispatcher, DispatchKey::Undefined);
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
updateDispatchTable_(dispatcher, k);
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
}
}
@ -373,10 +371,9 @@ void OperatorEntry::checkInvariants() const {
for (const auto& kv : kernels_) {
TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState());
}
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), k);
auto idx = getDispatchTableIndexForDispatchKey(k);
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[idx]),
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), static_cast<DispatchKey>(iter));
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[iter]),
"Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n"
"Computed table:\n~~~~~~~~~~~\n", dumpComputedTable());
}
@ -387,8 +384,7 @@ std::string OperatorEntry::listAllDispatchKeys() const {
str << "[";
bool has_kernels = false;
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
auto iter = getDispatchTableIndexForDispatchKey(k);
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
if (!dispatchTable_[iter].isValid()) {
continue;
}
@ -447,12 +443,8 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const {
// updateDispatchTableFull_ would update the dispatch table to be)
std::string OperatorEntry::dumpComputedTable() const {
std::ostringstream oss;
// Need to handle Undefined separately, because its a runtime key that can't be represented
// in a DispatchKeySet.
std::vector<DispatchKey> runtime_keys = {DispatchKey::Undefined};
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) runtime_keys.push_back(k);
for (auto k : runtime_keys) {
for (uint8_t i = 0; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) {
auto k = static_cast<DispatchKey>(i);
auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
if (kernel_prov.first.kernel.isValid()) {
oss << toString(k) << ": "

View File

@ -173,8 +173,11 @@ public:
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
const KernelFunction& lookup(DispatchKeySet ks) const {
const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
const KernelFunction& lookup(DispatchKey k) const {
const auto idx = getDispatchTableIndexForDispatchKey(k);
if (C10_UNLIKELY(idx == -1)) {
reportError(k);
}
const auto& kernel = dispatchTable_[idx];
// A valid kernel *always* has a boxed kernel and *may* have an
// unboxed kernel. However, we typically do unboxed calls in at::
@ -184,7 +187,7 @@ public:
// in the common case.
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
if (!kernel.isValid()) {
reportError(ks.highestPriorityTypeId());
reportError(k);
}
}
return kernel;
@ -208,7 +211,7 @@ private:
OperatorName name_;
c10::optional<AnnotatedSchema> schema_;
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
std::array<KernelFunction, c10::getDispatchTableIndexForDispatchKey(DispatchKey::NumDispatchKeys)> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
// kernels_ stores all registered kernels for the corresponding dispatch key

View File

@ -591,7 +591,7 @@ TEST(OperatorRegistrationTest, AutogradBackendOverridesAutogradKernel) {
void LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key) {
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
.kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(toBackendComponent(key)))
.kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(key))
.kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
@ -1791,22 +1791,22 @@ TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) {
TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool fpga_called, math_called = false;
bool sparsecpu_called, math_called = false;
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn", torch::dispatch(c10::DispatchKey::FPGA, [&](const Tensor& x) { fpga_called = true; return x; }));
m.def("fn", torch::dispatch(c10::DispatchKey::SparseCPU, [&](const Tensor& x) { sparsecpu_called = true; return x; }));
m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());
{
callOp(*op, dummyTensor(c10::DispatchKey::FPGA));
ASSERT_TRUE(fpga_called);
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
ASSERT_TRUE(sparsecpu_called);
}
{
expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::FPGA, /*requires_grad=*/true));
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
}, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther.");
}
}
@ -1849,15 +1849,18 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
}
{
// TODO(#43908): currently this will fallthrough AutogradPrivateUse1 then call catchall kernel
// at AutogradCPU, while backend extenders are indeed expecting to call PrivateUse1 kernel.
// This confusing behavior is caused by we registering fallthrough as backend fallback for
// Autograd keys. Note users could always work around this by registering the same kernel to
// AutogradPrivateUse1 as shown below until we support it.
auto op = Dispatcher::singleton().findOp({"test::fn", ""});
ASSERT_TRUE(op.has_value());
catchall_called = false;
privateuse1_called = false;
callOp(*op,
dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true),
dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
ASSERT_FALSE(catchall_called);
ASSERT_TRUE(privateuse1_called);
ASSERT_TRUE(catchall_called);
}
m.impl("fn", c10::DispatchKey::AutogradPrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; });
@ -1873,27 +1876,6 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
}
}
TEST(NewOperatorRegistrationTest, registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite) {
bool math_called = false;
bool cpu_called = false;
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn(Tensor dummy) -> Tensor");
m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; });
m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());
{
math_called = cpu_called = false;
// Meta should redispatch to the AutogradOther backend,
// which the composite kernel should be registered to.
callOp(*op, dummyTensor(c10::DispatchKey::Meta, /*requires_grad=*/true));
ASSERT_TRUE(math_called);
ASSERT_FALSE(cpu_called);
}
}
TEST(NewOperatorRegistrationTest, dispatchMultiple) {
bool cpu_called = false;
bool cuda_called = false;

View File

@ -1,47 +1,14 @@
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <unordered_map>
namespace c10 {
const char* toString(BackendComponent t) {
switch (t) {
case BackendComponent::CPUBit:
return "CPUBit";
case BackendComponent::CUDABit:
return "CUDABit";
case BackendComponent::HIPBit:
return "HIPBit";
case BackendComponent::XLABit:
return "XLABit";
case BackendComponent::LazyBit:
return "LazyBit";
case BackendComponent::XPUBit:
return "XPUBit";
case BackendComponent::MLCBit:
return "MLCBit";
case BackendComponent::HPUBit:
return "HPUBit";
case BackendComponent::VEBit:
return "VEBit";
case BackendComponent::PrivateUse1Bit:
return "PrivateUse1Bit";
case BackendComponent::PrivateUse2Bit:
return "PrivateUse2Bit";
case BackendComponent::PrivateUse3Bit:
return "PrivateUse3Bit";
case BackendComponent::InvalidBit:
return "InvalidBit";
default:
return "UNKNOWN_BACKEND_BIT";
}
}
const char* toString(DispatchKey t) {
switch (t) {
case DispatchKey::Undefined:
return "Undefined";
case DispatchKey::CPU:
return "CPU";
case DispatchKey::CUDA:
@ -134,6 +101,8 @@ const char* toString(DispatchKey t) {
return "AutogradMLC";
case DispatchKey::AutogradHPU:
return "AutogradHPU";
case DispatchKey::AutogradNestedTensor:
return "AutogradNestedTensor";
case DispatchKey::AutogradPrivateUse1:
return "AutogradPrivateUse1";
case DispatchKey::AutogradPrivateUse2:
@ -142,8 +111,6 @@ const char* toString(DispatchKey t) {
return "AutogradPrivateUse3";
case DispatchKey::AutogradOther:
return "AutogradOther";
case DispatchKey::AutogradNestedTensor:
return "AutogradNestedTensor";
case DispatchKey::ZeroTensor:
return "ZeroTensor";
@ -201,15 +168,6 @@ const char* toString(DispatchKey t) {
case DispatchKey::FuncTorchBatched:
return "FuncTorchBatched";
case DispatchKey::Dense:
return "Dense";
case DispatchKey::Quantized:
return "Quantized";
case DispatchKey::Sparse:
return "Sparse";
case DispatchKey::AutogradFunctionality:
return "AutogradFunctionality";
default:
return "UNKNOWN_TENSOR_TYPE_ID";
}
@ -218,37 +176,76 @@ const char* toString(DispatchKey t) {
std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
return str << toString(rhs);
}
std::ostream& operator<<(std::ostream& str, BackendComponent rhs) {
return str << toString(rhs);
}
DispatchKey getAutogradKeyFromBackend(BackendComponent k) {
// We want this to return an autograd key. We're relying on the fact that
// getAutogradRelatedKeySetFromBackend returns an autograd key +
// ADInplaceOrView, and autograd has higher precedence. The core mapping from
// backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend`
// instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a
// hotpath function, and we want to make sure that it doesn't have to
// construct any DispatchKeySets at runtime.
return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId();
// for a given backend key, return the associated autograd key.
// for non-backend keys, return AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
//
DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
switch (t) {
case DispatchKey::CPU:
return DispatchKey::AutogradCPU;
case DispatchKey::XPU:
return DispatchKey::AutogradXPU;
case DispatchKey::CUDA:
return DispatchKey::AutogradCUDA;
case DispatchKey::XLA:
return DispatchKey::AutogradXLA;
case DispatchKey::Lazy:
return DispatchKey::AutogradLazy;
case DispatchKey::MLC:
return DispatchKey::AutogradMLC;
case DispatchKey::HPU:
return DispatchKey::AutogradHPU;
case DispatchKey::NestedTensor:
return DispatchKey::AutogradNestedTensor;
case DispatchKey::PrivateUse1:
return DispatchKey::AutogradPrivateUse1;
case DispatchKey::PrivateUse2:
return DispatchKey::AutogradPrivateUse2;
case DispatchKey::PrivateUse3:
return DispatchKey::AutogradPrivateUse3;
default:
return DispatchKey::AutogradOther;
}
}
c10::DispatchKey parseDispatchKey(const std::string& k) {
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
{"Undefined", c10::DispatchKey::Undefined},
{"Dense", c10::DispatchKey::Dense},
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},
{"HIP", c10::DispatchKey::HIP},
{"FPGA", c10::DispatchKey::FPGA},
{"ORT", c10::DispatchKey::ORT},
{"XLA", c10::DispatchKey::XLA},
{"MLC", c10::DispatchKey::MLC},
{"Vulkan", c10::DispatchKey::Vulkan},
{"Metal", c10::DispatchKey::Metal},
{"XPU", c10::DispatchKey::XPU},
{"HPU", c10::DispatchKey::HPU},
{"VE", c10::DispatchKey::VE},
{"Lazy", c10::DispatchKey::Lazy},
{"Meta", c10::DispatchKey::Meta},
{"Quantized", c10::DispatchKey::Quantized},
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
{"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
{"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
{"Sparse", c10::DispatchKey::Sparse},
{"SparseCPU", c10::DispatchKey::SparseCPU},
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
{"SparseHIP", c10::DispatchKey::SparseHIP},
{"SparseXPU", c10::DispatchKey::SparseXPU},
{"SparseVE", c10::DispatchKey::SparseVE},
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
{"NestedTensor", c10::DispatchKey::NestedTensor},
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
{"BackendSelect", c10::DispatchKey::BackendSelect},
{"Python", c10::DispatchKey::Python},
{"Named", c10::DispatchKey::Named},
@ -259,8 +256,17 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
c10::DispatchKey::FuncTorchDynamicLayerBackMode},
{"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
{"AutogradOther", c10::DispatchKey::AutogradOther},
{"AutogradFunctionality", c10::DispatchKey::AutogradFunctionality},
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
{"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
{"Tracer", c10::DispatchKey::Tracer},
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
@ -274,41 +280,6 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"TESTING_ONLY_GenericWrapper",
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},
{"HIP", c10::DispatchKey::HIP},
{"XLA", c10::DispatchKey::XLA},
{"MLC", c10::DispatchKey::MLC},
{"XPU", c10::DispatchKey::XPU},
{"HPU", c10::DispatchKey::HPU},
{"Lazy", c10::DispatchKey::Lazy},
{"NestedTensor", c10::DispatchKey::NestedTensor},
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
{"SparseCPU", c10::DispatchKey::SparseCPU},
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
{"SparseHIP", c10::DispatchKey::SparseHIP},
{"SparseXPU", c10::DispatchKey::SparseXPU},
{"SparseVE", c10::DispatchKey::SparseVE},
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
{"Autograd", c10::DispatchKey::Autograd},
{"CompositeImplicitAutograd",
c10::DispatchKey::CompositeImplicitAutograd},

View File

@ -9,98 +9,20 @@
namespace c10 {
// Semantically, each value of BackendComponent identifies a "backend" for our
// dispatch. Some functionalities that we may dispatch to are allowed to
// register different handlers for each backend. The BackendComponent is then
// used to figure out which backend implementation to dispatch to.
// In implementation terms, the backend component identifies a specific "bit" in
// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom
// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to
// functionalities. When we encounter a functionality bit that is known to be
// customizeable per-backend, then we also look at the lower BackendComponent
// bits and take the highest bit to determine which backend's implementation to
// use.
enum class BackendComponent : uint8_t {
// A "backend" is colloquially used to refer to handlers for dispatch
// which actually implement the numerics of an operation in question.
//
// Due to the nature of the enum, these backends are specified in
// an ordered way, but for most backends this order is not semantically
// meaningful (e.g., it's valid to reorder these backends without changing
// semantics). The only situation when backend ordering is meaningful
// is when the backend participates in multiple dispatch with another
// backend; e.g., CPU and CUDA (cuda must have higher priority).
// These keys don't correspond to individual kernels.
// Instead, they represent the backends that are allowed to override specific
// pieces of functionality:
// - dense kernels (e.g. DispatchKey::CPU)
// - sparse kernels (e.g. DispatchKey::SparseCPU)
// - quantized kernels (e.g. DispatchKey::QuantizedCPU)
// - autograd kernels (e.g. DispatchKey::AutogradCPU)
// We reserve space in the runtime operator table for this full cross product
// of
// [backends in this enum] x [keys below that are explicitly marked as having
// per-backend functionality]
InvalidBit = 0,
CPUBit,
CUDABit,
HIPBit,
XLABit,
MLCBit,
XPUBit,
HPUBit,
VEBit,
LazyBit,
PrivateUse1Bit,
PrivateUse2Bit,
PrivateUse3Bit,
// Define an alias to represent end of backend dispatch keys.
// If you add new backend keys after PrivateUse3, please also update it here.
// (But you shouldn't: private use keys should have higher precedence than
// all built-in keys)
EndOfBackendKeys = PrivateUse3Bit,
};
// Semantically, a dispatch key identifies a possible "level" in our
// dispatch, for which a handler may be registered. Each handler corresponds
// to a type of functionality.
// dispatch, for which a handler may be registered. Traditional
// backends like CPU and CUDA get dispatch keys; however, so do
// "wrapping" layers like Variable (for autograd handling).
//
// In implementation terms, the dispatch key identifies a specific "bit" in a
// DispatchKeySet. Higher bit indexes get handled by dispatching first (because
// we "count leading zeros" when we extract the highest priority dispatch
// key.)
//
// Note [DispatchKey Classification]
// This enum actually contains several types of keys, which are explained
// in more detail further down:
// (1) non-customizable backends (e.g. FPGA)
// (2) non-customizable functionalities (e.g. Functionalize)
// (3) functionalized that are customizable per backend (e.g. Dense, Sparse,
// AutogradFunctionality) (4) per-backend instances of customizable
// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g.
// CompositeImplicitAutograd)
//
// Of the categories above, it's important to note:
// (a) which keys are assigned individual bits in a DispatchKeySet
// (b) which keys are assigned individual slots in the runtime operator table
// ("Runtime keys")
//
// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet.
// (1), (2) and (4) all get their own dedicated slots in the runtime operator
// table.
// See Note [DispatchKeySet Internal Representation] for more details.
//
// NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py
enum class DispatchKey : uint16_t {
enum class DispatchKey : uint8_t {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// This is not a "real" functionality, but it exists to give us a "nullopt"
// This is not a "real" tensor id, but it exists to give us a "nullopt"
// element we can return for cases when a DispatchKeySet contains no elements.
// You can think a more semantically accurate definition of DispatchKey is:
//
@ -116,31 +38,24 @@ enum class DispatchKey : uint16_t {
// this will get eliminated, but for now it's convenient)
CatchAll = Undefined,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
// Every value in the enum (up to EndOfFunctionalityKeys)
// corresponds to an individual "functionality" that can be dispatched to.
// This is represented in the DispatchKeySet by assigning each of these enum
// values
// to each of the remaining (64 - len(BackendComponent)) bits.
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ BACKENDS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// A "backend" is colloquially used to refer to handlers for dispatch
// which actually implement the numerics of an operation in question.
//
// Most of these functionalities have a single handler assigned to them,
// making them "runtime keys".
// That map to a single slot in the runtime operator table.
//
// A few functionalities are allowed to be customizable per backend.
// See [Note: Per-Backend Functionality Dispatch Keys] for details.
// See [Note: Per-Backend Functionality Dispatch Keys]
Dense,
// Below are non-extensible backends.
// These are backends that currently don't have their own overrides for
// Autograd/Sparse/Quantized kernels,
// and we therefore don't waste space in the runtime operator table allocating
// space for them.
// If any of these backends ever need to customize, e.g., Autograd, then we'll
// need to add a DispatchKey::*Bit for them.
// Due to the nature of the enum, these backends are specified in
// an ordered way, but for most backends this order is not semantically
// meaningful (e.g., it's valid to reorder these backends without changing
// semantics). The only situation when backend ordering is meaningful
// is when the backend participates in multiple dispatch with another
// backend; e.g., CPU and SparseCPU (sparse must have
// higher priority).
// Here are backends which you think of as traditionally specifying
// how to implement operations on some device.
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
// CUDA]
FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels
@ -152,8 +67,14 @@ enum class DispatchKey : uint16_t {
// - aten/src/ATen/test/extension_backend_test.cpp
ORT,
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan,
Metal,
XPU, // For out of tree Intel's heterogeneous computing plug-in
HPU, // For out of tree & closed source integration of HPU / Habana
VE, // For out of tree & closed source integration of SX-Aurora / NEC
Lazy, // For lazy tensor backends
// A meta tensor is a tensor without any data associated with it. (They
// have also colloquially been referred to as tensors on the "null" device).
@ -162,8 +83,11 @@ enum class DispatchKey : uint16_t {
// tensor with the output shape and dtype, but wouldn't actually add anything.
Meta,
// See [Note: Per-Backend Functionality Dispatch Keys]
Quantized,
// Here are backends which specify more specialized operators
// based on the dtype of the tensor.
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
// This backend is to support custom RNGs; it lets you go
// to a different kernel if you pass in a generator that is not a
@ -182,29 +106,31 @@ enum class DispatchKey : uint16_t {
// the corresponding dense tensors, and must be handled before them.
MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp
// NB: not to be confused with MKLDNN, which is Caffe2 only
// See [Note: Per-Backend Functionality Dispatch Keys]
Sparse,
SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp
SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp
SparseHIP, // TODO: I think this is not actually used, due to Note
// [Masquerading as CUDA]
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC
SparseCsrCPU,
SparseCsrCUDA,
// Note [Non-Customizable Backend Keys]
// Every key above here is considered a "non-customizable backend".
// These are backends that will work correctly with autograd, but
// but currently don't require separate implementations
// for autograd sparse or quantized kernels.
// Any new backends that don't need to be customized should go above here.
// If an existing backend needs to e.g. override autograd, then we can
// consider promoting it into the "BackendComponent" enum
//
// For all intents and purposes from the perspective of DispatchKeySet,
// "non-customizable backend" keys are treated the same way
// as other functionality keys
EndOfNonCustomizableBackends = SparseCsrCUDA,
NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
// Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey]
// To see some example about how to use this, check out ORT
PrivateUse1,
PrivateUse2,
PrivateUse3,
// Define an alias key to represent end of backend dispatch keys.
// If you add new backend keys after PrivateUse3, please also update it here.
// (But you shouldn't: private use keys should have higher precedence than
// all built-in keys)
EndOfBackendKeys = PrivateUse3,
// In some situations, it is not immediately obvious what the correct
// backend for function is, because the function in question doesn't
// have any "tensor" arguments. In this case, a BackendSelect function
@ -307,18 +233,20 @@ enum class DispatchKey : uint16_t {
// AutogradOther key. We can add specific autograd key for those backends
// upon request.
AutogradOther,
// See [Note: Per-Backend Functionality Dispatch Keys]
AutogradFunctionality,
// NestedTensor is an example of something that isn't a "real backend"
// (because it mostly consists of redispatching kernels)
// but it would like to override autograd functionality in C++.
// We can handle cases like this by adding an extra functionality key
// exclusively for handling autograd for NestedTensor.
// lives out of tree at
AutogradCPU,
AutogradCUDA,
AutogradXLA,
AutogradLazy,
AutogradXPU,
AutogradMLC,
AutogradHPU,
AutogradNestedTensor, // lives out of tree at
// https://github.com/pytorch/nestedtensor
AutogradNestedTensor,
// Here are some reserved pre-autograd keys for user-defined backends, see
// Note [Private use DispatchKey]
AutogradPrivateUse1,
AutogradPrivateUse2,
AutogradPrivateUse3,
Tracer,
@ -371,100 +299,9 @@ enum class DispatchKey : uint16_t {
TESTING_ONLY_GenericMode,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
EndOfFunctionalityKeys, // End of functionality keys.
// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
// Here are backends which you think of as traditionally specifying
// how to implement operations on some device.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfDenseBackends,
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
// CUDA]
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
XPU, // For out of tree Intel's heterogeneous computing plug-in
HPU, // For out of tree & closed source integration of HPU / Habana
VE, // For out of tree & closed source integration of SX-Aurora / NEC
Lazy, // For lazy tensor backends
// Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey]
// To see some example about how to use this, check out ORT
PrivateUse1,
PrivateUse2,
PrivateUse3,
EndOfDenseBackends = PrivateUse3,
// ~~~~~~~~~~~~~~ "Quantized" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~ //
// keys starting with an _ are not currently used,
// but are needed to ensure that every backend is indexed correctly.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfQuantizedBackends,
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
_QuantizedHIP,
_QuantizedXLA,
_QuantizedMLC,
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
_QuantizedHPU,
_QuantizedVE,
_QuantizedLazy,
_QuantizedPrivateUse1,
_QuantizedPrivateUse2,
_QuantizedPrivateUse3,
EndOfQuantizedBackends = _QuantizedPrivateUse3,
// ~~~~~~~~~~~~~~ "Sparse" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~ //
// keys starting with an _ are not currently used,
// but are needed to ensure that every backend is indexed correctly.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfSparseBackends,
SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp
SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp
SparseHIP, // TODO: I think this is not actually used, due to Note
// [Masquerading as CUDA]
_SparseXLA,
_SparseMLC,
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
_SparseHPU,
SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC
_SparseLazy,
_SparsePrivateUse1,
_SparsePrivateUse2,
_SparsePrivateUse3,
EndOfSparseBackends = _SparsePrivateUse3,
// ~~~~~~~~~~~~~~ "Autograd" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~ //
// keys starting with an _ are not currently used,
// but are needed to ensure that every backend is indexed correctly.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfAutogradBackends,
AutogradCPU,
AutogradCUDA,
_AutogradHIP,
AutogradXLA,
AutogradMLC,
AutogradXPU,
AutogradHPU,
_AutogradVE,
AutogradLazy,
// Here are some reserved pre-autograd keys for user-defined backends, see
// Note [Private use DispatchKey]
AutogradPrivateUse1,
AutogradPrivateUse2,
AutogradPrivateUse3,
EndOfAutogradBackends = AutogradPrivateUse3,
// If we add a new per-backend functionality key that has higher priority
// than Autograd, then this key should be updated.
EndOfRuntimeBackendKeys = EndOfAutogradBackends,
NumDispatchKeys, // Sentinel, end of runtime keys.
// ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// Note [Alias Dispatch Keys]
// Alias dispatch keys are synthetic dispatch keys which map to multiple
// runtime dispatch keys. Alisa keys have precedence, but they are always
// lower precedence than runtime keys. You can register a kernel to an
@ -484,7 +321,6 @@ enum class DispatchKey : uint16_t {
// Define an alias key to represent end of alias dispatch keys.
// If you add new alias keys after Autograd, please also update it here.
StartOfAliasKeys = Autograd,
EndOfAliasKeys = CompositeExplicitAutograd, //
// ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
@ -524,83 +360,54 @@ enum class DispatchKey : uint16_t {
// built-in autograd formulas for operators are not appropriate.
static_assert(
(static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) +
static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys)) <= 64,
"The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)"
" both map to backend and functionality bits"
" into a 64-bit bitmask; you must have less than 64 total entries between them");
// Check if a DispatchKey is an alias mapping to other runtime keys.
constexpr bool isAliasDispatchKey(DispatchKey k) {
return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys;
}
// [Note: Per-Backend Functionality Dispatch Keys]
// Check if a DispatchKey is a per-backend functionality key
// Any functionalities that can be customized per-backend should be added here.
// These keys correspond to functionalities that can be customized indivually
// per backend. While they only take up one bit in the `DispatchKeySet` bitset,
// they map to (# backends) slots in the operator table.
// Each of these keys also has a separate set of "runtime keys" in the dispatch
// key enum, per backend, which *do* map to the individual operator table slots.
// For example, the "Sparse" key maps to an individual bit in the
// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual
// slots in the runtime operator table.
constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality) {
return true;
} else {
return false;
}
}
// Note that this includes Undefined in the total count.
// BUT EndOfFunctionalityKeys is its own (placeholder) key.
// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3.
// In the above example, there are 3 total functionality keys.
constexpr uint8_t num_functionality_keys =
static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys);
// Note [No More Than 16 Backends]
// Search for this note to find places in the code where the "no more than 16
// backends" invariant is baked in.
static_assert(
static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) <= 16,
"BackendComponent currently only supports <= 16 backends. If we really need to extend this, \
there are a few places where this invariant is baked in");
constexpr uint8_t numPerBackendFunctionalityKeys() {
uint8_t count = 0;
for (uint8_t k = 0; k <= num_functionality_keys; ++k) {
if (isPerBackendFunctionalityKey(static_cast<DispatchKey>(k)))
++count;
}
return count;
}
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
// See [Note: Trimmed Mobile Dispatch Keys]
constexpr uint8_t num_backends = 1; // Only CPU
constexpr uint16_t num_runtime_entries = 8;
/**
* The method below maps the dispatch key in the enum DispatchKey to an
* integer index in the dispatchTable_ array in OperatorEntry. The array
* is trimmed for mobile to reduce peak memory usage since it's
* unnecessary to reserve additional space for dispatch keys that will
* never be used on mobile.
*/
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
switch (dk) {
case DispatchKey::Undefined:
return 0;
case DispatchKey::CPU:
return 1;
case DispatchKey::QuantizedCPU:
return 2;
case DispatchKey::SparseCPU:
return 3;
case DispatchKey::BackendSelect:
return 4;
case DispatchKey::ADInplaceOrView:
return 5;
case DispatchKey::AutogradOther:
return 6;
case DispatchKey::AutogradCPU:
return 7;
case DispatchKey::NumDispatchKeys: // Sentinel, end of runtime keys.
return 8;
default:
return -1;
}
}
#else
constexpr uint8_t num_backends =
static_cast<uint8_t>(BackendComponent::EndOfBackendKeys);
constexpr uint16_t num_runtime_entries = num_functionality_keys +
(numPerBackendFunctionalityKeys() * (num_backends - 1));
/**
* For the server use-case, make this a simple pass-through.
*/
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
return static_cast<int>(dk);
}
#endif
// See Note [No More Than 16 Backends]
constexpr uint16_t full_backend_mask =
(static_cast<uint16_t>(1) << num_backends) - 1;
C10_API const char* toString(DispatchKey);
C10_API const char* toString(BackendComponent);
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
C10_API std::ostream& operator<<(std::ostream&, BackendComponent);
C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k);
C10_API DispatchKey getAutogradKeyFromBackend(DispatchKey t);
// Parses a string into a dispatch key.
// If the string cannot be correctly parsed, throws an exception.
@ -613,86 +420,10 @@ C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
// torch::dispatch(torch::kCPU, ...) is also valid.
constexpr DispatchKey kAutograd = DispatchKey::Autograd;
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// This function relies on the invariant that the dispatch keys between
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr BackendComponent toBackendComponent(DispatchKey k) {
if (k >= DispatchKey::StartOfDenseBackends &&
k <= DispatchKey::EndOfDenseBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends));
} else if (
k >= DispatchKey::StartOfQuantizedBackends &&
k <= DispatchKey::EndOfQuantizedBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends));
} else if (
k >= DispatchKey::StartOfSparseBackends &&
k <= DispatchKey::EndOfSparseBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
} else if (
k >= DispatchKey::StartOfAutogradBackends &&
k <= DispatchKey::EndOfAutogradBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfAutogradBackends));
} else {
return BackendComponent::InvalidBit;
}
// Check if a DispatchKey is an alias mapping to other runtime keys.
inline bool isAliasDispatchKey(DispatchKey k) {
return k > DispatchKey::NumDispatchKeys && k <= DispatchKey::EndOfAliasKeys;
}
constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
if (k <= DispatchKey::EndOfFunctionalityKeys) {
return k;
} else if (k <= DispatchKey::EndOfDenseBackends) {
return DispatchKey::Dense;
} else if (k <= DispatchKey::EndOfQuantizedBackends) {
return DispatchKey::Quantized;
} else if (k <= DispatchKey::EndOfSparseBackends) {
return DispatchKey::Sparse;
} else if (k <= DispatchKey::EndOfAutogradBackends) {
return DispatchKey::AutogradFunctionality;
} else {
return DispatchKey::Undefined;
}
}
// Given (DispatchKey::Dense, DispatchKey::CUDABit), returns DispatchKey::CUDA
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// This function relies on the invariant that the dispatch keys between
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
DispatchKey functionality_k,
BackendComponent backend_k) {
if (functionality_k == DispatchKey::Dense) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::Sparse) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::Quantized) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::AutogradFunctionality) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfAutogradBackends) +
static_cast<uint8_t>(backend_k));
}
return DispatchKey::Undefined;
}
} // namespace c10
namespace torch {

View File

@ -1,29 +1,37 @@
#include <c10/core/DispatchKeySet.h>
#include <c10/util/irange.h>
namespace c10 {
// backend_dispatch_keyset includes all dispatch keys that map to backends.
// backend_dispatch_keyset should include all runtime backend keys.
// Alias key DispatchKey::CompositeExplicitAutograd maps to
// backend_dispatch_keyset
constexpr DispatchKeySet backend_dispatch_keyset =
autogradother_backends | DispatchKeySet(DispatchKey::Dense);
// backend_dispatch_keyset NestedTensor has been explicitly removed due to
// incompatibility with some kernels, such as structured kernels, that use the
// DefaultBackend key.
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKeySet({
DispatchKey::CPU,
DispatchKey::CUDA,
DispatchKey::XLA,
DispatchKey::Lazy,
DispatchKey::XPU,
DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3,
DispatchKey::MLC,
DispatchKey::HPU,
DispatchKey::ORT,
DispatchKey::Meta,
});
bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined
// See Note [No Alias Keys in DispatchKeySet]
&& !isAliasDispatchKey(t)
// Note [NestedTensor Not Included in Backend Keys]
// NestedTensor has been explicitly removed from the "backend keyset" due
// to incompatibility with some kernels, so we don't want it to be
// included in CompositeImplicitAutograd or CompositeExplicitAutograd
// kernels.
&& t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
&& !isAliasDispatchKey(t) && backend_dispatch_keyset.has(t);
}
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
// maps to [math_dispatch_keyset x full_backend_mask]
// maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset =
backend_dispatch_keyset | autograd_dispatch_keyset;
@ -31,12 +39,7 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
// See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// That's why we OR it with a mask of the backend bits here.
// getRuntimeDispatchKeySet() expects to return a keyset of runtime
// dispatch keys, like AutogradCPU, but that requires having backend bits.
return autograd_dispatch_keyset |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
return autograd_dispatch_keyset;
case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd:
@ -50,13 +53,11 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
return autograd_dispatch_keyset.has(toFunctionalityKey(k));
return autograd_dispatch_keyset.has(k);
case DispatchKey::CompositeImplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor && math_dispatch_keyset.has(k);
return math_dispatch_keyset.has(k);
case DispatchKey::CompositeExplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
return backend_dispatch_keyset.has(k);
default:
return t == k;
}
@ -78,6 +79,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradHPU:
return DispatchKeySet(DispatchKey::HPU);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1:
@ -93,6 +96,23 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
}
}
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
switch (t) {
case DispatchKey::CPU:
return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CUDA:
case DispatchKey::XLA:
return DispatchKeySet(DispatchKey::AutocastCUDA);
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
return DispatchKeySet(
{DispatchKey::ADInplaceOrView, getAutogradKeyFromBackend(t)});
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
}
@ -109,167 +129,18 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
return os;
}
os << "DispatchKeySet(";
DispatchKey tid;
bool first = true;
for (auto k : ts) {
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
if (!first) {
os << ", ";
}
os << k;
os << tid;
ts = ts.remove(tid);
first = false;
}
os << ")";
return os;
}
DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
TORCH_INTERNAL_ASSERT(next_functionality_ >= num_backends);
TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends);
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
uint64_t masked_functionality_bits =
llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
uint64_t masked_backend_bits =
llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
*data_ptr_;
uint64_t first_functionality_idx =
llvm::findFirstSet(masked_functionality_bits);
uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
// If there are no keys, set to end iterator value
if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
next_functionality_ == iterator::end_iter_mask_val) {
// Set up state to be the same as end()
next_functionality_ = iterator::end_iter_mask_val;
current_dispatchkey_idx_ = iterator::end_iter_key_val;
next_backend_ = 0;
current_backendcomponent_idx_ = iterator::end_iter_key_val;
return *this;
}
// The +1 is because of DispatchKey::Undefined and
// BackendComponent::InvalidBit
auto new_next_functionality = first_functionality_idx + 1;
auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
// and the -num_backends is because the first <num_backends> bits in the
// keyset are not Dispatch Keys.
auto next_dispatchkey_idx = new_next_functionality - num_backends;
// If the current functionality bit is a per-backend bit, we need special
// handling
if (isPerBackendFunctionalityKey(
static_cast<DispatchKey>(next_dispatchkey_idx))) {
// case 1: if the current backend is undefined, then there is no valid
// backend instance of this functionality key so we can skip it.
if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// increment the functionality mask so we skip the current functionality
// bit on the next increment.
next_functionality_ = new_next_functionality;
++(*this);
return *this;
}
// Otherwise, at this point we know what the current backend and
// functionality bits are.
current_dispatchkey_idx_ = next_dispatchkey_idx;
current_backendcomponent_idx_ = new_backendcomponent_idx;
// Next, we need to set up the masks for the next increment.
uint64_t next_backendcomponent_bits =
llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
full_backend_mask & *data_ptr_;
uint64_t next_backendcomponent_idx =
llvm::findFirstSet(next_backendcomponent_bits);
if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// case 2: the current backend is valid, but there is not another backend
// in the keyset. In this case, we need to bump the functionality mask and
// reset the backend mask for the next increment
next_functionality_ = new_next_functionality;
next_backend_ = 0;
} else {
// case 3: we have another backend to iterate over. We want to iterate
// over the same functionality bit next time, but a different backend bit.
next_backend_ = first_backendcomponent_idx + 1;
}
} else {
// Functionality bits that aren't per backend are simpler to handle. We can
// ignore the backend bits.
TORCH_INTERNAL_ASSERT(next_backend_ == 0);
current_dispatchkey_idx_ = next_dispatchkey_idx;
next_functionality_ = new_next_functionality;
}
return *this;
}
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks() {
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
offsets_and_masks;
// manualy set the first entry, which corresponds to Undefined.
offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
// loop through every functionality key (aside from Undefined).
for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
// functionality_idx should be Dense -> 1, ...
auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
auto k = static_cast<DispatchKey>(functionality_idx);
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
// [Note: Trimmed Mobile Dispatch Keys]
uint16_t mask = 0;
uint16_t offset = 0;
switch (k) {
case DispatchKey::Undefined:
offset = 0;
case DispatchKey::CPU:
offset = 1;
case DispatchKey::QuantizedCPU:
offset = 2;
case DispatchKey::SparseCPU:
offset = 3;
case DispatchKey::BackendSelect:
offset = 4;
case DispatchKey::ADInplaceOrView:
offset = 5;
case DispatchKey::AutogradOther:
offset = 6;
case DispatchKey::AutogradCPU:
offset = 7;
default:
// All other keys which are unsupported on mobile will get sent
// to the undefined kernel, causing them to error.
offset = 0;
}
offsets_and_masks[functionality_idx] =
FunctionalityOffsetAndMask(offset, 0);
}
#else
// If the previous functionality was not per-backend, then we can just
// increment the previous offset. Otherwise, the next offset =
// previous_offset + num_backends.
auto next_offset = prev_offset_and_mask.offset +
(prev_offset_and_mask.mask == 0 ? 1 : num_backends);
// the mask is used in the runtime index calculation to find the offset of
// the backend. For non-per-backend functionalities, this offset should
// always be 0. Otherwise, we need to get the index of the backend (which we
// can do using a backend mask).
auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
offsets_and_masks[functionality_idx] =
FunctionalityOffsetAndMask(next_offset, next_mask);
}
// Sanity check that the computed offset index of the last functionality key
// is correct. This assumes that the highest priority functionality key is not
// per backend.
TORCH_INTERNAL_ASSERT(
offsets_and_masks[num_functionality_keys - 1].offset ==
(num_runtime_entries - 1),
"num_runtime_entries: ",
num_runtime_entries,
"last_offset: ",
offsets_and_masks[num_functionality_keys - 1].offset);
#endif
return offsets_and_masks;
}
} // namespace c10

View File

@ -1,4 +1,5 @@
#pragma once
#include <c10/core/DispatchKey.h>
#include <c10/util/Exception.h>
#include <c10/util/Metaprogramming.h>
@ -7,147 +8,29 @@
namespace c10 {
struct FunctionalityOffsetAndMask {
// empty constructor shouldn't be used; only needed to initialize
// the array before populating it.
FunctionalityOffsetAndMask() {}
FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
: offset(offset), mask(mask) {}
// This needs to big enough to cover the size of the operator table.
uint16_t offset;
// See Note [No More Than 16 Backends]
// This mask needs to be big enough to mask all of the backend bits.
// We probably don't ever want to have more than 16 backend bits, so uint16_t
// should be enough.
uint16_t mask;
};
static_assert(
c10::num_runtime_entries < 65536,
"The dispatcher currently only supports up to 2^16 runtime entries");
C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks();
C10_ALWAYS_INLINE static const std::
array<FunctionalityOffsetAndMask, num_functionality_keys>&
offsetsAndMasks() {
static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
return offsets_and_masks_;
}
// A representation of a set of DispatchKeys. A DispatchKeySet contains both
// "functionality" bits and "backend bits", and every tensor holds its own
// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
// keyset on every input tensor, oring them together, and dispatching to a
// specific piece of functionality. The functionality bits are *ordered*. When
// multiple functionality bits are set, we use the highest priority
// functionality. Similarly, multiple backend bits can theoretically be set if
// you call an operator with multiple tensors from difference devices (e.g. CPU
// and CUDA), although support for mixed device dispatch is limited (the only
// kernels that gracefully handle mixed device inputs for now are cuda kernels
// that take in a scalar cpu tensor).
// A representation of a set of DispatchKeys. A tensor may have multiple
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
// DispatchKeySet specifies what type ids apply. The internal representation is
// as a 64-bit bit set (this means only 64 tensor type ids are supported).
//
// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
// "what is the highest priority DispatchKey in the set"? (The set itself is
// not ordered; two sets with the same ids will always have the ids ordered in
// the same way.)
// Note that DispatchKeys are ordered; thus, we can ask questions like "what is
// the highest priority DispatchKey in the set"? (The set itself is not
// ordered; two sets with the same ids will always have the ids ordered in the
// same way.)
//
// Note [DispatchKeySet Internal Representation]
// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
// that get passed around at runtime.
// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
// and individual dispatch keys.
// At the moment, there are no nontrivial uses of this set; tensors are always
// singletons. In the near future, this set will represent variable? + tensor
// type id. In the far future, it will be requires grad? + profiling? +
// tracing? + lazy? + tensor type id.
//
// First: why do we have this distinction, and why not map every dispatch key
// directly to a bit? This is mostly because we have several types of
// functionalities that different backends would like to customize. For example,
// we have:
// - "Dense": CPU, CUDA, XLA, ... (~12 keys)
// - "Sparse": SparseCPU, SparseCUDA, ...
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
// The problem is that total number of keys grows quadratically with [#
// backends] x [# functionalities], making it very difficult to map each key
// directly to a bit in a bitset without dramatically increasing the size of the
// bitset over time.
// (The difference between variable and requires grad, is that
// there are currently three states a tensor can be:
// 1. Not a variable
// 2. Variable with requires_grad=False
// 3. Variable with requires_grad=True
// Eventually, we want to kill state (1), and only dispatch to autograd
// handling code if one of the inputs requires grad.)
//
// The two enums (BackendComponent and DispatchKey) can be divided roughly into
// 5 categories.
//
// (1) "Building block" keys
// (a) backends: jEverything in the BackendComponent enum (e.g. CPUBit,
// CUDABIt) (b) functionalities: (per-backend) functionality-bit DispatchKeys
// (e.g. AutogradFunctionality, Sparse, Dense)
// (2) "Runtime" keys
// (a) "non-customizable backends" (e.g. FPGA)
// (b) "non-customizable functionalities" (e.g. Functionalize)
// (c) "per-backend instances of customizable functionalities" (e.g. CPU,
// SparseCPU, AutogradCPU)
// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
//
// (1) Building block keys always correspond to individual bits in a
// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
// runtime keys. e.g.
// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
// DispatchKey::Dense});
// // The keyset has the runtime dense-cpu key.
// dense_cpu_ks.has(DispatchKey::CPU);
// // And it contains the building block keys too.
// dense_cpu_ks.has(DispatchKey::CPUBit);
// dense_cpu_ks.has(DispatchKey::Dense);
//
// Not every backend and not every functionality counts as a "building block
// key". This is mostly to give us more levers to pull in the design space.
// Backend keys and functionality keys that count as "building blocks" will
// contribute to a full cross product of functionality that can be overriden.
//
// For example, right now we have at least 12 "backend" building blocks (CPU,
// CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense,
// Sparse, Quantized, AutogradFunctionality, ...). These keys together allow
// every dispatcher operator to be customized in up to 12*4 different ways. Each
// of those requires a slot in the operator table of every dispatcher operator.
// Not every piece of functionality necessarily needs to be customizeable
// per-backend, and not every backend necessarily needs to be able to customize
// every type of functionality.
//
//
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
// dispatch table, and you can directly register kernels to a runtime dispatch
// key.
//
// For per-backend functionalities like "Dense" or "AutogradFunctionality",
// you can think of the corresponding runtime dispatch keys as "instances" of
// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
// runtime instances of the "Dense" building block key.
// (2a) and (2b) are represented identically in the DispatchKeySet logic:
// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
// customizeable per backend.
// In order to do so, we'd need to promote it to a per-backend functionality
// "building block" key.
// - non-customizeable backends (e.g. FPGA) can NOT customize existing
// functionality like Sparse, Autograd, etc.
// In order to do so, we'd need to promote it to a backend "building block"
// key.
//
// In both cases, these keys directly correspond to runtime slots in the
// operator table.
//
//
// (3) "Alias" keys
// See Note [Alias Dispatch Keys]
//
// Final note: for anyone making future changes to the Dispatcher +
// DispatchKeySet internals, there's a closed PR with a basic
// python-implementation of the Dispatcher that might be useful in quickly
// testing out and validating changes. See it at
// https://github.com/pytorch/pytorch/pull/68743
// An undefined tensor is one with an empty tensor type set.
class DispatchKeySet final {
public:
@ -158,146 +41,29 @@ class DispatchKeySet final {
// NB: default constructor representation as zero is MANDATORY as
// use of DispatchKeySet in TLS requires this.
constexpr DispatchKeySet() : repr_(0) {}
constexpr DispatchKeySet(Full)
: repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
: repr_(std::numeric_limits<decltype(repr_)>::max()) {}
constexpr DispatchKeySet(FullAfter, DispatchKey t)
// LSB after t are OK, but not t itself.
// "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
// Quantized > Dense). But backends don't really have an ordering.
// Therefore, we're enforcing that FullAfter can only be used on
// "functionality" keys.
: repr_(
(1ULL
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
1)) -
1) {}
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {}
// Public version of DispatchKeySet(uint64_t) API; external users
// must be explicit when they do this!
constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
constexpr explicit DispatchKeySet(BackendComponent k) {
if (k == BackendComponent::InvalidBit) {
repr_ = 0;
} else {
repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
}
}
constexpr explicit DispatchKeySet(DispatchKey k) {
if (k == DispatchKey::Undefined) {
// Case 1: handle Undefined specifically
repr_ = 0;
} else if (k <= DispatchKey::EndOfFunctionalityKeys) {
// Case 2: handle "functionality-only" keys
// These keys have a functionality bit set, but no backend bits
// These can technically be either:
// - valid runtime keys (e.g. DispatchKey::AutogradOther,
// DispatchKey::FuncTorchBatched, etc)
// - "building block" keys that aren't actual runtime keys (e.g.
// DispatchKey::Dense or Sparse)
uint64_t functionality_val = 1ULL
<< (num_backends + static_cast<uint8_t>(k) - 1);
repr_ = functionality_val;
} else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
// Case 3: "runtime" keys that have a functionality bit AND a backend bit.
// First compute which bit to flip for the functionality.
auto functionality_k = toFunctionalityKey(k);
// The - 1 is because Undefined is technically a "functionality" that
// doesn't show up in the bitset. So e.g. Dense is technically the second
// functionality, but the lowest functionality bit.
uint64_t functionality_val = 1ULL
<< (num_backends + static_cast<uint8_t>(functionality_k) - 1);
// then compute which bit to flip for the backend
// Case 4a: handle the runtime instances of "per-backend functionality"
// keys For example, given DispatchKey::CPU, we should set:
// - the Dense functionality bit
// - the CPUBit backend bit
// first compute which bit to flip for the backend
auto backend_k = toBackendComponent(k);
uint64_t backend_val = backend_k == BackendComponent::InvalidBit
? 0
: 1ULL << (static_cast<uint8_t>(backend_k) - 1);
repr_ = functionality_val + backend_val;
} else {
// At this point, we should have covered every case except for alias keys.
// Technically it would be possible to add alias dispatch keys to a
// DispatchKeySet, but the semantics are a little confusing and this
// currently isn't needed anywhere.
repr_ = 0;
}
}
constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
uint64_t repr = 0;
for (auto k : ks) {
repr |= DispatchKeySet(k).repr_;
}
return repr;
}
constexpr uint64_t backend_bits_to_repr(
std::initializer_list<BackendComponent> ks) {
uint64_t repr = 0;
for (auto k : ks) {
repr |= DispatchKeySet(k).repr_;
}
return repr;
}
explicit constexpr DispatchKeySet(DispatchKey t)
: repr_(
t == DispatchKey::Undefined
? 0
: 1ULL << (static_cast<uint8_t>(t) - 1)) {}
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
: repr_(keys_to_repr(ks)) {}
explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
// Note: for some reason, putting this logic directly in the constructor
// appears to fail to compile on CUDA 10.1.
// See an example internal failure at
// https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
: repr_(backend_bits_to_repr(ks)) {}
: repr_(0) {
for (auto k : ks) {
repr_ |= DispatchKeySet(k).repr_;
}
}
// Test if a DispatchKey is in the set
inline bool has(DispatchKey t) const {
bool inline has(DispatchKey t) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
return has_all(DispatchKeySet(t));
}
constexpr bool has_backend(BackendComponent t) const {
return has_all(DispatchKeySet(t));
}
// Test if a DispatchKey is in the set
// Given a DispatchKeySet of functionality keys and (potentially) backend
// keys, tests if all of them are in the current set.
constexpr bool has_all(DispatchKeySet ks) const {
return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
}
// Given a DispatchKeySet of functionality keys and (potentially) backend
// keys, tests if any of them are in the current set. This could technically
// be pretty easily implemented using has(). It is strictly a perf
// optimization though. There are many places in the code base where we want
// to test for multiple functionality keys together. HOWEVER, runtime
// per-backend functionality keys aren't allowed to be used with this
// function, because you can end up with weird results. e.g.
// DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
// would return true.
inline bool has_any(DispatchKeySet ks) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
// Either there are no backend bits in the input keyset
((ks.repr_ & full_backend_mask) == 0) ||
// or there are no per-backend-functionality bits
// See [Note: Per-Backend Functionality Dispatch Keys]
((ks &
DispatchKeySet({
DispatchKey::Dense,
DispatchKey::Quantized,
DispatchKey::Sparse,
DispatchKey::AutogradFunctionality,
})
.repr_) == 0));
return static_cast<bool>((repr_ & ks.repr_) != 0);
return static_cast<bool>(repr_ & DispatchKeySet(t).repr_);
}
// Test if DispatchKeySet is a superset of ks.
bool isSupersetOf(DispatchKeySet ks) const {
@ -308,64 +74,31 @@ class DispatchKeySet final {
return DispatchKeySet(repr_ | other.repr_);
}
// Perform set intersection
constexpr DispatchKeySet operator&(DispatchKeySet other) const {
DispatchKeySet operator&(DispatchKeySet other) const {
return DispatchKeySet(repr_ & other.repr_);
}
// Compute the set difference self - other,
// but ONLY for the functionality keys.
// Any backend bits set on self will remain unchanged.
// See Note [Removing keys from DispatchKeySet Only Affects Functionality
// Keys]
// Compute the set difference self - other
DispatchKeySet operator-(DispatchKeySet other) const {
return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
return DispatchKeySet(repr_ & ~other.repr_);
}
// Compute self ^ other
constexpr DispatchKeySet operator^(DispatchKeySet other) const {
return DispatchKeySet(repr_ ^ other.repr_);
}
// Perform set equality
bool operator==(DispatchKeySet other) const {
return repr_ == other.repr_;
}
bool operator!=(DispatchKeySet other) const {
return repr_ != other.repr_;
}
// Add a DispatchKey to the DispatchKey set. Does NOT mutate,
// returns the extended DispatchKeySet!
C10_NODISCARD DispatchKeySet add(DispatchKey t) const {
return *this | DispatchKeySet(t);
}
C10_NODISCARD DispatchKeySet add(DispatchKeySet ks) const {
return *this | ks;
}
// Remove a DispatchKey from the DispatchKey set.
// This is generally not an operation you should be doing
// (it's used to implement the printing overload, operator<<)
//
// Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
// Only functionality bits are allowed to be removed from a keyset.
// For now, we're only allowing removal of "functionality bits" from the
// keyset, which is specifically needed by the fallthrough key calculation
// logic. Why is removing backend bits problematic? Consider this example:
//
// DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
// DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
// DispatchKeySet([DispatchKey.CPU,
// DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
//
// What do we want to happen?
// Technically, we'd like it to be true that after removal,
// the first keyset still has the CUDA dispatch key while the second doesn't.
// Unfortunately there's no way to represent that, because the two keysets are
// represented the same way internally: functionality bits: Autograd, Dense
// backend bits: CPU, CUDA
//
// Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
// bit from the bitset.
constexpr DispatchKeySet remove(DispatchKey t) const {
return DispatchKeySet(
repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
// Remove a DispatchKey from the DispatchKey set. This is
// generally not an operation you should be doing (it's
// used to implement operator<<)
C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const {
return DispatchKeySet(repr_ & ~DispatchKeySet(t).repr_);
}
// Is the set empty? (AKA undefined tensor)
bool empty() const {
@ -374,78 +107,22 @@ class DispatchKeySet final {
uint64_t raw_repr() {
return repr_;
}
DispatchKey highestFunctionalityKey() const {
auto functionality_idx = indexOfHighestBit();
// This means that none of the functionality bits were set.
if (functionality_idx < num_backends)
return DispatchKey::Undefined;
// The first num_backend bits in the keyset don't correspond to real
// dispatch keys.
return static_cast<DispatchKey>(functionality_idx - num_backends);
}
// This is similar like toBackendComponent(DispatchKey), but less restrictive.
// toBackendComponent() errors out if the key that it was passed has no
// backend bits, which is useful for error checking. We need a version of that
// here that can also handle "fake" backends like FPGA, because they need to
// map to the AutogradOther key. For those backends, we return
// BackendComponent::InvalidBit.
BackendComponent highestBackendKey() const {
// mask to mask out functionality bits
auto backend_idx =
DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
// all zeros across the backend bits means that no backend bits are set.
if (backend_idx == 0)
return BackendComponent::InvalidBit;
return static_cast<BackendComponent>(backend_idx);
}
// returns the DispatchKey of highest priority in the set.
// Return the type id in this set with the highest priority (i.e.,
// is the largest in the DispatchKey enum). Intuitively, this
// type id is the one that should handle dispatch (assuming there
// aren't any further exclusions or inclusions).
DispatchKey highestPriorityTypeId() const {
auto functionality_k = highestFunctionalityKey();
if (isPerBackendFunctionalityKey(functionality_k)) {
return toRuntimePerBackendFunctionalityKey(
functionality_k, highestBackendKey());
}
return functionality_k;
// TODO: If I put Undefined as entry 64 and then adjust the
// singleton constructor to shift from the right, we can get rid of the
// subtraction here. It's modestly more complicated to get right so I
// didn't do it for now.
return static_cast<DispatchKey>(64 - llvm::countLeadingZeros(repr_));
}
// Returns the index of the most-significant bit in the keyset.
// This is used to as part of the calculation into the operator table to get:
// - the highest "functionality" bit in the keyset.
// - the highest "backend" bit in the keyset.
uint8_t indexOfHighestBit() const {
return 64 - llvm::countLeadingZeros(repr_);
}
// returns the index in the operator table of highest priority key in the the
// keyset Note that we could in theory implement this using
// highestPriorityTypeId(), but this code is very hotpath and we can do it
// faster without it.
uint64_t getDispatchTableIndexForDispatchKeySet() const {
auto functionality_idx =
DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
auto offset_and_mask = offsetsAndMasks()[functionality_idx];
// Mask the functionality bits out first, then right-shift by 1.
// right-shifting by 1 because everything is zero-indexed.
// E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
// give us an offset of 1, etc.
auto backend_idx =
DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
return offset_and_mask.offset + backend_idx;
}
// returns the "index" of the highest priority backend in the keyset.
// This is pretty similar to getBackendKey(), but:
// - It's hotpath code (part of the runtime bitset calculation)
// - I's returns an integer index, not an enum value
// - Everything is shifted to the right by 1.
// BackendComponent::InvalidBit is technically the lowest enum value,
// but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
// etc.
uint64_t getBackendIndex() const {
return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
DispatchKey highestPriorityBackendTypeId() const {
return (*this &
((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1))
.highestPriorityTypeId();
}
private:
@ -453,47 +130,42 @@ class DispatchKeySet final {
uint64_t repr_ = 0;
public:
// STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
// in the set. The iterator is only invalidated by the destruction of the
// underlying DispatchKeySet as the iterator stores a pointer to the raw
// representation of the DispatchKeySet. Note: When we encounter a per-backend
// functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
// in the keyset, for that functionality. For example, if the next
// functionality key to iterate over is Autograd, and the backend bits in the
// keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
// then the next two keys we return will be DispatchKey::AutogradCPU,
// DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
// CUDA in DispatchKey.h).
// STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the
// set. The iterator is only invalidated by the destruction of the underlying
// DispatchKeySet as the iterator stores a pointer to the raw representation
// of the DispatchKeySet.
class iterator {
public:
using self_type = iterator;
using iterator_category = std::input_iterator_tag;
using value_type = DispatchKey;
using difference_type = ptrdiff_t;
// final mask value should mask out the entire keyset
static const uint8_t end_iter_mask_val =
num_backends + num_functionality_keys;
// final key value should be the last DispatchKey
static const uint8_t end_iter_key_val = num_functionality_keys;
// current_dispatchkey_idx_ will iterate through all functionality bits.
// current_backendcomponent_idx_ will iterate through all backend bits.
explicit iterator(
const uint64_t* data_ptr,
uint8_t next_functionality = num_backends,
uint8_t next_backend = 0)
: data_ptr_(data_ptr),
next_functionality_(next_functionality),
next_backend_(next_backend),
// These are in an invalid state at construction time, and set by the
// first increment call
current_dispatchkey_idx_(end_iter_key_val),
current_backendcomponent_idx_(end_iter_key_val) {
explicit iterator(const uint64_t* data_ptr, uint8_t i = 0)
: data_ptr_(data_ptr), i_(i) {
// Go to the first key in the set
++(*this);
}
C10_API self_type& operator++();
self_type& operator++() {
TORCH_INTERNAL_ASSERT(
i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
uint64_t masked_data = llvm::maskTrailingZeros<uint64_t>(i_) & *data_ptr_;
uint64_t firstKeyIndex = llvm::findFirstSet(masked_data);
// If there are no keys, set to end iterator value
if (firstKeyIndex == std::numeric_limits<uint64_t>::max() ||
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
i_ = static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
return *this;
}
i_ = static_cast<uint8_t>(firstKeyIndex) + 1;
return *this;
}
self_type operator++(int) {
self_type previous_iterator = *this;
@ -502,50 +174,18 @@ class DispatchKeySet final {
}
bool operator==(const self_type& rhs) const {
return next_functionality_ == rhs.next_functionality_ &&
current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
next_backend_ == rhs.next_backend_ &&
current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
return i_ == rhs.i_;
}
bool operator!=(const self_type& rhs) const {
return next_functionality_ != rhs.next_functionality_ ||
current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
next_backend_ != rhs.next_backend_ ||
current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
return i_ != rhs.i_;
}
DispatchKey operator*() const {
auto functionality_key =
static_cast<DispatchKey>(current_dispatchkey_idx_);
if (isPerBackendFunctionalityKey(functionality_key)) {
auto next_key = toRuntimePerBackendFunctionalityKey(
functionality_key,
static_cast<BackendComponent>(current_backendcomponent_idx_));
// We expect all of the Dense, Sparse, Quantized, and Autograd keys to
// be ordered the same way with respect to their backends
TORCH_INTERNAL_ASSERT(
toBackendComponent(next_key) ==
static_cast<BackendComponent>(current_backendcomponent_idx_),
"Tried to map functionality key ",
toString(functionality_key),
" and backend bit ",
toString(
static_cast<BackendComponent>(current_backendcomponent_idx_)),
" to a runtime key, but ended up with ",
toString(next_key),
". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
" Please double check that enum for inconsistencies.");
return next_key;
} else {
return functionality_key;
}
return static_cast<DispatchKey>(i_);
}
private:
const uint64_t* data_ptr_;
uint8_t next_functionality_;
uint8_t next_backend_;
uint8_t current_dispatchkey_idx_;
uint8_t current_backendcomponent_idx_;
uint8_t i_;
};
public:
@ -555,35 +195,31 @@ class DispatchKeySet final {
return iterator(&repr_);
}
// We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
// this as the end iterator.
// We do not need to iterate beyond NumDispatchKeys so we will treat this as
// the end iterator. NumDispatchKeys will always be strictly less than 64.
iterator end() const {
return iterator(&repr_, iterator::end_iter_mask_val);
return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
}
};
C10_API std::string toString(DispatchKeySet);
C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
C10_API inline uint64_t getDispatchTableIndexForDispatchKey(DispatchKey k) {
return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
}
// Alias key DispatchKey::Autograd maps to
// (autograd_dispatch_keyset x full_backend_mask)
// autograd_dispatch_keyset should include all runtime autograd keys.
// Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset.
// NB: keys in this set also get associated with CompositeImplicitAutograd
//
// Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// We don't want to include any backend bits (BackendComponent::CPUBit, etc)
// directly in autograd_dispatch_keyset.
// Why? keysets like autograd_dispatch_keyset are commonly used to remove
// autograd keys from a DispatchKeySet throughout the code base. However, you
// are only allowed to remove functionality bits from a keyset, not backend
// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
// Keys] for details. To be consistent and avoid confusion, we're explicitly
// setting up autograd_dispatch_keyset to not have any backend bits.
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
DispatchKey::AutogradFunctionality,
DispatchKey::AutogradCPU,
DispatchKey::AutogradCUDA,
DispatchKey::AutogradXLA,
DispatchKey::AutogradLazy,
DispatchKey::AutogradNestedTensor,
DispatchKey::AutogradMLC,
DispatchKey::AutogradHPU,
DispatchKey::AutogradXPU,
DispatchKey::AutogradPrivateUse1,
DispatchKey::AutogradPrivateUse2,
DispatchKey::AutogradPrivateUse3,
DispatchKey::AutogradOther,
});
@ -608,28 +244,25 @@ constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
// backend dispatch keys that map to DispatchKey::AutogradOther
// NB: keys in this set also get associated with CompositeImplicitAutograd
constexpr DispatchKeySet autogradother_backends =
DispatchKeySet(
// HIP and VE aren't in this list: they now have their own backend bits
// which means that they can now have their own Autograd keys.
// Technically, HIP will now redispatch to its own custom AutogradHIP
// slot in the runtime table.
{DispatchKey::FPGA,
DispatchKey::ORT,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
DispatchKey::Meta,
// Sparse and Quantized backends also live here.
DispatchKey::Sparse,
DispatchKey::Quantized})
// Including the backend bits because this keyset is used during op
// registration, which requires looping over all runtime autogradother
// backend keys.
| DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
{DispatchKey::HIP,
DispatchKey::VE,
DispatchKey::FPGA,
DispatchKey::ORT,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::QuantizedCPU,
DispatchKey::QuantizedCUDA,
DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseHIP,
DispatchKey::SparseVE,
DispatchKey::SparseXPU,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::Meta});
// The set of dispatch keys that come after autograd
// n.b. this relies on the fact that AutogradOther is currently the lowest
@ -659,36 +292,6 @@ constexpr DispatchKeySet after_func_keyset =
// away with it by explicitly removing the key here.
c10::DispatchKey::ADInplaceOrView);
constexpr DispatchKeySet backend_bitset_mask =
DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
constexpr auto inplace_or_view_ks =
DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
constexpr auto autograd_mlc_ks = DispatchKeySet(DispatchKey::AutogradMLC);
constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
constexpr auto autograd_privateuse1_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse1);
constexpr auto autograd_privateuse2_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse2);
constexpr auto autograd_privateuse3_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse3);
constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
struct OpTableOffsetAndMask {
uint16_t offset;
uint16_t backend_mask;
};
static_assert(
num_backends <= 16,
"Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
" that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
// true if t is a backend dispatch key
C10_API bool isBackendDispatchKey(DispatchKey t);
@ -704,53 +307,10 @@ C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
// Returns a DispatchKeySet of autograd related keys mapped to backend.
// for a given backend key, use the associated autograd key.
// for non-backend keys, use AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
switch (t) {
case BackendComponent::CPUBit:
return inplace_or_view_ks | autograd_cpu_ks;
case BackendComponent::XPUBit:
return inplace_or_view_ks | autograd_xpu_ks;
case BackendComponent::CUDABit:
return inplace_or_view_ks | autograd_cuda_ks;
case BackendComponent::XLABit:
return inplace_or_view_ks | autograd_xla_ks;
case BackendComponent::LazyBit:
return inplace_or_view_ks | autograd_lazy_ks;
case BackendComponent::MLCBit:
return inplace_or_view_ks | autograd_mlc_ks;
case BackendComponent::HPUBit:
return inplace_or_view_ks | autograd_hpu_ks;
case BackendComponent::PrivateUse1Bit:
return inplace_or_view_ks | autograd_privateuse1_ks;
case BackendComponent::PrivateUse2Bit:
return inplace_or_view_ks | autograd_privateuse2_ks;
case BackendComponent::PrivateUse3Bit:
return inplace_or_view_ks | autograd_privateuse3_ks;
default:
return inplace_or_view_ks | autograd_other_ks;
}
}
C10_API DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t);
// Returns a DispatchKeySet of autocast related keys mapped to backend.
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
switch (t) {
case BackendComponent::CPUBit:
return autocast_cpu_ks;
case BackendComponent::CUDABit:
case BackendComponent::XLABit:
return autocast_cuda_ks;
default:
return DispatchKeySet();
}
}
C10_API DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t);
// This API exists because we have a use case for checking
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)

View File

@ -190,7 +190,7 @@ TensorImpl::TensorImpl(
// TODO: be more explicit about the full key set at call sites so we
// don't have to keep recomputing it here
auto k = key_set.highestBackendKey();
DispatchKey k = key_set.highestPriorityBackendTypeId();
key_set = key_set | getAutocastRelatedKeySetFromBackend(k);

View File

@ -838,7 +838,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_sparse() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has(DispatchKey::Sparse);
return key_set_.has(DispatchKey::SparseCPU) ||
key_set_.has(DispatchKey::SparseCUDA) ||
key_set_.has(DispatchKey::SparseHIP) ||
key_set_.has(DispatchKey::SparseXPU);
}
// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR
@ -851,7 +854,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_quantized() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has(DispatchKey::Quantized);
return key_set_.has(DispatchKey::QuantizedCPU) ||
key_set_.has(DispatchKey::QuantizedCUDA) ||
key_set_.has(DispatchKey::QuantizedXPU);
}
bool is_meta() const {
@ -863,46 +868,53 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_cpu() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::CPUBit) ||
return key_set_.has(DispatchKey::CPU) ||
key_set_.has(DispatchKey::SparseCPU) ||
key_set_.has(DispatchKey::SparseCsrCPU) ||
key_set_.has(DispatchKey::QuantizedCPU) ||
key_set_.has(DispatchKey::MkldnnCPU);
}
bool is_cuda() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::CUDABit) ||
key_set_.has(DispatchKey::SparseCsrCUDA);
return key_set_.has(DispatchKey::CUDA) ||
key_set_.has(DispatchKey::SparseCUDA) ||
key_set_.has(DispatchKey::SparseCsrCUDA) ||
key_set_.has(DispatchKey::QuantizedCUDA);
}
bool is_xpu() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::XPUBit);
return key_set_.has(DispatchKey::XPU) ||
key_set_.has(DispatchKey::SparseXPU) ||
key_set_.has(DispatchKey::QuantizedXPU);
}
bool is_xla() const {
return key_set_.has_backend(BackendComponent::XLABit);
return key_set_.has(DispatchKey::XLA);
}
bool is_hpu() const {
return key_set_.has_backend(BackendComponent::HPUBit);
return key_set_.has(DispatchKey::HPU);
}
bool is_lazy() const {
return key_set_.has_backend(BackendComponent::LazyBit);
return key_set_.has(DispatchKey::Lazy);
}
bool is_hip() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::HIPBit);
return key_set_.has(DispatchKey::HIP) ||
key_set_.has(DispatchKey::SparseHIP);
}
bool is_ve() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
return key_set_.has_backend(BackendComponent::VEBit);
return key_set_.has(DispatchKey::VE) || key_set_.has(DispatchKey::SparseVE);
}
bool is_mkldnn() const {
@ -1536,22 +1548,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/
inline bool has_compatible_shallow_copy_type(DispatchKeySet from) {
auto is_dense = [](DispatchKeySet ts) {
constexpr auto dense_backends = DispatchKeySet(
{BackendComponent::CPUBit,
BackendComponent::CUDABit,
BackendComponent::HIPBit,
BackendComponent::XPUBit});
constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense);
return ts.has_any(dense_k) && ts.has_any(dense_backends);
return ts.has(DispatchKey::CPU) || ts.has(DispatchKey::CUDA) ||
ts.has(DispatchKey::HIP) || ts.has(DispatchKey::XPU);
};
auto is_sparse = [](DispatchKeySet ts) {
constexpr auto sparse_backends = DispatchKeySet(
{BackendComponent::CPUBit,
BackendComponent::CUDABit,
BackendComponent::HIPBit,
BackendComponent::XPUBit});
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
return ts.has_any(sparse_k) && ts.has_any(sparse_backends);
return ts.has(DispatchKey::SparseCPU) ||
ts.has(DispatchKey::SparseCUDA) || ts.has(DispatchKey::SparseHIP) ||
ts.has(DispatchKey::SparseXPU);
};
return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) ||
(is_sparse(key_set_) && is_sparse(from));

View File

@ -3,163 +3,25 @@
#include <unordered_set>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/irange.h>
using namespace c10;
// This test exists not to be comprehensive, but to more clearly show
// what the semantics of DispatchKeySet are.
TEST(DispatchKeySet, ShowSemantics) {
// the "CPU" dispatch key is an instance of a per-backend-functionality key.
// It corresponds to "dense" functionality, "CPU" backend.
// This means that it gets a dense functionality bit, and a cpu backend bit
// set.
auto undefined_set = DispatchKeySet();
auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU);
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy);
ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense));
ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit));
ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy));
// You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block"
// dispatch keys. You are allowed to directly create keysets out of them!
auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) |
DispatchKeySet(BackendComponent::CPUBit);
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks);
// Similarly, the AutogradCUDA key gets 2 bits in the keyset:
// The "Autograd" functionality bit, and the "CUDA" backend bit
auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA);
ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality));
ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit));
// Because DispatchKeySet uses a condensed internal representation, you cannot
// use it to represent the FULL cross product of backends and functionalities
// for example:
auto autograd_dense_cpu_cuda = DispatchKeySet(
{DispatchKey::AutogradFunctionality,
DispatchKey::Dense,
DispatchKey::CUDA,
DispatchKey::CPU});
auto fpga = DispatchKeySet(DispatchKey::FPGA);
auto fpga_and_cpu = DispatchKeySet({DispatchKey::FPGA, DispatchKey::CPU});
// this keyset has all of the building block keys:
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense));
ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit));
ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit));
// and it also has the "runtime" keys that correspond to the full
// cross-product of functionality
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA));
// This means that there's no way to represent a keyset with, say, only
// Autograd CUDA + Dense CPU. Instead, you should think of a keyset as
// inheriting the full set of functionalities + backends of its keys. This
// means that the below keysets are all indistinguishable from each other.
ASSERT_EQ(
autograd_dense_cpu_cuda,
DispatchKeySet(
{DispatchKey::AutogradCUDA,
DispatchKey::AutogradCPU,
DispatchKey::CUDA,
DispatchKey::CPU}));
ASSERT_EQ(
autograd_dense_cpu_cuda,
DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU}));
ASSERT_EQ(
autograd_dense_cpu_cuda,
DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU}));
// ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~
// Iterators allow you to iterate individually through the DispatchKey's in a
// DispatchKeySet
auto empty_set = DispatchKeySet();
auto t1 = empty_set.begin();
auto t2 = empty_set.end();
ASSERT_EQ(*empty_set.begin(), *empty_set.end());
// However, only keys that correspond to actual runtime indices of kernels in
// the operator table show up when you iterate through a keyset. i.e.
// DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an
// iterator.
auto dense_cpu_iter = dense_cpu_set.begin();
ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU);
ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end());
auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin();
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end());
// But other "functionality bits" that are not defined per-backend DO get
// their own slots in the operator table.
auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) |
DispatchKeySet(
{DispatchKey::FPGA, // runtime key
DispatchKey::Functionalize, // runtime key
DispatchKey::Dense}); // NOT a runtime key
auto mixed_iter = mixed_keyset.begin();
ASSERT_EQ(*mixed_iter++, DispatchKey::CPU);
ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA);
ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize);
ASSERT_EQ(*mixed_iter, *mixed_keyset.end());
}
TEST(DispatchKeySet, Empty) {
DispatchKeySet empty_set;
for (uint8_t i = 0;
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
if (tid == DispatchKey::Undefined)
continue;
ASSERT_FALSE(empty_set.has(tid));
}
ASSERT_TRUE(empty_set.empty());
DispatchKeySet empty_set2;
ASSERT_TRUE(empty_set == empty_set2);
ASSERT_EQ(empty_set.highestPriorityTypeId(), DispatchKey::Undefined);
}
// This covers all keys that correspond to a single backend bit, e.g.
// BackendComponent::CPUBit. Even though these are NOT runtime keys, we still
// allow adding them directly to a keyset
TEST(DispatchKeySet, SingletonBackendComponent) {
for (const auto i : c10::irange(1, num_backends)) {
auto tid = static_cast<DispatchKey>(i);
DispatchKeySet sing(tid);
ASSERT_EQ(sing, sing);
ASSERT_EQ(sing, DispatchKeySet().add(tid));
ASSERT_EQ(sing, sing.add(tid));
ASSERT_EQ(sing, sing | sing);
ASSERT_FALSE(sing.empty());
ASSERT_TRUE(sing.has(tid));
}
}
// This covers all keys that correspond to a single functionality bit:
// - runtime, not-per-backend functionality keys, e.g.
// DispatchKey::FuncTorchBatched
// - runtime, "fake backend" keys, e.g. DispatchKey::FPGA
// - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense
// Even though it's not a runtime key, we still allow adding it directly to a
// keyset.
// DispatchKey::
TEST(DispatchKeySet, SingletonFunctionalityKeys) {
for (const auto i : c10::irange(1, num_functionality_keys)) {
TEST(DispatchKeySet, Singleton) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
DispatchKeySet sing(tid);
ASSERT_EQ(sing, sing);
@ -168,145 +30,47 @@ TEST(DispatchKeySet, SingletonFunctionalityKeys) {
ASSERT_EQ(sing, sing | sing);
ASSERT_FALSE(sing.empty());
ASSERT_TRUE(sing.has(tid));
ASSERT_EQ(sing.highestPriorityTypeId(), tid);
ASSERT_EQ(sing.remove(tid), DispatchKeySet());
}
}
// This covers runtime keys that are per-backend,
// and take up more than one bit in a DispatchKeySet. They take up one
// functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA,
// AutogradCPU, AutogradCUDA
TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
// Skip these because they aren't real keys.
if (tid == DispatchKey::StartOfDenseBackends ||
tid == DispatchKey::StartOfSparseBackends ||
tid == DispatchKey::StartOfQuantizedBackends ||
tid == DispatchKey::StartOfAutogradBackends) {
continue;
}
DispatchKeySet sing(tid);
ASSERT_EQ(sing, sing);
ASSERT_EQ(sing, DispatchKeySet().add(tid));
ASSERT_EQ(sing, sing.add(tid));
ASSERT_EQ(sing, sing | sing);
ASSERT_FALSE(sing.empty());
ASSERT_TRUE(sing.has(tid));
auto functionality_key = toFunctionalityKey(tid);
auto backend_key = toBackendComponent(tid);
// These two sets should be equivalent:
// DispatchKeySet(DispatchKey::CPU)
// DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit})
auto expected_ks =
DispatchKeySet(functionality_key) | DispatchKeySet(backend_key);
ASSERT_EQ(sing, expected_ks);
// These two sets should be equivalent:
// DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense)
// DispatchKeySet(BackendComponent::CPUBit)
expected_ks = DispatchKeySet(toBackendComponent(tid));
ASSERT_EQ(sing.remove(tid), expected_ks);
}
}
TEST(DispatchKeySet, DoubletonPerBackend) {
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
TEST(DispatchKeySet, Doubleton) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
for (uint8_t j = i + 1;
j <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
j < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
j++) {
ASSERT_LT(i, j);
auto tid1 = static_cast<DispatchKey>(i);
auto tid2 = static_cast<DispatchKey>(j);
// Skip these because they aren't real keys.
if (tid1 == DispatchKey::StartOfDenseBackends ||
tid1 == DispatchKey::StartOfSparseBackends ||
tid1 == DispatchKey::StartOfQuantizedBackends ||
tid1 == DispatchKey::StartOfAutogradBackends)
continue;
if (tid2 == DispatchKey::StartOfDenseBackends ||
tid2 == DispatchKey::StartOfSparseBackends ||
tid2 == DispatchKey::StartOfQuantizedBackends ||
tid2 == DispatchKey::StartOfAutogradBackends)
continue;
auto backend1 = toBackendComponent(tid1);
auto backend2 = toBackendComponent(tid2);
auto functionality1 = toFunctionalityKey(tid1);
auto functionality2 = toFunctionalityKey(tid2);
auto combined = DispatchKeySet({tid1, tid2});
// The combined set has the backend bits
ASSERT_TRUE(combined.has_backend(backend1));
ASSERT_TRUE(combined.has_backend(backend2));
// and it has the backend bits
ASSERT_TRUE(combined.has(functionality1));
ASSERT_TRUE(combined.has(functionality2));
// and it has the original two runtime keys
ASSERT_TRUE(combined.has(tid1));
ASSERT_TRUE(combined.has(tid2));
// Add all of the keys in the keyset to a real set
std::unordered_set<DispatchKey> visited_keys;
auto iter = combined.begin();
while (*iter != *combined.end()) {
visited_keys.insert(*iter);
++iter;
}
std::unordered_set<DispatchKey> expected_keys;
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality1, backend1));
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality1, backend2));
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality2, backend1));
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality2, backend2));
ASSERT_EQ(expected_keys, visited_keys);
if (backend1 == backend2 || functionality1 == functionality2) {
// We have two runtime keys, with either the same backend or the same
// per-backend functionalities. E.g. {AutogradCUDA, CUDA} or
// {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in
// this set.
ASSERT_EQ(2, visited_keys.size());
} else {
// since i and j are different keys, they should not have the same
// functionality and backend
ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2);
// We have two runtime keys, that have different backends + per-backend
// functionalities. So we should expect the full cross product of
// runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU,
// then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU}
ASSERT_EQ(4, visited_keys.size());
}
auto doub = DispatchKeySet(tid1).add(tid2);
ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2));
ASSERT_TRUE(doub.has(tid1));
ASSERT_TRUE(doub.has(tid2));
ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j
}
}
}
TEST(DispatchKeySet, Full) {
DispatchKeySet full(DispatchKeySet::FULL);
for (const auto i : c10::irange(1, num_functionality_keys)) {
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
ASSERT_TRUE(full.has(tid));
}
ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys));
}
TEST(DispatchKeySet, IteratorBasicOps) {
DispatchKeySet empty_set;
DispatchKeySet full_set(DispatchKeySet::FULL);
DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU);
DispatchKeySet mutated_set = empty_set.add(static_cast<DispatchKey>(1));
// Constructor + Comparison
ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys);
ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys);
ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU);
ASSERT_EQ(*empty_set.begin(), DispatchKey::NumDispatchKeys);
ASSERT_EQ(*empty_set.end(), DispatchKey::NumDispatchKeys);
ASSERT_EQ(*mutated_set.begin(), static_cast<DispatchKey>(1));
ASSERT_TRUE(empty_set.begin() == empty_set.end());
ASSERT_TRUE(full_set.begin() != full_set.end());
@ -326,37 +90,16 @@ TEST(DispatchKeySet, IteratorEmpty) {
ASSERT_EQ(i, 0);
}
TEST(DispatchKeySet, IteratorCrossProduct) {
// The iterator should return all runtime keys in the set,
// including the cross product of {backends} x {functionalities}
auto ks =
DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) |
DispatchKeySet(
{DispatchKey::Dense,
DispatchKey::FPGA,
DispatchKey::AutogradFunctionality});
auto iter = ks.begin();
// iterate through dense backends first.
ASSERT_EQ(DispatchKey::CPU, *(iter++));
ASSERT_EQ(DispatchKey::CUDA, *(iter++));
// FPGA doesn't have a backend bit, so it isn't included in the cross product.
ASSERT_EQ(DispatchKey::FPGA, *(iter++));
// iterate through the autograd keys laster.
ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++));
ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++));
}
TEST(DispatchKeySet, IteratorFull) {
DispatchKeySet full_set(DispatchKeySet::FULL);
uint8_t i = 0;
for (const auto& it : full_set) {
i++;
ASSERT_TRUE(it == static_cast<DispatchKey>(i));
ASSERT_TRUE(it != DispatchKey::NumDispatchKeys);
}
// Total # of runtime entries includes an entry for DispatchKey::Undefined,
// which is not included when iterating through the DispatchKeySet.
ASSERT_EQ(i, num_runtime_entries - 1);
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1);
}
TEST(DispatchKeySet, IteratorRangeFull) {
@ -365,61 +108,41 @@ TEST(DispatchKeySet, IteratorRangeFull) {
for (DispatchKey dispatch_key : full_set) {
i++;
ASSERT_TRUE(dispatch_key == static_cast<DispatchKey>(i));
}
// Total # of runtime entries includes an entry for DispatchKey::Undefined,
// which is not included when iterating through the DispatchKeySet.
ASSERT_EQ(i, num_runtime_entries - 1);
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1);
}
TEST(DispatchKeySet, SpecificKeys) {
DispatchKeySet keyset({
static_cast<DispatchKey>(0), // Undefined should be ignored
static_cast<DispatchKey>(4),
static_cast<DispatchKey>(10),
static_cast<DispatchKey>(15),
});
std::unordered_set<DispatchKey> visited_keys;
for (DispatchKey key : keyset) {
visited_keys.insert(key);
}
ASSERT_EQ(visited_keys.size(), 3);
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(4)) != visited_keys.end());
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(10)) != visited_keys.end());
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(15)) != visited_keys.end());
}
TEST(DispatchKeySet, FailAtEndIterator) {
DispatchKeySet full_set(DispatchKeySet::FULL);
uint64_t raw_repr = full_set.raw_repr();
// doesn't throw
DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW(
DispatchKeySet::iterator(
&raw_repr, num_backends + num_functionality_keys + 1),
&raw_repr, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) + 1),
c10::Error);
}
TEST(DispatchKeySet, TestKeyOrderingInvariants) {
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
i++) {
auto k = static_cast<DispatchKey>(i);
// Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// The DispatchKey enum includes all of the runtime keys for
// Dense/Sparse/Quantized/Autograd, (e.g. CPU, CUDA, SparseCPU, SparseCUDA,
// AutogradCPU, AutogradCUDA, etc). And we expect the ordering of those keys
// to be the same as the ordering of the backends in the `BackendComponent`
// enum. This makes several utilities in `DispatchKey.h` and
// `DispatchKeySet.h` significantly easier to implement. The purpose of the
// test is to assert (through CI) that this invariant is maintained.
//
// The only way that we can really check this invariant is by
// comparing the string names of each enum.
// We only really care about the ordering for "real" keys that are actually
// used, which we expect to be able to print properly. This saves us from
// having to enumerate the full set of possible runtime keys in
// DispatchKey::toString(). It also relies on toString() being implemented
// correctly.
auto functionality_str = std::string(toString(k));
if (functionality_str == "UNKNOWN_TENSOR_TYPE_ID")
continue;
auto computed_backend_k = toBackendComponent(k);
auto computed_backend_str = std::string(toString(computed_backend_k));
// Skip, e.g., the "Bit" from "CPUBit"
computed_backend_str =
computed_backend_str.substr(0, computed_backend_str.size() - 3);
ASSERT_TRUE(
functionality_str.find(computed_backend_str) != std::string::npos)
<< "DispatchKey invariant broken! Found a key that is not ordered correctly"
<< " with its backend bit. key = " << toString(k) << ", " << k
<< ", computed backend = " << toString(computed_backend_k);
}
}

View File

@ -532,8 +532,8 @@ AutogradXLA: fn_math [math kernel]
lambda m: m.def_("foo(Tensor x) -> Tensor"),
# m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"),
# m.impl("foo", torch::kFPGA, [](const Tensor & x) { return x })
lambda m: m.impl_t_t("foo", "FPGA", debug="fn_fpga"),
# m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x })
lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"),
])
state, table = result.state, result.table
self.assertExpectedInline(state, '''\
@ -541,12 +541,12 @@ name: test::foo
schema: test::foo(Tensor x) -> (Tensor)
debug: registered at /dev/null:0
alias analysis kind: FROM_SCHEMA
FPGA: fn_fpga :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
''')
# computed dispatch table is too big, so we only check on a few entries we're interested in.
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',))
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
self.assertExpectedInline(extracted_table, '''\
Undefined: fn_math [math kernel]
@ -557,7 +557,7 @@ AutogradOther: ambiguous_autogradother [ambiguous autogradother]
AutogradCPU: fn_math [math kernel]
AutogradCUDA: fn_math [math kernel]
AutogradXLA: fn_math [math kernel]
FPGA: fn_fpga [kernel]
QuantizedCPU: fn_quantizedcpu [kernel]
''')
def test_computed_table_with_cpu_defaultbackend(self):
@ -616,7 +616,7 @@ CompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0
''')
# computed dispatch table is too big, so we only check on a few entries we're interested in.
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',))
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
self.assertExpectedInline(extracted_table, '''\
Undefined: fn_defaultbackend [default backend kernel]
@ -627,7 +627,7 @@ AutogradOther: fn_autograd [autograd kernel]
AutogradCPU: fn_autograd [autograd kernel]
AutogradCUDA: fn_autograd [autograd kernel]
AutogradXLA: fn_autograd [autograd kernel]
FPGA: fn_defaultbackend [default backend kernel]
QuantizedCPU: fn_defaultbackend [default backend kernel]
''')
def test_computed_table_with_cpu_autograd_math_defaultbackend(self):
@ -808,7 +808,7 @@ key kernel
CPU fn_CPU [kernel]
XLA fn_XLA [kernel]
Lazy fn_Lazy [kernel]
FPGA fn_CompositeImplicitAutograd [math kernel]
QuantizedCPU fn_CompositeImplicitAutograd [math kernel]
AutogradOther fn_CompositeImplicitAutograd [math kernel]
AutogradCPU fallthrough [backend fallback]
AutogradXLA fallthrough [backend fallback]
@ -829,7 +829,7 @@ key kernel
CPU fn_CPU [kernel]
XLA fn_XLA [kernel]
Lazy fn_Lazy [kernel]
FPGA fn_CompositeImplicitAutograd [math kernel]
QuantizedCPU fn_CompositeImplicitAutograd [math kernel]
AutogradOther fn_CompositeImplicitAutograd [math kernel]
AutogradCPU fn_AutogradCPU [kernel]
AutogradXLA fallthrough [backend fallback]
@ -864,7 +864,7 @@ key kernel
CPU fn_CPU [kernel]
XLA fn_XLA [kernel]
Lazy fn_Lazy [kernel]
FPGA fn_CompositeExplicitAutograd [default backend kernel]
QuantizedCPU fn_CompositeExplicitAutograd [default backend kernel]
AutogradOther fallthrough [backend fallback]
AutogradCPU fn_AutogradCPU [kernel]
AutogradXLA fallthrough [backend fallback]
@ -889,7 +889,7 @@ CompositeExplicitAutograd[alias] fn_CompositeExplicitAutograd
def test_autogradother(self):
dispatcher = PythonDispatcher()
dispatcher.register(["CPU", "FPGA", "CompositeImplicitAutograd"])
dispatcher.register(["CPU", "QuantizedCPU", "CompositeImplicitAutograd"])
self.assertExpectedInline(
dispatcher.dispatchTable(),
'''\
@ -900,7 +900,7 @@ key kernel
CPU fn_CPU [kernel]
XLA fn_CompositeImplicitAutograd [math kernel]
Lazy fn_CompositeImplicitAutograd [math kernel]
FPGA fn_FPGA [kernel]
QuantizedCPU fn_QuantizedCPU [kernel]
AutogradOther ambiguous_autogradother [ambiguous autogradother]
AutogradCPU fallthrough [backend fallback]
AutogradXLA fn_CompositeImplicitAutograd [math kernel]
@ -915,8 +915,8 @@ AutogradLazy fn_CompositeImplicitAutograd [math kernel]
Registered Kernels
key kernel
---------------------------
FPGA fn_FPGA
CPU fn_CPU
QuantizedCPU fn_QuantizedCPU
CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
'''
)

View File

@ -3410,21 +3410,21 @@ class TestSparseOneOff(TestCase):
def test_cuda_from_cpu(self):
with self.assertRaisesRegex(
RuntimeError,
"Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"):
torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(),
torch.randn(4, 4, 4),
[3, 4, 4])
with self.assertRaisesRegex(
RuntimeError,
"Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"):
torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(),
torch.randn(4, 4, 4, 0),
[3, 4, 4, 0])
with self.assertRaisesRegex(
RuntimeError,
"Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"):
torch.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(),
torch.randn(0, 4, 4, 0),
[0, 4, 4, 0])

View File

@ -48,66 +48,58 @@ class DispatchKey(Enum):
Undefined = 0
CatchAll = Undefined
Dense = auto()
CPU = auto()
CUDA = auto()
HIP = auto()
FPGA = auto()
ORT = auto()
XLA = auto()
Lazy = auto()
Vulkan = auto()
Metal = auto()
XPU = auto()
MKLDNN = auto()
OpenGL = auto()
OpenCL = auto()
IDEEP = auto()
Quantized = auto()
QuantizedCPU = auto()
QuantizedCUDA = auto()
QuantizedXPU = auto()
CustomRNGKeyId = auto()
MkldnnCPU = auto()
Sparse = auto()
SparseCPU = auto()
SparseCUDA = auto()
SparseCsrCPU = auto()
SparseCsrCUDA = auto()
SparseHIP = auto()
SparseXPU = auto()
NestedTensor = auto()
PrivateUse1 = auto()
PrivateUse2 = auto()
PrivateUse3 = auto()
EndOfBackendKeys = PrivateUse3
ZeroTensor = auto()
Meta = auto()
BackendSelect = auto()
Named = auto()
AutogradOther = auto()
AutogradFunctionality = auto()
AutogradCPU = auto()
AutogradCUDA = auto()
AutogradXLA = auto()
AutogradLazy = auto()
AutogradNestedTensor = auto()
AutogradXPU = auto()
AutogradPrivateUse1 = auto()
AutogradPrivateUse2 = auto()
AutogradPrivateUse3 = auto()
Tracer = auto()
Autocast = auto()
Batched = auto()
VmapMode = auto()
TESTING_ONLY_GenericWrapper = auto()
TESTING_ONLY_GenericMode = auto()
EndOfFunctionalityKeys = TESTING_ONLY_GenericMode
CPU = auto()
CUDA = auto()
HIP = auto()
XLA = auto()
Lazy = auto()
XPU = auto()
NestedTensor = auto()
PrivateUse1 = auto()
PrivateUse2 = auto()
PrivateUse3 = auto()
QuantizedCPU = auto()
QuantizedCUDA = auto()
QuantizedXPU = auto()
SparseCPU = auto()
SparseCUDA = auto()
SparseHIP = auto()
SparseXPU = auto()
AutogradCPU = auto()
AutogradCUDA = auto()
AutogradXLA = auto()
AutogradLazy = auto()
AutogradXPU = auto()
AutogradPrivateUse1 = auto()
AutogradPrivateUse2 = auto()
AutogradPrivateUse3 = auto()
NumDispatchKeys = auto()
Autograd = auto()
CompositeImplicitAutograd = auto()
CompositeExplicitAutograd = auto()

View File

@ -15,9 +15,9 @@ keys for a single example of each use case. These use cases are listed below:
- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
autograd kernel in pytorch core library.
E.g. CPU, CUDA
- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
- QuantizedCPU/AutogradOther: represents in-tree backends which we usually have backend specific
inference kernels, but they share the same autograd kernel specified in AutogradOther.
E.g. FPGA, SparseCsrCPU
E.g. QuantizedCPU, QuantizedCUDA
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
kernel defined in pytorch core library. Backend owner is responsible for registering both
inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
@ -53,7 +53,7 @@ class PythonDispatcher:
name = "foo"
runtime_keys = [
"CPU", "AutogradCPU",
"FPGA", "AutogradOther",
"QuantizedCPU", "AutogradOther",
"XLA", "AutogradXLA",
"Lazy", "AutogradLazy",
]