mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make SparseCsr a functionality dispatch key (#120703)
As in the title. To enable meta and fake tensor support for sparse compressed tensors in compliance with the meta/fake tensor support for sparse COO tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120703 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
eee040c939
commit
70d4d109f2
@ -43,8 +43,7 @@ constexpr auto kTensorSubclassLike =
|
||||
// no matter the backend component
|
||||
DispatchKey::Batched,
|
||||
DispatchKey::Sparse,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::SparseCsr,
|
||||
DispatchKey::Python}) |
|
||||
DispatchKeySet(BackendComponent::MetaBit);
|
||||
|
||||
|
@ -42,6 +42,10 @@ enum class Backend {
|
||||
SparseVE,
|
||||
SparseXPU,
|
||||
SparsePrivateUse1,
|
||||
SparseCsrHIP,
|
||||
SparseCsrVE,
|
||||
SparseCsrXPU,
|
||||
SparseCsrPrivateUse1,
|
||||
ORT,
|
||||
XLA,
|
||||
Vulkan,
|
||||
@ -100,6 +104,12 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
|
||||
return Backend::SparseCsrCPU;
|
||||
} else if (t == DispatchKey::SparseCsrCUDA) {
|
||||
return Backend::SparseCsrCUDA;
|
||||
} else if (t == DispatchKey::SparseCsrHIP) {
|
||||
return Backend::SparseCsrHIP;
|
||||
} else if (t == DispatchKey::SparseCsrVE) {
|
||||
return Backend::SparseCsrVE;
|
||||
} else if (t == DispatchKey::SparseCsrPrivateUse1) {
|
||||
return Backend::SparseCsrPrivateUse1;
|
||||
} else if (t == DispatchKey::MkldnnCPU) {
|
||||
return Backend::MkldnnCPU;
|
||||
} else if (t == DispatchKey::QuantizedCPU) {
|
||||
@ -112,6 +122,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
|
||||
return Backend::XPU;
|
||||
} else if (t == DispatchKey::SparseXPU) {
|
||||
return Backend::SparseXPU;
|
||||
} else if (t == DispatchKey::SparseCsrXPU) {
|
||||
return Backend::SparseCsrXPU;
|
||||
} else if (t == DispatchKey::QuantizedXPU) {
|
||||
return Backend::QuantizedXPU;
|
||||
} else if (t == DispatchKey::QuantizedPrivateUse1) {
|
||||
@ -154,6 +166,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
|
||||
return DispatchKey::XPU;
|
||||
case Backend::SparseXPU:
|
||||
return DispatchKey::SparseXPU;
|
||||
case Backend::SparseCsrXPU:
|
||||
return DispatchKey::SparseCsrXPU;
|
||||
case Backend::SparseCPU:
|
||||
return DispatchKey::SparseCPU;
|
||||
case Backend::SparseCUDA:
|
||||
@ -168,6 +182,12 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
|
||||
return DispatchKey::SparseCsrCPU;
|
||||
case Backend::SparseCsrCUDA:
|
||||
return DispatchKey::SparseCsrCUDA;
|
||||
case Backend::SparseCsrHIP:
|
||||
return DispatchKey::SparseCsrHIP;
|
||||
case Backend::SparseCsrVE:
|
||||
return DispatchKey::SparseCsrVE;
|
||||
case Backend::SparseCsrPrivateUse1:
|
||||
return DispatchKey::SparseCsrPrivateUse1;
|
||||
case Backend::MkldnnCPU:
|
||||
return DispatchKey::MkldnnCPU;
|
||||
case Backend::Vulkan:
|
||||
@ -226,10 +246,15 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
||||
return DeviceType::HIP;
|
||||
case Backend::SparseVE:
|
||||
return DeviceType::VE;
|
||||
case Backend::SparseCsrHIP:
|
||||
return DeviceType::HIP;
|
||||
case Backend::SparseCsrVE:
|
||||
return DeviceType::VE;
|
||||
case Backend::IPU:
|
||||
return DeviceType::IPU;
|
||||
case Backend::XPU:
|
||||
case Backend::SparseXPU:
|
||||
case Backend::SparseCsrXPU:
|
||||
case Backend::QuantizedXPU:
|
||||
return DeviceType::XPU;
|
||||
case Backend::Vulkan:
|
||||
@ -246,6 +271,7 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
||||
return DeviceType::MTIA;
|
||||
case Backend::PrivateUse1:
|
||||
case Backend::SparsePrivateUse1:
|
||||
case Backend::SparseCsrPrivateUse1:
|
||||
case Backend::QuantizedPrivateUse1:
|
||||
return DeviceType::PrivateUse1;
|
||||
case Backend::Undefined:
|
||||
@ -296,6 +322,14 @@ static inline const char* toString(Backend b) {
|
||||
return "SparseCsrCPU";
|
||||
case Backend::SparseCsrCUDA:
|
||||
return "SparseCsrCUDA";
|
||||
case Backend::SparseCsrHIP:
|
||||
return "SparseCsrHIP";
|
||||
case Backend::SparseCsrVE:
|
||||
return "SparseCsrVE";
|
||||
case Backend::SparseCsrXPU:
|
||||
return "SparseCsrXPU";
|
||||
case Backend::SparseCsrPrivateUse1:
|
||||
return "SparseCsrPrivateUse1";
|
||||
case Backend::MkldnnCPU:
|
||||
return "MkldnnCPU";
|
||||
case Backend::Vulkan:
|
||||
@ -339,8 +373,12 @@ static inline bool isSparse(Backend b) {
|
||||
|
||||
static inline bool isSparseCsr(Backend b) {
|
||||
switch (b) {
|
||||
case Backend::SparseCsrXPU:
|
||||
case Backend::SparseCsrCPU:
|
||||
case Backend::SparseCsrCUDA:
|
||||
case Backend::SparseCsrHIP:
|
||||
case Backend::SparseCsrVE:
|
||||
case Backend::SparseCsrPrivateUse1:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -91,10 +91,9 @@ const char* toString(DispatchKey t) {
|
||||
|
||||
case DispatchKey::Sparse:
|
||||
return "Sparse";
|
||||
case DispatchKey::SparseCsrCPU:
|
||||
return "SparseCsrCPU";
|
||||
case DispatchKey::SparseCsrCUDA:
|
||||
return "SparseCsrCUDA";
|
||||
|
||||
case DispatchKey::SparseCsr:
|
||||
return "SparseCsr";
|
||||
|
||||
case DispatchKey::NestedTensor:
|
||||
return "NestedTensor";
|
||||
@ -274,8 +273,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
|
||||
{"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
|
||||
{"Sparse", c10::DispatchKey::Sparse},
|
||||
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
|
||||
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
|
||||
{"SparseCsr", c10::DispatchKey::SparseCsr},
|
||||
{"BackendSelect", c10::DispatchKey::BackendSelect},
|
||||
{"Python", c10::DispatchKey::Python},
|
||||
{"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},
|
||||
@ -346,6 +344,14 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"SparseMeta", c10::DispatchKey::SparseMeta},
|
||||
{"SparsePrivateUse1", c10::DispatchKey::SparsePrivateUse1},
|
||||
|
||||
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
|
||||
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
|
||||
{"SparseCsrHIP", c10::DispatchKey::SparseCsrHIP},
|
||||
{"SparseCsrXPU", c10::DispatchKey::SparseCsrXPU},
|
||||
{"SparseCsrVE", c10::DispatchKey::SparseCsrVE},
|
||||
{"SparseCsrMeta", c10::DispatchKey::SparseCsrMeta},
|
||||
{"SparseCsrPrivateUse1", c10::DispatchKey::SparseCsrPrivateUse1},
|
||||
|
||||
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
|
||||
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
|
||||
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
|
||||
|
@ -57,6 +57,7 @@ namespace c10 {
|
||||
_(Dense, ) \
|
||||
_(Quantized, Quantized) \
|
||||
_(Sparse, Sparse) \
|
||||
_(SparseCsr, SparseCsr) \
|
||||
_(NestedTensor, NestedTensor) \
|
||||
_(AutogradFunctionality, Autograd)
|
||||
|
||||
@ -217,9 +218,7 @@ enum class DispatchKey : uint16_t {
|
||||
// See [Note: Per-Backend Functionality Dispatch Keys]
|
||||
Sparse,
|
||||
|
||||
// TODO: Make SparseCsr a functionality key
|
||||
SparseCsrCPU,
|
||||
SparseCsrCUDA,
|
||||
SparseCsr,
|
||||
|
||||
NestedTensor,
|
||||
|
||||
@ -548,7 +547,8 @@ constexpr bool isAliasDispatchKey(DispatchKey k) {
|
||||
|
||||
constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
|
||||
if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
|
||||
k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality ||
|
||||
k == DispatchKey::Sparse || k == DispatchKey::SparseCsr ||
|
||||
k == DispatchKey::AutogradFunctionality ||
|
||||
k == DispatchKey::NestedTensor) {
|
||||
return true;
|
||||
} else {
|
||||
@ -635,6 +635,12 @@ constexpr BackendComponent toBackendComponent(DispatchKey k) {
|
||||
return static_cast<BackendComponent>(
|
||||
static_cast<uint8_t>(k) -
|
||||
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
|
||||
} else if (
|
||||
k >= DispatchKey::StartOfSparseCsrBackends &&
|
||||
k <= DispatchKey::EndOfSparseCsrBackends) {
|
||||
return static_cast<BackendComponent>(
|
||||
static_cast<uint8_t>(k) -
|
||||
static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends));
|
||||
} else if (
|
||||
k >= DispatchKey::StartOfNestedTensorBackends &&
|
||||
k <= DispatchKey::EndOfNestedTensorBackends) {
|
||||
@ -662,6 +668,8 @@ constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
|
||||
return DispatchKey::Quantized;
|
||||
} else if (k <= DispatchKey::EndOfSparseBackends) {
|
||||
return DispatchKey::Sparse;
|
||||
} else if (k <= DispatchKey::EndOfSparseCsrBackends) {
|
||||
return DispatchKey::SparseCsr;
|
||||
} else if (k <= DispatchKey::EndOfNestedTensorBackends) {
|
||||
return DispatchKey::NestedTensor;
|
||||
} else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
|
||||
@ -692,6 +700,11 @@ constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
|
||||
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
|
||||
static_cast<uint8_t>(backend_k));
|
||||
}
|
||||
if (functionality_k == DispatchKey::SparseCsr) {
|
||||
return static_cast<DispatchKey>(
|
||||
static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends) +
|
||||
static_cast<uint8_t>(backend_k));
|
||||
}
|
||||
if (functionality_k == DispatchKey::Quantized) {
|
||||
return static_cast<DispatchKey>(
|
||||
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
|
||||
|
@ -79,6 +79,7 @@ C10_ALWAYS_INLINE static const std::
|
||||
// we have:
|
||||
// - "Dense": CPU, CUDA, XLA, ... (~12 keys)
|
||||
// - "Sparse": SparseCPU, SparseCUDA, ...
|
||||
// - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ...
|
||||
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
|
||||
// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
|
||||
// The problem is that total number of keys grows quadratically with [#
|
||||
@ -92,7 +93,7 @@ C10_ALWAYS_INLINE static const std::
|
||||
// (1) "Building block" keys
|
||||
// (a) backends: Everything in the BackendComponent enum (e.g. CPUBit,
|
||||
// CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys
|
||||
// (e.g. AutogradFunctionality, Sparse, Dense)
|
||||
// (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense)
|
||||
// (2) "Runtime" keys
|
||||
// (a) "non-customizable backends" (e.g. FPGA)
|
||||
// (b) "non-customizable functionalities" (e.g. Functionalize)
|
||||
@ -116,14 +117,16 @@ C10_ALWAYS_INLINE static const std::
|
||||
// Backend keys and functionality keys that count as "building blocks" will
|
||||
// contribute to a full cross product of functionality that can be overriden.
|
||||
//
|
||||
// For example, right now we have at least 12 "backend" building blocks (CPU,
|
||||
// CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense,
|
||||
// Sparse, Quantized, AutogradFunctionality, ...). These keys together allow
|
||||
// every dispatcher operator to be customized in up to 12*4 different ways. Each
|
||||
// of those requires a slot in the operator table of every dispatcher operator.
|
||||
// Not every piece of functionality necessarily needs to be customizable
|
||||
// per-backend, and not every backend necessarily needs to be able to customize
|
||||
// every type of functionality.
|
||||
// For example, right now we have at least 12 "backend" building
|
||||
// blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality"
|
||||
// building blocks (Dense, Sparse, SparseCsr, Quantized,
|
||||
// AutogradFunctionality, ...). These keys together allow every
|
||||
// dispatcher operator to be customized in up to 12*4 different
|
||||
// ways. Each of those requires a slot in the operator table of every
|
||||
// dispatcher operator. Not every piece of functionality necessarily
|
||||
// needs to be customizable per-backend, and not every backend
|
||||
// necessarily needs to be able to customize every type of
|
||||
// functionality.
|
||||
//
|
||||
//
|
||||
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
|
||||
@ -307,6 +310,7 @@ class DispatchKeySet final {
|
||||
DispatchKey::Dense,
|
||||
DispatchKey::Quantized,
|
||||
DispatchKey::Sparse,
|
||||
DispatchKey::SparseCsr,
|
||||
DispatchKey::AutogradFunctionality,
|
||||
})
|
||||
.repr_) == 0));
|
||||
@ -685,8 +689,7 @@ constexpr DispatchKeySet python_ks = DispatchKeySet({
|
||||
|
||||
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
|
||||
|
||||
constexpr DispatchKeySet sparse_csr_ks =
|
||||
DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA});
|
||||
constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr);
|
||||
|
||||
constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
|
||||
|
||||
@ -702,12 +705,11 @@ constexpr DispatchKeySet autogradother_backends =
|
||||
DispatchKey::ORT,
|
||||
DispatchKey::Vulkan,
|
||||
DispatchKey::Metal,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::CustomRNGKeyId,
|
||||
DispatchKey::MkldnnCPU,
|
||||
// Sparse and Quantized backends also live here.
|
||||
DispatchKey::Sparse,
|
||||
DispatchKey::SparseCsr,
|
||||
DispatchKey::Quantized})
|
||||
// Including the backend bits because this keyset is used during op
|
||||
// registration, which requires looping over all runtime autogradother
|
||||
@ -785,6 +787,7 @@ constexpr DispatchKeySet backend_functionality_keys =
|
||||
DispatchKey::Dense,
|
||||
DispatchKey::Quantized,
|
||||
DispatchKey::Sparse,
|
||||
DispatchKey::SparseCsr,
|
||||
}) |
|
||||
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||
|
||||
|
@ -41,9 +41,12 @@ inline Layout layout_from_backend(Backend backend) {
|
||||
return Layout::Mkldnn;
|
||||
case Backend::SparseCsrCPU:
|
||||
case Backend::SparseCsrCUDA:
|
||||
case Backend::SparseCsrHIP:
|
||||
case Backend::SparseCsrVE:
|
||||
case Backend::SparseCsrXPU:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot map Backend SparseCsrCPU|SparseCsrCUDA to a unique layout.");
|
||||
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout.");
|
||||
default:
|
||||
return Layout::Strided;
|
||||
}
|
||||
|
@ -1066,6 +1066,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return layout() == kSparseCsr;
|
||||
}
|
||||
|
||||
// Whether a tensor is sparse CSR/CSC/BSR/BSC or not.
|
||||
bool is_sparse_compressed() const {
|
||||
return key_set_.has_all(c10::sparse_csr_ks);
|
||||
}
|
||||
|
||||
bool is_quantized() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
@ -1269,7 +1274,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return kStrided;
|
||||
} else if (is_sparse()) {
|
||||
return kSparse;
|
||||
} else if (key_set_.has_any(c10::sparse_csr_ks)) {
|
||||
} else if (is_sparse_compressed()) {
|
||||
// Typically, the tensor dispatch keys define the tensor layout
|
||||
// uniquely. This allows using non-virtual layout method for
|
||||
// better performance. However, when tensor's layout depends,
|
||||
@ -2035,8 +2040,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
|
||||
return ts.has_any(sparse_k) && ts.has_any(sparse_backends);
|
||||
};
|
||||
auto is_sparse_compressed = [](DispatchKeySet ts) {
|
||||
constexpr auto sparse_compressed_k =
|
||||
DispatchKeySet(DispatchKey::SparseCsr);
|
||||
return ts.has_any(sparse_compressed_k);
|
||||
};
|
||||
return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) ||
|
||||
(is_sparse(key_set_) && is_sparse(from));
|
||||
(is_sparse(key_set_) && is_sparse(from)) ||
|
||||
(is_sparse_compressed(key_set_) && is_sparse_compressed(from));
|
||||
;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -359,10 +359,18 @@ struct C10_API TensorOptions {
|
||||
return layout_ == c10::Layout::Sparse;
|
||||
}
|
||||
|
||||
/// Returns if the layout is sparse CSR, deprecated, use
|
||||
/// is_sparse_compressed() instead
|
||||
bool is_sparse_csr() const {
|
||||
return layout_ == c10::Layout::SparseCsr;
|
||||
}
|
||||
|
||||
bool is_sparse_compressed() const {
|
||||
return layout_ == c10::Layout::SparseCsr ||
|
||||
layout_ == c10::Layout::SparseCsc ||
|
||||
layout_ == c10::Layout::SparseBsr || layout_ == c10::Layout::SparseBsc;
|
||||
}
|
||||
|
||||
// For compatibility with legacy tensor.type() comparisons
|
||||
bool type_equal(const TensorOptions& other) const {
|
||||
return computeDispatchKey() == other.computeDispatchKey() &&
|
||||
@ -696,12 +704,15 @@ inline DispatchKey computeDispatchKey(
|
||||
case Layout::SparseBsr:
|
||||
case Layout::SparseBsc:
|
||||
switch (device_.type()) {
|
||||
case c10::DeviceType::CPU:
|
||||
return DispatchKey::SparseCsrCPU;
|
||||
case c10::DeviceType::CUDA:
|
||||
return DispatchKey::SparseCsrCUDA;
|
||||
#define DO_CASE(device, _) \
|
||||
case c10::DeviceType::device: { \
|
||||
return DispatchKey::SparseCsr##device; \
|
||||
}
|
||||
C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
|
||||
#undef DO_CASE
|
||||
default:
|
||||
AT_ERROR(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"Unsupported device type for ",
|
||||
layout_,
|
||||
" layout: ",
|
||||
@ -718,13 +729,11 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
|
||||
C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
|
||||
#undef DO_CASE
|
||||
return Layout::Sparse;
|
||||
case DispatchKey::SparseCsrCPU:
|
||||
case DispatchKey::SparseCsrCUDA:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot map DispatchKey ",
|
||||
dispatch_key,
|
||||
" to a unique layout.");
|
||||
#define DO_CASE(bc, _) case DispatchKey::SparseCsr##bc:
|
||||
C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
|
||||
#undef DO_CASE
|
||||
TORCH_CHECK(
|
||||
false, "Cannot map DispatchKey ", dispatch_key, " to a unique layout.");
|
||||
case DispatchKey::MkldnnCPU:
|
||||
return Layout::Mkldnn;
|
||||
default:
|
||||
|
@ -223,12 +223,14 @@ TEST(DispatchKeySet, DoubletonPerBackend) {
|
||||
// Skip these because they aren't real keys.
|
||||
if (tid1 == DispatchKey::StartOfDenseBackends ||
|
||||
tid1 == DispatchKey::StartOfSparseBackends ||
|
||||
tid1 == DispatchKey::StartOfSparseCsrBackends ||
|
||||
tid1 == DispatchKey::StartOfQuantizedBackends ||
|
||||
tid1 == DispatchKey::StartOfNestedTensorBackends ||
|
||||
tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
|
||||
continue;
|
||||
if (tid2 == DispatchKey::StartOfDenseBackends ||
|
||||
tid2 == DispatchKey::StartOfSparseBackends ||
|
||||
tid2 == DispatchKey::StartOfSparseCsrBackends ||
|
||||
tid2 == DispatchKey::StartOfQuantizedBackends ||
|
||||
tid2 == DispatchKey::StartOfNestedTensorBackends ||
|
||||
tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
|
||||
@ -326,6 +328,12 @@ TEST(DispatchKeySet, getHighestPriorityBackendTypeId) {
|
||||
ASSERT_EQ(
|
||||
DispatchKey::SparseCUDA, c10::highestPriorityBackendTypeId(sparse_cuda));
|
||||
|
||||
DispatchKeySet sparse_compressed_cuda(
|
||||
{DispatchKey::Functionalize, DispatchKey::SparseCsrCUDA});
|
||||
ASSERT_EQ(
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
c10::highestPriorityBackendTypeId(sparse_compressed_cuda));
|
||||
|
||||
// quantizedCUDA has higher priority than CUDA
|
||||
DispatchKeySet quantized_cuda(
|
||||
{DispatchKey::CUDA, DispatchKey::QuantizedCUDA});
|
||||
@ -417,6 +425,7 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
|
||||
k == DispatchKey::StartOfDenseBackends ||
|
||||
k == DispatchKey::StartOfQuantizedBackends ||
|
||||
k == DispatchKey::StartOfSparseBackends ||
|
||||
k == DispatchKey::StartOfSparseCsrBackends ||
|
||||
k == DispatchKey::StartOfNestedTensorBackends ||
|
||||
k == DispatchKey::StartOfAutogradFunctionalityBackends)
|
||||
continue;
|
||||
|
@ -54,7 +54,14 @@ DEFAULT_KERNEL_NAMESPACE = "at::native"
|
||||
|
||||
# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
|
||||
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
|
||||
FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"]
|
||||
FUNCTIONALITY_KEYS = [
|
||||
"",
|
||||
"Quantized",
|
||||
"Sparse",
|
||||
"SparseCsr",
|
||||
"NestedTensor",
|
||||
"Autograd",
|
||||
]
|
||||
|
||||
# This list guards dispatches that can be used in derivatives.yaml
|
||||
# For now we omit AutogradFunctionality and AutogradOther
|
||||
@ -82,8 +89,7 @@ class DispatchKey(Enum):
|
||||
CustomRNGKeyId = auto()
|
||||
MkldnnCPU = auto()
|
||||
Sparse = auto()
|
||||
SparseCsrCPU = auto()
|
||||
SparseCsrCUDA = auto()
|
||||
SparseCsr = auto()
|
||||
NestedTensor = auto()
|
||||
Dense = auto()
|
||||
|
||||
@ -165,6 +171,21 @@ class DispatchKey(Enum):
|
||||
SparsePrivateUse1 = auto()
|
||||
SparsePrivateUse2 = auto()
|
||||
SparsePrivateUse3 = auto()
|
||||
SparseCsrCPU = auto()
|
||||
SparseCsrCUDA = auto()
|
||||
SparseCsrHIP = auto()
|
||||
SparseCsrXLA = auto()
|
||||
SparseCsrMTIA = auto()
|
||||
SparseCsrMPS = auto()
|
||||
SparseCsrIPU = auto()
|
||||
SparseCsrXPU = auto()
|
||||
SparseCsrHPU = auto()
|
||||
SparseCsrVE = auto()
|
||||
SparseCsrLazy = auto()
|
||||
SparseCsrMeta = auto()
|
||||
SparseCsrPrivateUse1 = auto()
|
||||
SparseCsrPrivateUse2 = auto()
|
||||
SparseCsrPrivateUse3 = auto()
|
||||
NestedTensorCPU = auto()
|
||||
NestedTensorCUDA = auto()
|
||||
NestedTensorHIP = auto()
|
||||
@ -260,6 +281,7 @@ dispatch_keys = [
|
||||
# kernels
|
||||
DispatchKey.Meta,
|
||||
DispatchKey.SparseMeta,
|
||||
DispatchKey.SparseCsrMeta,
|
||||
DispatchKey.QuantizedMeta,
|
||||
DispatchKey.NestedTensorMeta,
|
||||
DispatchKey.ZeroTensor,
|
||||
|
Reference in New Issue
Block a user