mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
There are lots of warnings in builds: ``` 2025-07-05T16:59:46.9208806Z C:\actions-runner\_work\pytorch\pytorch\build\aten\src\ATen\core\TensorBody.h(5043,29): warning: 'at::Tensor::less_' redeclared inline; 'dllimport' attribute ignored [-Wignored-attributes] 2025-07-05T16:59:46.9209030Z 5043 | inline at::Tensor & Tensor::less_(const at::Scalar & other) const { 2025-07-05T16:59:46.9209104Z | ^ 2025-07-05T16:59:46.9209671Z C:\actions-runner\_work\pytorch\pytorch\build\aten\src\ATen\core\TensorBody.h(5048,29): warning: 'at::Tensor::less_' redeclared inline; 'dllimport' attribute ignored [-Wignored-attributes] 2025-07-05T16:59:46.9209860Z 5048 | inline at::Tensor & Tensor::less_(const at::Tensor & other) const ``` This PR has fixed them and turned the warning into an error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157670 Approved by: https://github.com/albanD
969 lines
41 KiB
C++
969 lines
41 KiB
C++
#pragma once
|
||
#include <c10/core/DispatchKey.h>
|
||
#include <c10/macros/Export.h>
|
||
#include <c10/macros/Macros.h>
|
||
#include <c10/util/Exception.h>
|
||
#include <c10/util/Metaprogramming.h>
|
||
#include <c10/util/TypeList.h>
|
||
#include <c10/util/llvmMathExtras.h>
|
||
#include <array>
|
||
#include <cstddef>
|
||
#include <cstdint>
|
||
#include <initializer_list>
|
||
#include <iterator>
|
||
#include <ostream>
|
||
#include <string>
|
||
#include <type_traits>
|
||
|
||
namespace c10 {
|
||
|
||
struct FunctionalityOffsetAndMask {
|
||
// empty constructor shouldn't be used; only needed to initialize
|
||
// the array before populating it.
|
||
FunctionalityOffsetAndMask() = default;
|
||
FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
|
||
: offset(offset), mask(mask) {}
|
||
// This needs to big enough to cover the size of the operator table.
|
||
uint16_t offset{};
|
||
// See Note [No More Than 16 Backends]
|
||
// This mask needs to be big enough to mask all of the backend bits.
|
||
// We probably don't ever want to have more than 16 backend bits, so uint16_t
|
||
// should be enough.
|
||
uint16_t mask{};
|
||
};
|
||
static_assert(
|
||
c10::num_runtime_entries < 65536,
|
||
"The dispatcher currently only supports up to 2^16 runtime entries");
|
||
|
||
C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
|
||
initializeFunctionalityOffsetsAndMasks();
|
||
|
||
C10_ALWAYS_INLINE static const std::
|
||
array<FunctionalityOffsetAndMask, num_functionality_keys>&
|
||
offsetsAndMasks() {
|
||
static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
|
||
return offsets_and_masks_;
|
||
}
|
||
|
||
// A representation of a set of DispatchKeys. A DispatchKeySet contains both
|
||
// "functionality" bits and "backend bits", and every tensor holds its own
|
||
// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
|
||
// keyset on every input tensor, or’ing them together, and dispatching to a
|
||
// specific piece of functionality. The functionality bits are *ordered*. When
|
||
// multiple functionality bits are set, we use the highest priority
|
||
// functionality. Similarly, multiple backend bits can theoretically be set if
|
||
// you call an operator with multiple tensors from difference devices (e.g. CPU
|
||
// and CUDA), although support for mixed device dispatch is limited (the only
|
||
// kernels that gracefully handle mixed device inputs for now are cuda kernels
|
||
// that take in a scalar cpu tensor).
|
||
|
||
// A representation of a set of DispatchKeys. A tensor may have multiple
|
||
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
|
||
// DispatchKeySet specifies what type ids apply. The internal representation is
|
||
// as a 64-bit bit set (this means only 64 tensor type ids are supported).
|
||
//
|
||
// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
|
||
// "what is the highest priority DispatchKey in the set"? (The set itself is
|
||
// not ordered; two sets with the same ids will always have the ids ordered in
|
||
// the same way.)
|
||
//
|
||
// Note [DispatchKeySet Internal Representation]
|
||
// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
|
||
// that get passed around at runtime.
|
||
// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
|
||
// and individual dispatch keys.
|
||
//
|
||
// First: why do we have this distinction, and why not map every dispatch key
|
||
// directly to a bit? This is mostly because we have several types of
|
||
// functionalities that different backends would like to customize. For example,
|
||
// we have:
|
||
// - "Dense": CPU, CUDA, XLA, ... (~12 keys)
|
||
// - "Sparse": SparseCPU, SparseCUDA, ...
|
||
// - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ...
|
||
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
|
||
// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
|
||
// The problem is that total number of keys grows quadratically with [#
|
||
// backends] x [# functionalities], making it very difficult to map each key
|
||
// directly to a bit in a bitset without dramatically increasing the size of the
|
||
// bitset over time.
|
||
//
|
||
// The two enums (BackendComponent and DispatchKey) can be divided roughly into
|
||
// 5 categories.
|
||
//
|
||
// (1) "Building block" keys
|
||
// (a) backends: Everything in the BackendComponent enum (e.g. CPUBit,
|
||
// CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys
|
||
// (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense)
|
||
// (2) "Runtime" keys
|
||
// (a) "non-customizable backends" (e.g. FPGA)
|
||
// (b) "non-customizable functionalities" (e.g. Functionalize)
|
||
// (c) "per-backend instances of customizable functionalities" (e.g. CPU,
|
||
// SparseCPU, AutogradCPU)
|
||
// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
|
||
//
|
||
// (1) Building block keys always correspond to individual bits in a
|
||
// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
|
||
// runtime keys. e.g.
|
||
// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
|
||
// DispatchKey::Dense});
|
||
// // The keyset has the runtime dense-cpu key.
|
||
// dense_cpu_ks.has(DispatchKey::CPU);
|
||
// // And it contains the building block keys too.
|
||
// dense_cpu_ks.has(DispatchKey::CPUBit);
|
||
// dense_cpu_ks.has(DispatchKey::Dense);
|
||
//
|
||
// Not every backend and not every functionality counts as a "building block
|
||
// key". This is mostly to give us more levers to pull in the design space.
|
||
// Backend keys and functionality keys that count as "building blocks" will
|
||
// contribute to a full cross product of functionality that can be overridden.
|
||
//
|
||
// For example, right now we have at least 12 "backend" building
|
||
// blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality"
|
||
// building blocks (Dense, Sparse, SparseCsr, Quantized,
|
||
// AutogradFunctionality, ...). These keys together allow every
|
||
// dispatcher operator to be customized in up to 12*4 different
|
||
// ways. Each of those requires a slot in the operator table of every
|
||
// dispatcher operator. Not every piece of functionality necessarily
|
||
// needs to be customizable per-backend, and not every backend
|
||
// necessarily needs to be able to customize every type of
|
||
// functionality.
|
||
//
|
||
//
|
||
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
|
||
// dispatch table, and you can directly register kernels to a runtime dispatch
|
||
// key.
|
||
//
|
||
// For per-backend functionalities like "Dense" or "AutogradFunctionality",
|
||
// you can think of the corresponding runtime dispatch keys as "instances" of
|
||
// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
|
||
// runtime instances of the "Dense" building block key.
|
||
|
||
// (2a) and (2b) are represented identically in the DispatchKeySet logic:
|
||
// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
|
||
// customizable per backend.
|
||
// In order to do so, we'd need to promote it to a per-backend functionality
|
||
// "building block" key.
|
||
// - non-customizable backends (e.g. FPGA) can NOT customize existing
|
||
// functionality like Sparse, Autograd, etc.
|
||
// In order to do so, we'd need to promote it to a backend "building block"
|
||
// key.
|
||
//
|
||
// In both cases, these keys directly correspond to runtime slots in the
|
||
// operator table.
|
||
//
|
||
//
|
||
// (3) "Alias" keys
|
||
// See Note [Alias Dispatch Keys]
|
||
//
|
||
// Final note: for anyone making future changes to the Dispatcher +
|
||
// DispatchKeySet internals, there's a closed PR with a basic
|
||
// python-implementation of the Dispatcher that might be useful in quickly
|
||
// testing out and validating changes. See it at
|
||
// https://github.com/pytorch/pytorch/pull/68743
|
||
|
||
// An undefined tensor is one with an empty tensor type set.
|
||
class DispatchKeySet final {
|
||
public:
|
||
enum Full { FULL };
|
||
enum FullAfter { FULL_AFTER };
|
||
enum Raw { RAW };
|
||
|
||
// NB: default constructor representation as zero is MANDATORY as
|
||
// use of DispatchKeySet in TLS requires this.
|
||
constexpr DispatchKeySet() = default;
|
||
|
||
constexpr DispatchKeySet(Full)
|
||
: repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
|
||
|
||
constexpr DispatchKeySet(FullAfter, DispatchKey t)
|
||
// LSB after t are OK, but not t itself.
|
||
// "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
|
||
// Quantized > Dense). But backends don't really have an ordering.
|
||
// Therefore, we're enforcing that FullAfter can only be used on
|
||
// "functionality" keys.
|
||
: repr_(
|
||
(1ULL
|
||
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
|
||
1)) -
|
||
1) {
|
||
*this = add(DispatchKey::PythonDispatcher);
|
||
}
|
||
|
||
// Public version of DispatchKeySet(uint64_t) API; external users
|
||
// must be explicit when they do this!
|
||
constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
|
||
|
||
constexpr explicit DispatchKeySet(BackendComponent k) {
|
||
if (k == BackendComponent::InvalidBit) {
|
||
repr_ = 0;
|
||
} else {
|
||
repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
|
||
}
|
||
}
|
||
|
||
constexpr explicit DispatchKeySet(DispatchKey k) {
|
||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||
if (k == DispatchKey::Undefined) {
|
||
// Case 1: handle Undefined specifically
|
||
repr_ = 0;
|
||
} else if (k <= DispatchKey::EndOfFunctionalityKeys) {
|
||
// Case 2: handle "functionality-only" keys
|
||
// These keys have a functionality bit set, but no backend bits
|
||
// These can technically be either:
|
||
// - valid runtime keys (e.g. DispatchKey::AutogradOther,
|
||
// DispatchKey::FuncTorchBatched, etc)
|
||
// - "building block" keys that aren't actual runtime keys (e.g.
|
||
// DispatchKey::Dense or Sparse)
|
||
uint64_t functionality_val = 1ULL
|
||
<< (num_backends + static_cast<uint8_t>(k) - 1);
|
||
repr_ = functionality_val;
|
||
} else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
|
||
// Case 3: "runtime" keys that have a functionality bit AND a backend bit.
|
||
// First compute which bit to flip for the functionality.
|
||
auto functionality_k = toFunctionalityKey(k);
|
||
// The - 1 is because Undefined is technically a "functionality" that
|
||
// doesn't show up in the bitset. So e.g. Dense is technically the second
|
||
// functionality, but the lowest functionality bit.
|
||
uint64_t functionality_val = 1ULL
|
||
<< (num_backends + static_cast<uint8_t>(functionality_k) - 1);
|
||
|
||
// then compute which bit to flip for the backend
|
||
// Case 4a: handle the runtime instances of "per-backend functionality"
|
||
// keys For example, given DispatchKey::CPU, we should set:
|
||
// - the Dense functionality bit
|
||
// - the CPUBit backend bit
|
||
// first compute which bit to flip for the backend
|
||
auto backend_k = toBackendComponent(k);
|
||
uint64_t backend_val = backend_k == BackendComponent::InvalidBit
|
||
? 0
|
||
: 1ULL << (static_cast<uint8_t>(backend_k) - 1);
|
||
repr_ = functionality_val + backend_val;
|
||
} else {
|
||
// At this point, we should have covered every case except for alias keys.
|
||
// Technically it would be possible to add alias dispatch keys to a
|
||
// DispatchKeySet, but the semantics are a little confusing and this
|
||
// currently isn't needed anywhere.
|
||
repr_ = 0;
|
||
}
|
||
}
|
||
|
||
constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
|
||
uint64_t repr = 0;
|
||
for (auto k : ks) {
|
||
repr |= DispatchKeySet(k).repr_;
|
||
}
|
||
return repr;
|
||
}
|
||
|
||
constexpr uint64_t backend_bits_to_repr(
|
||
std::initializer_list<BackendComponent> ks) {
|
||
uint64_t repr = 0;
|
||
for (auto k : ks) {
|
||
repr |= DispatchKeySet(k).repr_;
|
||
}
|
||
return repr;
|
||
}
|
||
|
||
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
|
||
: repr_(keys_to_repr(ks)) {}
|
||
|
||
explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
|
||
// Note: for some reason, putting this logic directly in the constructor
|
||
// appears to fail to compile on CUDA 10.1.
|
||
// See an example internal failure at
|
||
// https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
|
||
: repr_(backend_bits_to_repr(ks)) {}
|
||
|
||
// Test if a DispatchKey is in the set
|
||
inline bool has(DispatchKey t) const {
|
||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
|
||
return has_all(DispatchKeySet(t));
|
||
}
|
||
constexpr bool has_backend(BackendComponent t) const {
|
||
return has_all(DispatchKeySet(t));
|
||
}
|
||
|
||
// Test if a DispatchKey is in the set
|
||
// Given a DispatchKeySet of functionality keys and (potentially) backend
|
||
// keys, tests if all of them are in the current set.
|
||
constexpr bool has_all(DispatchKeySet ks) const {
|
||
return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
|
||
}
|
||
|
||
// Given a DispatchKeySet of functionality keys and (potentially) backend
|
||
// keys, tests if any of them are in the current set. This could technically
|
||
// be pretty easily implemented using has(). It is strictly a perf
|
||
// optimization though. There are many places in the code base where we want
|
||
// to test for multiple functionality keys together. HOWEVER, runtime
|
||
// per-backend functionality keys aren't allowed to be used with this
|
||
// function, because you can end up with weird results. e.g.
|
||
// DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
|
||
// would return true.
|
||
inline bool has_any(DispatchKeySet ks) const {
|
||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||
// Either there are no backend bits in the input keyset
|
||
((ks.repr_ & full_backend_mask) == 0) ||
|
||
// or there are no per-backend-functionality bits
|
||
// See [Note: Per-Backend Functionality Dispatch Keys]
|
||
((ks &
|
||
DispatchKeySet({
|
||
DispatchKey::Dense,
|
||
DispatchKey::Quantized,
|
||
DispatchKey::Sparse,
|
||
DispatchKey::SparseCsr,
|
||
DispatchKey::AutogradFunctionality,
|
||
})
|
||
.repr_) == 0));
|
||
return static_cast<bool>((repr_ & ks.repr_) != 0);
|
||
}
|
||
// Test if DispatchKeySet is a superset of ks.
|
||
bool isSupersetOf(DispatchKeySet ks) const {
|
||
return (repr_ & ks.repr_) == ks.repr_;
|
||
}
|
||
// Perform set union
|
||
constexpr DispatchKeySet operator|(DispatchKeySet other) const {
|
||
return DispatchKeySet(repr_ | other.repr_);
|
||
}
|
||
// Perform set intersection
|
||
constexpr DispatchKeySet operator&(DispatchKeySet other) const {
|
||
return DispatchKeySet(repr_ & other.repr_);
|
||
}
|
||
// Compute the set difference self - other,
|
||
// but ONLY for the functionality keys.
|
||
// Any backend bits set on self will remain unchanged.
|
||
// See Note [Removing keys from DispatchKeySet Only Affects Functionality
|
||
// Keys]
|
||
constexpr DispatchKeySet operator-(DispatchKeySet other) const {
|
||
return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
|
||
}
|
||
|
||
// Compute self ^ other
|
||
constexpr DispatchKeySet operator^(DispatchKeySet other) const {
|
||
return DispatchKeySet(repr_ ^ other.repr_);
|
||
}
|
||
bool operator==(DispatchKeySet other) const {
|
||
return repr_ == other.repr_;
|
||
}
|
||
bool operator!=(DispatchKeySet other) const {
|
||
return repr_ != other.repr_;
|
||
}
|
||
// Add a DispatchKey to the DispatchKey set. Does NOT mutate,
|
||
// returns the extended DispatchKeySet!
|
||
[[nodiscard]] constexpr DispatchKeySet add(DispatchKey t) const {
|
||
return *this | DispatchKeySet(t);
|
||
}
|
||
[[nodiscard]] constexpr DispatchKeySet add(DispatchKeySet ks) const {
|
||
return *this | ks;
|
||
}
|
||
|
||
// Remove a DispatchKey from the DispatchKey set.
|
||
// This is generally not an operation you should be doing
|
||
// (it's used to implement the printing overload, operator<<)
|
||
//
|
||
// Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
|
||
// Only functionality bits are allowed to be removed from a keyset.
|
||
// For now, we're only allowing removal of "functionality bits" from the
|
||
// keyset, which is specifically needed by the fallthrough key calculation
|
||
// logic. Why is removing backend bits problematic? Consider this example:
|
||
//
|
||
// DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
|
||
// DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
|
||
// DispatchKeySet([DispatchKey.CPU,
|
||
// DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
|
||
//
|
||
// What do we want to happen?
|
||
// Technically, we'd like it to be true that after removal,
|
||
// the first keyset still has the CUDA dispatch key while the second doesn't.
|
||
// Unfortunately there's no way to represent that, because the two keysets are
|
||
// represented the same way internally: functionality bits: Autograd, Dense
|
||
// backend bits: CPU, CUDA
|
||
//
|
||
// Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
|
||
// bit from the bitset.
|
||
[[nodiscard]] constexpr DispatchKeySet remove(DispatchKey t) const {
|
||
return DispatchKeySet(
|
||
repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
|
||
}
|
||
// You're allowed to remove a backend bit from a DispatchKeySet,
|
||
// but you have to be explicit about it (remove_backend() instead of
|
||
// remove()).
|
||
constexpr DispatchKeySet remove_backend(BackendComponent b) const {
|
||
return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_));
|
||
}
|
||
// Is the set empty? (AKA undefined tensor)
|
||
bool empty() const {
|
||
return repr_ == 0;
|
||
}
|
||
uint64_t raw_repr() const {
|
||
return repr_;
|
||
}
|
||
|
||
static DispatchKeySet from_raw_repr(uint64_t x) {
|
||
return DispatchKeySet(RAW, x);
|
||
}
|
||
|
||
DispatchKey highestFunctionalityKey() const {
|
||
auto functionality_idx = indexOfHighestBit();
|
||
// This means that none of the functionality bits were set.
|
||
if (functionality_idx < num_backends)
|
||
return DispatchKey::Undefined;
|
||
// The first num_backend bits in the keyset don't correspond to real
|
||
// dispatch keys.
|
||
return static_cast<DispatchKey>(functionality_idx - num_backends);
|
||
}
|
||
|
||
// This is similar like toBackendComponent(DispatchKey), but less restrictive.
|
||
// toBackendComponent() errors out if the key that it was passed has no
|
||
// backend bits, which is useful for error checking. We need a version of that
|
||
// here that can also handle "fake" backends like FPGA, because they need to
|
||
// map to the AutogradOther key. For those backends, we return
|
||
// BackendComponent::InvalidBit.
|
||
BackendComponent highestBackendKey() const {
|
||
// mask to mask out functionality bits
|
||
auto backend_idx =
|
||
DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
|
||
// all zeros across the backend bits means that no backend bits are set.
|
||
if (backend_idx == 0)
|
||
return BackendComponent::InvalidBit;
|
||
return static_cast<BackendComponent>(backend_idx);
|
||
}
|
||
|
||
// returns the DispatchKey of highest priority in the set.
|
||
DispatchKey highestPriorityTypeId() const {
|
||
auto functionality_k = highestFunctionalityKey();
|
||
if (isPerBackendFunctionalityKey(functionality_k)) {
|
||
return toRuntimePerBackendFunctionalityKey(
|
||
functionality_k, highestBackendKey());
|
||
}
|
||
return functionality_k;
|
||
}
|
||
|
||
// Returns the index of the most-significant bit in the keyset.
|
||
// This is used to as part of the calculation into the operator table to get:
|
||
// - the highest "functionality" bit in the keyset.
|
||
// - the highest "backend" bit in the keyset.
|
||
uint8_t indexOfHighestBit() const {
|
||
return 64 - llvm::countLeadingZeros(repr_);
|
||
}
|
||
|
||
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
|
||
// [Note: Trimmed Mobile Dispatch Keys]
|
||
/**
|
||
* The method below maps the dispatch key in the enum DispatchKey to an
|
||
* integer index in the dispatchTable_ array in OperatorEntry. The array
|
||
* is trimmed for mobile to reduce peak memory usage since it's
|
||
* unnecessary to reserve additional space for dispatch keys that will
|
||
* never be used on mobile.
|
||
*/
|
||
int getDispatchTableIndexForDispatchKeySet() const {
|
||
auto dk = highestPriorityTypeId();
|
||
switch (dk) {
|
||
case DispatchKey::Undefined:
|
||
return 0;
|
||
case DispatchKey::CPU:
|
||
return 1;
|
||
case DispatchKey::QuantizedCPU:
|
||
return 2;
|
||
case DispatchKey::SparseCPU:
|
||
return 3;
|
||
case DispatchKey::BackendSelect:
|
||
return 4;
|
||
case DispatchKey::ADInplaceOrView:
|
||
return 5;
|
||
case DispatchKey::AutogradOther:
|
||
return 6;
|
||
case DispatchKey::AutogradCPU:
|
||
return 7;
|
||
default:
|
||
return -1;
|
||
}
|
||
}
|
||
#else
|
||
// returns the index in the operator table of highest priority key in the the
|
||
// keyset Note that we could in theory implement this using
|
||
// highestPriorityTypeId(), but this code is very hotpath and we can do it
|
||
// faster without it.
|
||
int getDispatchTableIndexForDispatchKeySet() const {
|
||
auto functionality_idx =
|
||
DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
|
||
auto offset_and_mask = offsetsAndMasks()[functionality_idx];
|
||
// Mask the functionality bits out first, then right-shift by 1.
|
||
// right-shifting by 1 because everything is zero-indexed.
|
||
// E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
|
||
// give us an offset of 1, etc.
|
||
auto backend_idx =
|
||
DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
|
||
return offset_and_mask.offset + backend_idx;
|
||
}
|
||
#endif
|
||
|
||
// returns the "index" of the highest priority backend in the keyset.
|
||
// This is pretty similar to getBackendKey(), but:
|
||
// - It's hotpath code (part of the runtime bitset calculation)
|
||
// - I's returns an integer index, not an enum value
|
||
// - Everything is shifted to the right by 1.
|
||
// BackendComponent::InvalidBit is technically the lowest enum value,
|
||
// but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
|
||
// etc.
|
||
uint64_t getBackendIndex() const {
|
||
return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
|
||
}
|
||
|
||
private:
|
||
constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
|
||
uint64_t repr_ = 0;
|
||
|
||
public:
|
||
// STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
|
||
// in the set. The iterator is only invalidated by the destruction of the
|
||
// underlying DispatchKeySet as the iterator stores a pointer to the raw
|
||
// representation of the DispatchKeySet. Note: When we encounter a per-backend
|
||
// functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
|
||
// in the keyset, for that functionality. For example, if the next
|
||
// functionality key to iterate over is Autograd, and the backend bits in the
|
||
// keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
|
||
// then the next two keys we return will be DispatchKey::AutogradCPU,
|
||
// DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
|
||
// CUDA in DispatchKey.h).
|
||
class iterator {
|
||
public:
|
||
using self_type = iterator;
|
||
using iterator_category = std::input_iterator_tag;
|
||
using value_type = DispatchKey;
|
||
using difference_type = ptrdiff_t;
|
||
using reference = value_type&;
|
||
using pointer = value_type*;
|
||
// final mask value should mask out the entire keyset
|
||
static const uint8_t end_iter_mask_val =
|
||
num_backends + num_functionality_keys;
|
||
// final key value should be the last DispatchKey
|
||
static const uint8_t end_iter_key_val = num_functionality_keys;
|
||
|
||
// current_dispatchkey_idx_ will iterate through all functionality bits.
|
||
// current_backendcomponent_idx_ will iterate through all backend bits.
|
||
explicit iterator(
|
||
const uint64_t* data_ptr,
|
||
uint8_t next_functionality = num_backends,
|
||
uint8_t next_backend = 0)
|
||
: data_ptr_(data_ptr),
|
||
next_functionality_(next_functionality),
|
||
next_backend_(next_backend),
|
||
// These are in an invalid state at construction time, and set by the
|
||
// first increment call
|
||
current_dispatchkey_idx_(end_iter_key_val),
|
||
current_backendcomponent_idx_(end_iter_key_val) {
|
||
// Go to the first key in the set
|
||
TORCH_INTERNAL_ASSERT(
|
||
next_functionality_ >= num_backends,
|
||
"num_backends=",
|
||
static_cast<uint32_t>(num_backends),
|
||
"next_functionality_=",
|
||
static_cast<uint32_t>(next_functionality_));
|
||
++(*this);
|
||
}
|
||
|
||
C10_API self_type& operator++();
|
||
|
||
self_type operator++(int) {
|
||
self_type previous_iterator = *this;
|
||
++(*this);
|
||
return previous_iterator;
|
||
}
|
||
|
||
bool operator==(const self_type& rhs) const {
|
||
return next_functionality_ == rhs.next_functionality_ &&
|
||
current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
|
||
next_backend_ == rhs.next_backend_ &&
|
||
current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
|
||
}
|
||
bool operator!=(const self_type& rhs) const {
|
||
return next_functionality_ != rhs.next_functionality_ ||
|
||
current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
|
||
next_backend_ != rhs.next_backend_ ||
|
||
current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
|
||
}
|
||
DispatchKey operator*() const {
|
||
auto functionality_key =
|
||
static_cast<DispatchKey>(current_dispatchkey_idx_);
|
||
if (isPerBackendFunctionalityKey(functionality_key)) {
|
||
auto next_key = toRuntimePerBackendFunctionalityKey(
|
||
functionality_key,
|
||
static_cast<BackendComponent>(current_backendcomponent_idx_));
|
||
// We expect all of the Dense, Sparse, Quantized, and Autograd keys to
|
||
// be ordered the same way with respect to their backends
|
||
TORCH_INTERNAL_ASSERT(
|
||
toBackendComponent(next_key) ==
|
||
static_cast<BackendComponent>(current_backendcomponent_idx_),
|
||
"Tried to map functionality key ",
|
||
toString(functionality_key),
|
||
" and backend bit ",
|
||
toString(
|
||
static_cast<BackendComponent>(current_backendcomponent_idx_)),
|
||
" to a runtime key, but ended up with ",
|
||
toString(next_key),
|
||
". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
|
||
" Please double check that enum for inconsistencies.");
|
||
return next_key;
|
||
} else {
|
||
return functionality_key;
|
||
}
|
||
}
|
||
|
||
private:
|
||
const uint64_t* data_ptr_;
|
||
uint8_t next_functionality_;
|
||
uint8_t next_backend_;
|
||
uint8_t current_dispatchkey_idx_;
|
||
uint8_t current_backendcomponent_idx_;
|
||
};
|
||
|
||
public:
|
||
// Returns iterator to the first key in the set. If no keys are in the
|
||
// set, then will return the end iterator.
|
||
iterator begin() const {
|
||
return iterator(&repr_);
|
||
}
|
||
|
||
// We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
|
||
// this as the end iterator.
|
||
iterator end() const {
|
||
return iterator(&repr_, iterator::end_iter_mask_val);
|
||
}
|
||
};
|
||
|
||
C10_API std::string toString(DispatchKeySet);
|
||
C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
|
||
|
||
inline int getDispatchTableIndexForDispatchKey(DispatchKey k) {
|
||
return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
|
||
}
|
||
|
||
// Alias key DispatchKey::Autograd maps to
|
||
// (autograd_dispatch_keyset x full_backend_mask)
|
||
// NB: keys in this set also get associated with CompositeImplicitAutograd
|
||
//
|
||
// Note [autograd_dispatch_keyset Does Not Include Backend Bits]
|
||
// We don't want to include any backend bits (BackendComponent::CPUBit, etc)
|
||
// directly in autograd_dispatch_keyset.
|
||
// Why? keysets like autograd_dispatch_keyset are commonly used to remove
|
||
// autograd keys from a DispatchKeySet throughout the code base. However, you
|
||
// are only allowed to remove functionality bits from a keyset, not backend
|
||
// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
|
||
// Keys] for details. To be consistent and avoid confusion, we're explicitly
|
||
// setting up autograd_dispatch_keyset to not have any backend bits.
|
||
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
|
||
DispatchKey::AutogradFunctionality,
|
||
DispatchKey::AutogradOther,
|
||
DispatchKey::AutogradNestedTensor,
|
||
});
|
||
|
||
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
|
||
DispatchKey::AutocastCPU,
|
||
DispatchKey::AutocastMPS,
|
||
DispatchKey::AutocastCUDA,
|
||
DispatchKey::AutocastXPU,
|
||
DispatchKey::AutocastIPU,
|
||
DispatchKey::AutocastHPU,
|
||
DispatchKey::AutocastXLA,
|
||
DispatchKey::AutocastPrivateUse1,
|
||
DispatchKey::AutocastMTIA,
|
||
DispatchKey::AutocastMAIA,
|
||
});
|
||
|
||
// See Note [TLS Initialization]
|
||
constexpr DispatchKeySet default_included_set = DispatchKeySet({
|
||
DispatchKey::BackendSelect,
|
||
DispatchKey::ADInplaceOrView,
|
||
});
|
||
|
||
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
||
DispatchKey::AutocastCPU,
|
||
DispatchKey::AutocastMPS,
|
||
DispatchKey::AutocastCUDA,
|
||
DispatchKey::AutocastXPU,
|
||
DispatchKey::AutocastIPU,
|
||
DispatchKey::AutocastHPU,
|
||
DispatchKey::AutocastXLA,
|
||
DispatchKey::AutocastPrivateUse1,
|
||
DispatchKey::AutocastMTIA,
|
||
DispatchKey::AutocastMAIA,
|
||
});
|
||
|
||
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
|
||
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
|
||
|
||
constexpr DispatchKeySet python_ks = DispatchKeySet({
|
||
DispatchKey::Python,
|
||
DispatchKey::PythonTLSSnapshot,
|
||
});
|
||
|
||
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
|
||
|
||
constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr);
|
||
|
||
constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
|
||
|
||
// backend dispatch keys that map to DispatchKey::AutogradOther
|
||
// NB: keys in this set also get associated with CompositeImplicitAutograd
|
||
constexpr DispatchKeySet autogradother_backends =
|
||
DispatchKeySet(
|
||
// HIP and VE aren't in this list: they now have their own backend bits
|
||
// which means that they can now have their own Autograd keys.
|
||
// Technically, HIP will now redispatch to its own custom AutogradHIP
|
||
// slot in the runtime table.
|
||
{DispatchKey::FPGA,
|
||
DispatchKey::Vulkan,
|
||
DispatchKey::Metal,
|
||
DispatchKey::CustomRNGKeyId,
|
||
DispatchKey::MkldnnCPU,
|
||
// Sparse and Quantized backends also live here.
|
||
DispatchKey::Sparse,
|
||
DispatchKey::SparseCsr,
|
||
DispatchKey::Quantized})
|
||
// Including the backend bits because this keyset is used during op
|
||
// registration, which requires looping over all runtime autogradother
|
||
// backend keys.
|
||
| DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||
|
||
// The set of dispatch keys that come after autograd
|
||
// n.b. this relies on the fact that AutogradOther is currently the lowest
|
||
// Autograd key
|
||
constexpr DispatchKeySet after_autograd_keyset =
|
||
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
|
||
|
||
// The set of dispatch keys that come after ADInplaceOrView
|
||
constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet(
|
||
DispatchKeySet::FULL_AFTER,
|
||
c10::DispatchKey::ADInplaceOrView);
|
||
|
||
// The set of dispatch keys that come after Functionalize
|
||
constexpr DispatchKeySet after_func_keyset =
|
||
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize)
|
||
.remove(
|
||
// NOTE: we also need to remove ADInplaceOrView from the keyset when
|
||
// redispatching after the func kernels. This is because we're not
|
||
// calling the same op; we originally called an inplace op, and now
|
||
// we aren't. The original key calculation figured out which keys
|
||
// were Fallthrough based on the inplace op. That means that it did
|
||
// not include the ADInPlaceOrView kernel as a fallthrough key.
|
||
// However, we WANT the ADInPlaceOrView kernel to be ignored now
|
||
// that we're calling an out-of-place op. Re-invoking
|
||
// Dispatcher::call would re-run the Fallthrough key calculation and
|
||
// get us that, But at::redispatch is more performant. We can get
|
||
// away with it by explicitly removing the key here.
|
||
c10::DispatchKey::ADInplaceOrView);
|
||
|
||
constexpr DispatchKeySet backend_bitset_mask =
|
||
DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
|
||
|
||
constexpr auto inplace_or_view_ks =
|
||
DispatchKeySet(DispatchKey::ADInplaceOrView);
|
||
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
|
||
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
|
||
constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA);
|
||
constexpr auto autograd_maia_ks = DispatchKeySet(DispatchKey::AutogradMAIA);
|
||
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
|
||
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
|
||
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
|
||
constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
|
||
constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta);
|
||
constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS);
|
||
constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
|
||
constexpr auto autograd_privateuse1_ks =
|
||
DispatchKeySet(DispatchKey::AutogradPrivateUse1);
|
||
constexpr auto autograd_privateuse2_ks =
|
||
DispatchKeySet(DispatchKey::AutogradPrivateUse2);
|
||
constexpr auto autograd_privateuse3_ks =
|
||
DispatchKeySet(DispatchKey::AutogradPrivateUse3);
|
||
constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
|
||
constexpr auto autograd_nested =
|
||
DispatchKeySet(DispatchKey::AutogradNestedTensor);
|
||
// keyset corresponding to functorch keys that have their own dedicated
|
||
// TensorImpl subclass.
|
||
constexpr auto functorch_transforms_ks = DispatchKeySet(
|
||
{DispatchKey::FuncTorchBatched,
|
||
DispatchKey::FuncTorchVmapMode,
|
||
DispatchKey::Batched,
|
||
DispatchKey::VmapMode,
|
||
DispatchKey::FuncTorchGradWrapper});
|
||
|
||
constexpr auto functorch_batched_ks =
|
||
DispatchKeySet({DispatchKey::FuncTorchBatched});
|
||
|
||
// This keyset has:
|
||
// (1) the functionality bits corresponding to backends (dense, sparse,
|
||
// quantized) (2) all of the backend bits set
|
||
constexpr DispatchKeySet backend_functionality_keys =
|
||
DispatchKeySet({
|
||
DispatchKey::Dense,
|
||
DispatchKey::Quantized,
|
||
DispatchKey::Sparse,
|
||
DispatchKey::SparseCsr,
|
||
}) |
|
||
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||
|
||
struct OpTableOffsetAndMask {
|
||
uint16_t offset;
|
||
uint16_t backend_mask;
|
||
};
|
||
|
||
static_assert(
|
||
num_backends <= 16,
|
||
"Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
|
||
" that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
|
||
|
||
// true if t is a backend dispatch key
|
||
C10_API bool isBackendDispatchKey(DispatchKey t);
|
||
|
||
// Resolve alias dispatch key to DispatchKeySet if applicable
|
||
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
|
||
|
||
// Resolve alias dispatch key to DispatchKeySet if applicable,
|
||
// and check if k is a part of that set
|
||
C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
|
||
|
||
// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
|
||
// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
|
||
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
|
||
|
||
// Returns a DispatchKeySet of autograd related keys mapped to backend.
|
||
// for a given backend key, use the associated autograd key.
|
||
// for non-backend keys, use AutogradOther as a default.
|
||
// Note: it's convenient and fast to return a default here rather than (say)
|
||
// returning an std::optional<DispatchKey>, or throwing. But it makes callers
|
||
// responsible for either a) enforcing the invariant that only backend keys
|
||
// be passed as arguments, or b) interpreting our return value carefully.
|
||
inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
|
||
switch (t) {
|
||
case BackendComponent::CPUBit:
|
||
return inplace_or_view_ks | autograd_cpu_ks;
|
||
case BackendComponent::IPUBit:
|
||
return inplace_or_view_ks | autograd_ipu_ks;
|
||
case BackendComponent::MTIABit:
|
||
return inplace_or_view_ks | autograd_mtia_ks;
|
||
case BackendComponent::MAIABit:
|
||
return inplace_or_view_ks | autograd_maia_ks;
|
||
case BackendComponent::XPUBit:
|
||
return inplace_or_view_ks | autograd_xpu_ks;
|
||
case BackendComponent::CUDABit:
|
||
return inplace_or_view_ks | autograd_cuda_ks;
|
||
case BackendComponent::XLABit:
|
||
return inplace_or_view_ks | autograd_xla_ks;
|
||
case BackendComponent::LazyBit:
|
||
return inplace_or_view_ks | autograd_lazy_ks;
|
||
case BackendComponent::MetaBit:
|
||
return inplace_or_view_ks | autograd_meta_ks;
|
||
case BackendComponent::MPSBit:
|
||
return inplace_or_view_ks | autograd_mps_ks;
|
||
case BackendComponent::HPUBit:
|
||
return inplace_or_view_ks | autograd_hpu_ks;
|
||
case BackendComponent::PrivateUse1Bit:
|
||
return inplace_or_view_ks | autograd_privateuse1_ks;
|
||
case BackendComponent::PrivateUse2Bit:
|
||
return inplace_or_view_ks | autograd_privateuse2_ks;
|
||
case BackendComponent::PrivateUse3Bit:
|
||
return inplace_or_view_ks | autograd_privateuse3_ks;
|
||
default:
|
||
return inplace_or_view_ks | autograd_other_ks;
|
||
}
|
||
}
|
||
|
||
// Returns a DispatchKeySet of autocast related keys mapped to backend.
|
||
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
|
||
constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA);
|
||
constexpr auto autocast_maia_ks = DispatchKeySet(DispatchKey::AutocastMAIA);
|
||
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
|
||
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
|
||
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
|
||
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
|
||
constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
|
||
constexpr auto autocast_privateuse1_ks =
|
||
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
|
||
constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS);
|
||
switch (t) {
|
||
case BackendComponent::CPUBit:
|
||
return autocast_cpu_ks;
|
||
case BackendComponent::MTIABit:
|
||
return autocast_mtia_ks;
|
||
case BackendComponent::MAIABit:
|
||
return autocast_maia_ks;
|
||
case BackendComponent::XPUBit:
|
||
return autocast_xpu_ks;
|
||
case BackendComponent::IPUBit:
|
||
return autocast_ipu_ks;
|
||
case BackendComponent::HPUBit:
|
||
return autocast_hpu_ks;
|
||
case BackendComponent::CUDABit:
|
||
return autocast_cuda_ks;
|
||
case BackendComponent::XLABit:
|
||
return autocast_xla_ks;
|
||
case BackendComponent::PrivateUse1Bit:
|
||
return autocast_privateuse1_ks;
|
||
case BackendComponent::MPSBit:
|
||
return autocast_mps_ks;
|
||
default:
|
||
return DispatchKeySet();
|
||
}
|
||
}
|
||
|
||
// returns the "backend" DispatchKey of highest priority in the set.
|
||
// This is basically like highestBackendKey(), except that we have some
|
||
// "functionality" bits that correspond to backends (Sparse, Quantized)
|
||
inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) {
|
||
return (ks & backend_functionality_keys).highestPriorityTypeId();
|
||
}
|
||
|
||
// This API exists because we have a use case for checking
|
||
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
|
||
// in OperatorEntry.cpp but we disallow it in has() API.
|
||
C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);
|
||
|
||
// Historically, every tensor only had a single DispatchKey, and it was always
|
||
// something like CPU, and there wasn't any of this business where TLS
|
||
// could cause the DispatchKey of a tensor to change. But we still have some
|
||
// legacy code that is still using DispatchKey for things like instanceof
|
||
// checks; if at all possible, refactor the code to stop using DispatchKey in
|
||
// those cases.
|
||
inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
|
||
// NB: If you add any extra keys that can be stored in TensorImpl on
|
||
// top of existing "backend" keys like CPU/CUDA, you need to add it
|
||
// here. At the moment, autograd keys and ADInplaceOrView key need this
|
||
// treatment;
|
||
return (s - autograd_dispatch_keyset_with_ADInplaceOrView -
|
||
autocast_dispatch_keyset -
|
||
DispatchKeySet(
|
||
{DispatchKey::Functionalize,
|
||
DispatchKey::PythonTLSSnapshot,
|
||
DispatchKey::FuncTorchGradWrapper,
|
||
DispatchKey::FuncTorchVmapMode,
|
||
DispatchKey::FuncTorchBatched,
|
||
DispatchKey::Python}))
|
||
.highestPriorityTypeId();
|
||
}
|
||
|
||
template <class T>
|
||
using is_not_DispatchKeySet = std::negation<std::is_same<DispatchKeySet, T>>;
|
||
|
||
// Given a function type, constructs a function_traits type that drops the first
|
||
// parameter type if the first parameter is of type DispatchKeySet. NB:
|
||
// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
|
||
// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
|
||
// the Dispatcher] for details). If at any point in the future we need to expose
|
||
// this type to JIT, revisit the usage of this type alias.
|
||
template <class FuncType>
|
||
using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
|
||
typename guts::infer_function_traits_t<FuncType>::return_type,
|
||
typename std::conditional_t<
|
||
std::is_same_v<
|
||
DispatchKeySet,
|
||
typename guts::typelist::head_with_default_t<
|
||
void,
|
||
typename guts::infer_function_traits_t<
|
||
FuncType>::parameter_types>>,
|
||
guts::typelist::drop_if_nonempty_t<
|
||
typename guts::infer_function_traits_t<FuncType>::parameter_types,
|
||
1>,
|
||
typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
|
||
} // namespace c10
|