Make SparseCsr a functionality dispatch key (#120703)

As in the title.

To enable meta and fake tensor support for sparse compressed tensors in compliance with the meta/fake tensor support for sparse COO tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120703
Approved by: https://github.com/ezyang
This commit is contained in:
Pearu Peterson
2024-03-01 10:06:21 +02:00
committed by PyTorch MergeBot
parent eee040c939
commit 70d4d109f2
10 changed files with 157 additions and 43 deletions

View File

@ -43,8 +43,7 @@ constexpr auto kTensorSubclassLike =
// no matter the backend component
DispatchKey::Batched,
DispatchKey::Sparse,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::SparseCsr,
DispatchKey::Python}) |
DispatchKeySet(BackendComponent::MetaBit);

View File

@ -42,6 +42,10 @@ enum class Backend {
SparseVE,
SparseXPU,
SparsePrivateUse1,
SparseCsrHIP,
SparseCsrVE,
SparseCsrXPU,
SparseCsrPrivateUse1,
ORT,
XLA,
Vulkan,
@ -100,6 +104,12 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::SparseCsrCPU;
} else if (t == DispatchKey::SparseCsrCUDA) {
return Backend::SparseCsrCUDA;
} else if (t == DispatchKey::SparseCsrHIP) {
return Backend::SparseCsrHIP;
} else if (t == DispatchKey::SparseCsrVE) {
return Backend::SparseCsrVE;
} else if (t == DispatchKey::SparseCsrPrivateUse1) {
return Backend::SparseCsrPrivateUse1;
} else if (t == DispatchKey::MkldnnCPU) {
return Backend::MkldnnCPU;
} else if (t == DispatchKey::QuantizedCPU) {
@ -112,6 +122,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::XPU;
} else if (t == DispatchKey::SparseXPU) {
return Backend::SparseXPU;
} else if (t == DispatchKey::SparseCsrXPU) {
return Backend::SparseCsrXPU;
} else if (t == DispatchKey::QuantizedXPU) {
return Backend::QuantizedXPU;
} else if (t == DispatchKey::QuantizedPrivateUse1) {
@ -154,6 +166,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::XPU;
case Backend::SparseXPU:
return DispatchKey::SparseXPU;
case Backend::SparseCsrXPU:
return DispatchKey::SparseCsrXPU;
case Backend::SparseCPU:
return DispatchKey::SparseCPU;
case Backend::SparseCUDA:
@ -168,6 +182,12 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::SparseCsrCPU;
case Backend::SparseCsrCUDA:
return DispatchKey::SparseCsrCUDA;
case Backend::SparseCsrHIP:
return DispatchKey::SparseCsrHIP;
case Backend::SparseCsrVE:
return DispatchKey::SparseCsrVE;
case Backend::SparseCsrPrivateUse1:
return DispatchKey::SparseCsrPrivateUse1;
case Backend::MkldnnCPU:
return DispatchKey::MkldnnCPU;
case Backend::Vulkan:
@ -226,10 +246,15 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::HIP;
case Backend::SparseVE:
return DeviceType::VE;
case Backend::SparseCsrHIP:
return DeviceType::HIP;
case Backend::SparseCsrVE:
return DeviceType::VE;
case Backend::IPU:
return DeviceType::IPU;
case Backend::XPU:
case Backend::SparseXPU:
case Backend::SparseCsrXPU:
case Backend::QuantizedXPU:
return DeviceType::XPU;
case Backend::Vulkan:
@ -246,6 +271,7 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::MTIA;
case Backend::PrivateUse1:
case Backend::SparsePrivateUse1:
case Backend::SparseCsrPrivateUse1:
case Backend::QuantizedPrivateUse1:
return DeviceType::PrivateUse1;
case Backend::Undefined:
@ -296,6 +322,14 @@ static inline const char* toString(Backend b) {
return "SparseCsrCPU";
case Backend::SparseCsrCUDA:
return "SparseCsrCUDA";
case Backend::SparseCsrHIP:
return "SparseCsrHIP";
case Backend::SparseCsrVE:
return "SparseCsrVE";
case Backend::SparseCsrXPU:
return "SparseCsrXPU";
case Backend::SparseCsrPrivateUse1:
return "SparseCsrPrivateUse1";
case Backend::MkldnnCPU:
return "MkldnnCPU";
case Backend::Vulkan:
@ -339,8 +373,12 @@ static inline bool isSparse(Backend b) {
static inline bool isSparseCsr(Backend b) {
switch (b) {
case Backend::SparseCsrXPU:
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
case Backend::SparseCsrHIP:
case Backend::SparseCsrVE:
case Backend::SparseCsrPrivateUse1:
return true;
default:
return false;

View File

@ -91,10 +91,9 @@ const char* toString(DispatchKey t) {
case DispatchKey::Sparse:
return "Sparse";
case DispatchKey::SparseCsrCPU:
return "SparseCsrCPU";
case DispatchKey::SparseCsrCUDA:
return "SparseCsrCUDA";
case DispatchKey::SparseCsr:
return "SparseCsr";
case DispatchKey::NestedTensor:
return "NestedTensor";
@ -274,8 +273,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
{"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
{"Sparse", c10::DispatchKey::Sparse},
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
{"SparseCsr", c10::DispatchKey::SparseCsr},
{"BackendSelect", c10::DispatchKey::BackendSelect},
{"Python", c10::DispatchKey::Python},
{"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},
@ -346,6 +344,14 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"SparseMeta", c10::DispatchKey::SparseMeta},
{"SparsePrivateUse1", c10::DispatchKey::SparsePrivateUse1},
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
{"SparseCsrHIP", c10::DispatchKey::SparseCsrHIP},
{"SparseCsrXPU", c10::DispatchKey::SparseCsrXPU},
{"SparseCsrVE", c10::DispatchKey::SparseCsrVE},
{"SparseCsrMeta", c10::DispatchKey::SparseCsrMeta},
{"SparseCsrPrivateUse1", c10::DispatchKey::SparseCsrPrivateUse1},
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
{"AutogradXLA", c10::DispatchKey::AutogradXLA},

View File

@ -57,6 +57,7 @@ namespace c10 {
_(Dense, ) \
_(Quantized, Quantized) \
_(Sparse, Sparse) \
_(SparseCsr, SparseCsr) \
_(NestedTensor, NestedTensor) \
_(AutogradFunctionality, Autograd)
@ -217,9 +218,7 @@ enum class DispatchKey : uint16_t {
// See [Note: Per-Backend Functionality Dispatch Keys]
Sparse,
// TODO: Make SparseCsr a functionality key
SparseCsrCPU,
SparseCsrCUDA,
SparseCsr,
NestedTensor,
@ -548,7 +547,8 @@ constexpr bool isAliasDispatchKey(DispatchKey k) {
constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality ||
k == DispatchKey::Sparse || k == DispatchKey::SparseCsr ||
k == DispatchKey::AutogradFunctionality ||
k == DispatchKey::NestedTensor) {
return true;
} else {
@ -635,6 +635,12 @@ constexpr BackendComponent toBackendComponent(DispatchKey k) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
} else if (
k >= DispatchKey::StartOfSparseCsrBackends &&
k <= DispatchKey::EndOfSparseCsrBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends));
} else if (
k >= DispatchKey::StartOfNestedTensorBackends &&
k <= DispatchKey::EndOfNestedTensorBackends) {
@ -662,6 +668,8 @@ constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
return DispatchKey::Quantized;
} else if (k <= DispatchKey::EndOfSparseBackends) {
return DispatchKey::Sparse;
} else if (k <= DispatchKey::EndOfSparseCsrBackends) {
return DispatchKey::SparseCsr;
} else if (k <= DispatchKey::EndOfNestedTensorBackends) {
return DispatchKey::NestedTensor;
} else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
@ -692,6 +700,11 @@ constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::SparseCsr) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::Quantized) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +

View File

@ -79,6 +79,7 @@ C10_ALWAYS_INLINE static const std::
// we have:
// - "Dense": CPU, CUDA, XLA, ... (~12 keys)
// - "Sparse": SparseCPU, SparseCUDA, ...
// - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ...
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
// The problem is that total number of keys grows quadratically with [#
@ -92,7 +93,7 @@ C10_ALWAYS_INLINE static const std::
// (1) "Building block" keys
// (a) backends: Everything in the BackendComponent enum (e.g. CPUBit,
// CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys
// (e.g. AutogradFunctionality, Sparse, Dense)
// (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense)
// (2) "Runtime" keys
// (a) "non-customizable backends" (e.g. FPGA)
// (b) "non-customizable functionalities" (e.g. Functionalize)
@ -116,14 +117,16 @@ C10_ALWAYS_INLINE static const std::
// Backend keys and functionality keys that count as "building blocks" will
// contribute to a full cross product of functionality that can be overriden.
//
// For example, right now we have at least 12 "backend" building blocks (CPU,
// CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense,
// Sparse, Quantized, AutogradFunctionality, ...). These keys together allow
// every dispatcher operator to be customized in up to 12*4 different ways. Each
// of those requires a slot in the operator table of every dispatcher operator.
// Not every piece of functionality necessarily needs to be customizable
// per-backend, and not every backend necessarily needs to be able to customize
// every type of functionality.
// For example, right now we have at least 12 "backend" building
// blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality"
// building blocks (Dense, Sparse, SparseCsr, Quantized,
// AutogradFunctionality, ...). These keys together allow every
// dispatcher operator to be customized in up to 12*4 different
// ways. Each of those requires a slot in the operator table of every
// dispatcher operator. Not every piece of functionality necessarily
// needs to be customizable per-backend, and not every backend
// necessarily needs to be able to customize every type of
// functionality.
//
//
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
@ -307,6 +310,7 @@ class DispatchKeySet final {
DispatchKey::Dense,
DispatchKey::Quantized,
DispatchKey::Sparse,
DispatchKey::SparseCsr,
DispatchKey::AutogradFunctionality,
})
.repr_) == 0));
@ -685,8 +689,7 @@ constexpr DispatchKeySet python_ks = DispatchKeySet({
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
constexpr DispatchKeySet sparse_csr_ks =
DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA});
constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr);
constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
@ -702,12 +705,11 @@ constexpr DispatchKeySet autogradother_backends =
DispatchKey::ORT,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
// Sparse and Quantized backends also live here.
DispatchKey::Sparse,
DispatchKey::SparseCsr,
DispatchKey::Quantized})
// Including the backend bits because this keyset is used during op
// registration, which requires looping over all runtime autogradother
@ -785,6 +787,7 @@ constexpr DispatchKeySet backend_functionality_keys =
DispatchKey::Dense,
DispatchKey::Quantized,
DispatchKey::Sparse,
DispatchKey::SparseCsr,
}) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);

View File

@ -41,9 +41,12 @@ inline Layout layout_from_backend(Backend backend) {
return Layout::Mkldnn;
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
case Backend::SparseCsrHIP:
case Backend::SparseCsrVE:
case Backend::SparseCsrXPU:
TORCH_CHECK(
false,
"Cannot map Backend SparseCsrCPU|SparseCsrCUDA to a unique layout.");
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout.");
default:
return Layout::Strided;
}

View File

@ -1066,6 +1066,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return layout() == kSparseCsr;
}
// Whether a tensor is sparse CSR/CSC/BSR/BSC or not.
bool is_sparse_compressed() const {
return key_set_.has_all(c10::sparse_csr_ks);
}
bool is_quantized() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
@ -1269,7 +1274,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return kStrided;
} else if (is_sparse()) {
return kSparse;
} else if (key_set_.has_any(c10::sparse_csr_ks)) {
} else if (is_sparse_compressed()) {
// Typically, the tensor dispatch keys define the tensor layout
// uniquely. This allows using non-virtual layout method for
// better performance. However, when tensor's layout depends,
@ -2035,8 +2040,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
return ts.has_any(sparse_k) && ts.has_any(sparse_backends);
};
auto is_sparse_compressed = [](DispatchKeySet ts) {
constexpr auto sparse_compressed_k =
DispatchKeySet(DispatchKey::SparseCsr);
return ts.has_any(sparse_compressed_k);
};
return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) ||
(is_sparse(key_set_) && is_sparse(from));
(is_sparse(key_set_) && is_sparse(from)) ||
(is_sparse_compressed(key_set_) && is_sparse_compressed(from));
;
}
private:

View File

@ -359,10 +359,18 @@ struct C10_API TensorOptions {
return layout_ == c10::Layout::Sparse;
}
/// Returns if the layout is sparse CSR, deprecated, use
/// is_sparse_compressed() instead
bool is_sparse_csr() const {
return layout_ == c10::Layout::SparseCsr;
}
bool is_sparse_compressed() const {
return layout_ == c10::Layout::SparseCsr ||
layout_ == c10::Layout::SparseCsc ||
layout_ == c10::Layout::SparseBsr || layout_ == c10::Layout::SparseBsc;
}
// For compatibility with legacy tensor.type() comparisons
bool type_equal(const TensorOptions& other) const {
return computeDispatchKey() == other.computeDispatchKey() &&
@ -696,12 +704,15 @@ inline DispatchKey computeDispatchKey(
case Layout::SparseBsr:
case Layout::SparseBsc:
switch (device_.type()) {
case c10::DeviceType::CPU:
return DispatchKey::SparseCsrCPU;
case c10::DeviceType::CUDA:
return DispatchKey::SparseCsrCUDA;
#define DO_CASE(device, _) \
case c10::DeviceType::device: { \
return DispatchKey::SparseCsr##device; \
}
C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
#undef DO_CASE
default:
AT_ERROR(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for ",
layout_,
" layout: ",
@ -718,13 +729,11 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
#undef DO_CASE
return Layout::Sparse;
case DispatchKey::SparseCsrCPU:
case DispatchKey::SparseCsrCUDA:
TORCH_CHECK(
false,
"Cannot map DispatchKey ",
dispatch_key,
" to a unique layout.");
#define DO_CASE(bc, _) case DispatchKey::SparseCsr##bc:
C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
#undef DO_CASE
TORCH_CHECK(
false, "Cannot map DispatchKey ", dispatch_key, " to a unique layout.");
case DispatchKey::MkldnnCPU:
return Layout::Mkldnn;
default:

View File

@ -223,12 +223,14 @@ TEST(DispatchKeySet, DoubletonPerBackend) {
// Skip these because they aren't real keys.
if (tid1 == DispatchKey::StartOfDenseBackends ||
tid1 == DispatchKey::StartOfSparseBackends ||
tid1 == DispatchKey::StartOfSparseCsrBackends ||
tid1 == DispatchKey::StartOfQuantizedBackends ||
tid1 == DispatchKey::StartOfNestedTensorBackends ||
tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
continue;
if (tid2 == DispatchKey::StartOfDenseBackends ||
tid2 == DispatchKey::StartOfSparseBackends ||
tid2 == DispatchKey::StartOfSparseCsrBackends ||
tid2 == DispatchKey::StartOfQuantizedBackends ||
tid2 == DispatchKey::StartOfNestedTensorBackends ||
tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
@ -326,6 +328,12 @@ TEST(DispatchKeySet, getHighestPriorityBackendTypeId) {
ASSERT_EQ(
DispatchKey::SparseCUDA, c10::highestPriorityBackendTypeId(sparse_cuda));
DispatchKeySet sparse_compressed_cuda(
{DispatchKey::Functionalize, DispatchKey::SparseCsrCUDA});
ASSERT_EQ(
DispatchKey::SparseCsrCUDA,
c10::highestPriorityBackendTypeId(sparse_compressed_cuda));
// quantizedCUDA has higher priority than CUDA
DispatchKeySet quantized_cuda(
{DispatchKey::CUDA, DispatchKey::QuantizedCUDA});
@ -417,6 +425,7 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
k == DispatchKey::StartOfDenseBackends ||
k == DispatchKey::StartOfQuantizedBackends ||
k == DispatchKey::StartOfSparseBackends ||
k == DispatchKey::StartOfSparseCsrBackends ||
k == DispatchKey::StartOfNestedTensorBackends ||
k == DispatchKey::StartOfAutogradFunctionalityBackends)
continue;

View File

@ -54,7 +54,14 @@ DEFAULT_KERNEL_NAMESPACE = "at::native"
# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"]
FUNCTIONALITY_KEYS = [
"",
"Quantized",
"Sparse",
"SparseCsr",
"NestedTensor",
"Autograd",
]
# This list guards dispatches that can be used in derivatives.yaml
# For now we omit AutogradFunctionality and AutogradOther
@ -82,8 +89,7 @@ class DispatchKey(Enum):
CustomRNGKeyId = auto()
MkldnnCPU = auto()
Sparse = auto()
SparseCsrCPU = auto()
SparseCsrCUDA = auto()
SparseCsr = auto()
NestedTensor = auto()
Dense = auto()
@ -165,6 +171,21 @@ class DispatchKey(Enum):
SparsePrivateUse1 = auto()
SparsePrivateUse2 = auto()
SparsePrivateUse3 = auto()
SparseCsrCPU = auto()
SparseCsrCUDA = auto()
SparseCsrHIP = auto()
SparseCsrXLA = auto()
SparseCsrMTIA = auto()
SparseCsrMPS = auto()
SparseCsrIPU = auto()
SparseCsrXPU = auto()
SparseCsrHPU = auto()
SparseCsrVE = auto()
SparseCsrLazy = auto()
SparseCsrMeta = auto()
SparseCsrPrivateUse1 = auto()
SparseCsrPrivateUse2 = auto()
SparseCsrPrivateUse3 = auto()
NestedTensorCPU = auto()
NestedTensorCUDA = auto()
NestedTensorHIP = auto()
@ -260,6 +281,7 @@ dispatch_keys = [
# kernels
DispatchKey.Meta,
DispatchKey.SparseMeta,
DispatchKey.SparseCsrMeta,
DispatchKey.QuantizedMeta,
DispatchKey.NestedTensorMeta,
DispatchKey.ZeroTensor,