mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add python API to print all operators that have kernels registered to a particular DispatchKey (#63575)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63575 Test Plan: Imported from OSS Reviewed By: ezyang, Chillee Differential Revision: D30426919 Pulled By: bdhirsh fbshipit-source-id: b0e487e48dfe02f7b9d678403f0a2b5bfe146f4e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9324d682fd
commit
bcc6e3ab5e
@ -323,6 +323,19 @@ std::vector<OperatorHandle> Dispatcher::findDanglingImpls() const {
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<OperatorName> Dispatcher::getRegistrationsForDispatchKey(c10::optional<DispatchKey> k) const {
|
||||
return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
|
||||
std::vector<OperatorName> op_names;
|
||||
for (const auto& op : operatorLookupTable) {
|
||||
// If no DispatchKey is specified, print all of the operators.
|
||||
if (!k || op.second.hasKernelForDispatchKey(*k)) {
|
||||
op_names.push_back(op.first);
|
||||
}
|
||||
}
|
||||
return op_names;
|
||||
});
|
||||
}
|
||||
|
||||
int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey) {
|
||||
int64_t seq_num = -1;
|
||||
// Setting sequence number in the Autograd case to associate
|
||||
|
@ -252,6 +252,13 @@ public:
|
||||
*/
|
||||
std::vector<OperatorHandle> findDanglingImpls() const;
|
||||
|
||||
/**
|
||||
* Useful for inspecting global Dispatcher registration state.
|
||||
* Returns the names of all operators with a kernel registered for the specified DispatchKey.
|
||||
* If no DispatchKey is specified, it returns all registered operators.
|
||||
*/
|
||||
std::vector<OperatorName> getRegistrationsForDispatchKey(c10::optional<DispatchKey> k) const;
|
||||
|
||||
private:
|
||||
Dispatcher();
|
||||
|
||||
@ -318,6 +325,11 @@ public:
|
||||
return operatorDef_->op.dumpState();
|
||||
}
|
||||
|
||||
bool hasKernelForDispatchKey(DispatchKey k) const {
|
||||
return operatorDef_->op.hasKernelForDispatchKey(k);
|
||||
}
|
||||
|
||||
|
||||
std::string dumpComputedTable() const {
|
||||
return operatorDef_->op.dumpComputedTable();
|
||||
}
|
||||
|
@ -187,6 +187,17 @@ public:
|
||||
|
||||
std::string listAllDispatchKeys() const;
|
||||
|
||||
// Returns true if kernel_ has entry for any key in ks.
|
||||
//
|
||||
// Invariant: There are no alias keys in the passed-in dispatch key set.
|
||||
// Note [No Alias Keys in DispatchKeySet]
|
||||
// Alias keys should be checked using `hasKernelForDispatchKey`
|
||||
// Alias keys shouldn't go inside of a DispatchKeySet, since they can technically
|
||||
// have a value > 63 (causing overflow).
|
||||
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
|
||||
// Returns true if kernel_ has entry for a particular key.
|
||||
bool hasKernelForDispatchKey(DispatchKey k) const;
|
||||
|
||||
private:
|
||||
|
||||
OperatorName name_;
|
||||
@ -266,17 +277,6 @@ private:
|
||||
void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
|
||||
// Like above, but for ALL entries in the dispatch table.
|
||||
void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
|
||||
|
||||
// Returns true if kernel_ has entry for any key in ks.
|
||||
//
|
||||
// Invariant: There are no alias keys in the passed-in dispatch key set.
|
||||
// Note [No Alias Keys in DispatchKeySet]
|
||||
// Alias keys should be checked using `hasKernelForDispatchKey`
|
||||
// Alias keys shouldn't go inside of a DispatchKeySet, since they can technically
|
||||
// have a value > 63 (causing overflow).
|
||||
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
|
||||
// Returns true if kernel_ has entry for a particular key.
|
||||
bool hasKernelForDispatchKey(DispatchKey k) const;
|
||||
// Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front().
|
||||
const AnnotatedKernel* getKernelForDispatchKey(DispatchKey dispatch_key) const;
|
||||
};
|
||||
|
@ -18,6 +18,8 @@
|
||||
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using c10::RegisterOperators;
|
||||
using c10::OperatorKernel;
|
||||
using c10::OperatorHandle;
|
||||
@ -2101,6 +2103,23 @@ TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_mixedCall
|
||||
EXPECT_TRUE(called_kernel_cpu);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest, getRegistrationsForDispatchKey) {
|
||||
// should return every registered op
|
||||
auto all_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::nullopt);
|
||||
// should return every registered op with a cpu kernel
|
||||
auto cpu_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::DispatchKey::CPU);
|
||||
ASSERT_TRUE(all_ops.size() > 0);
|
||||
ASSERT_TRUE(cpu_ops.size() > 0);
|
||||
|
||||
auto cmp_lambda = [](const c10::OperatorName a, const c10::OperatorName& b) -> bool {
|
||||
return c10::toString(a) < c10::toString(b);
|
||||
};
|
||||
|
||||
std::sort(all_ops.begin(), all_ops.end(), cmp_lambda);
|
||||
std::sort(cpu_ops.begin(), cpu_ops.end(), cmp_lambda);
|
||||
ASSERT_TRUE(std::includes(all_ops.begin(), all_ops.end(), cpu_ops.begin(), cpu_ops.end(), cmp_lambda));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#pragma GCC diagnostic pop
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <c10/core/DispatchKey.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
const char* toString(DispatchKey t) {
|
||||
@ -202,4 +204,82 @@ DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
|
||||
}
|
||||
}
|
||||
|
||||
c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
|
||||
{"Undefined", c10::DispatchKey::Undefined},
|
||||
{"CPU", c10::DispatchKey::CPU},
|
||||
{"CUDA", c10::DispatchKey::CUDA},
|
||||
{"HIP", c10::DispatchKey::HIP},
|
||||
{"FPGA", c10::DispatchKey::FPGA},
|
||||
{"ORT", c10::DispatchKey::ORT},
|
||||
{"XLA", c10::DispatchKey::XLA},
|
||||
{"MLC", c10::DispatchKey::MLC},
|
||||
{"Vulkan", c10::DispatchKey::Vulkan},
|
||||
{"Metal", c10::DispatchKey::Metal},
|
||||
{"XPU", c10::DispatchKey::XPU},
|
||||
{"HPU", c10::DispatchKey::HPU},
|
||||
{"VE", c10::DispatchKey::VE},
|
||||
{"Lazy", c10::DispatchKey::Lazy},
|
||||
{"Meta", c10::DispatchKey::Meta},
|
||||
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
|
||||
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
|
||||
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
|
||||
{"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
|
||||
{"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
|
||||
{"SparseCPU", c10::DispatchKey::SparseCPU},
|
||||
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
|
||||
{"SparseHIP", c10::DispatchKey::SparseHIP},
|
||||
{"SparseXPU", c10::DispatchKey::SparseXPU},
|
||||
{"SparseVE", c10::DispatchKey::SparseVE},
|
||||
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
|
||||
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
|
||||
{"NestedTensor", c10::DispatchKey::NestedTensor},
|
||||
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
|
||||
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
|
||||
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
|
||||
{"BackendSelect", c10::DispatchKey::BackendSelect},
|
||||
{"Python", c10::DispatchKey::Python},
|
||||
{"FuncTorchPython", c10::DispatchKey::FuncTorchPython},
|
||||
{"Named", c10::DispatchKey::Named},
|
||||
{"Conjugate", c10::DispatchKey::Conjugate},
|
||||
{"Negative", c10::DispatchKey::Negative},
|
||||
{"FuncTorchDynamicLayerBackMode",
|
||||
c10::DispatchKey::FuncTorchDynamicLayerBackMode},
|
||||
{"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
|
||||
{"AutogradOther", c10::DispatchKey::AutogradOther},
|
||||
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
|
||||
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
|
||||
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
|
||||
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
|
||||
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
|
||||
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
|
||||
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
|
||||
{"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
|
||||
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
|
||||
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
|
||||
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
|
||||
{"Tracer", c10::DispatchKey::Tracer},
|
||||
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
|
||||
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
|
||||
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
|
||||
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
|
||||
{"Batched", c10::DispatchKey::Batched},
|
||||
{"VmapMode", c10::DispatchKey::VmapMode},
|
||||
{"FuncTorchGradWrapper", c10::DispatchKey::FuncTorchGradWrapper},
|
||||
{"FuncTorchDynamicLayerFrontMode",
|
||||
c10::DispatchKey::FuncTorchDynamicLayerFrontMode},
|
||||
{"TESTING_ONLY_GenericWrapper",
|
||||
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
|
||||
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
|
||||
{"Autograd", c10::DispatchKey::Autograd},
|
||||
{"CompositeImplicitAutograd",
|
||||
c10::DispatchKey::CompositeImplicitAutograd},
|
||||
{"CompositeExplicitAutograd",
|
||||
c10::DispatchKey::CompositeExplicitAutograd},
|
||||
};
|
||||
auto it = key_map.find(k);
|
||||
TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: ", k);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -363,6 +363,10 @@ C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
|
||||
|
||||
C10_API DispatchKey getAutogradKeyFromBackend(DispatchKey t);
|
||||
|
||||
// Parses a string into a dispatch key.
|
||||
// If the string cannot be correctly parsed, throws an exception.
|
||||
C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
|
||||
|
||||
// These are some convenience identifiers for dispatch keys which are
|
||||
// shorter to type than their long counterparts. Note that some of these
|
||||
// dispatch keys directly correspond to DeviceType; and most APIs that
|
||||
|
@ -786,6 +786,12 @@ CPU: registered at {}:5 :: () -> () [ boxed unboxed ]
|
||||
'''.format(extension_path),
|
||||
impls[0])
|
||||
|
||||
def test_dispatch_print_registrations_for_dispatch_key_invalid(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"could not parse dispatch key: invalid_key"):
|
||||
C._dispatch_print_registrations_for_dispatch_key('invalid_key')
|
||||
|
||||
class TestPythonDispatcher(TestCase):
|
||||
def test_basic(self):
|
||||
dispatcher = PythonDispatcher()
|
||||
|
@ -26,29 +26,6 @@ torch::Library::Kind parseKind(const std::string& k) {
|
||||
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
c10::optional<c10::DispatchKey> parseDispatchKey(const std::string& k) {
|
||||
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
|
||||
{"CPU", c10::DispatchKey::CPU},
|
||||
{"CUDA", c10::DispatchKey::CUDA},
|
||||
{"XLA", c10::DispatchKey::XLA},
|
||||
{"Lazy", c10::DispatchKey::Lazy},
|
||||
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
|
||||
{"CompositeImplicitAutograd", c10::DispatchKey::CompositeImplicitAutograd},
|
||||
{"Autograd", c10::DispatchKey::Autograd},
|
||||
{"CompositeExplicitAutograd", c10::DispatchKey::CompositeExplicitAutograd},
|
||||
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
|
||||
{"", c10::DispatchKey::Undefined},
|
||||
};
|
||||
auto it = key_map.find(k);
|
||||
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
|
||||
if (it->second == c10::DispatchKey::Undefined) {
|
||||
return c10::nullopt;
|
||||
} else {
|
||||
return c10::make_optional(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
||||
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
|
||||
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
|
||||
@ -64,7 +41,7 @@ c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
||||
|
||||
template <typename Func>
|
||||
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||
auto mb_key = parseDispatchKey(key);
|
||||
auto mb_key = std::string(key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(key));
|
||||
if (mb_key) {
|
||||
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
|
||||
} else {
|
||||
@ -154,7 +131,12 @@ void initDispatchBindings(PyObject* module) {
|
||||
;
|
||||
|
||||
m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch) {
|
||||
return std::make_unique<torch::Library>(parseKind(kind), std::move(name), parseDispatchKey(dispatch), "/dev/null", 0);
|
||||
return std::make_unique<torch::Library>(
|
||||
parseKind(kind),
|
||||
std::move(name),
|
||||
std::string(dispatch) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch)),
|
||||
"/dev/null",
|
||||
0);
|
||||
});
|
||||
|
||||
m.def("_dispatch_dump", [](const char* name) -> std::string {
|
||||
@ -198,6 +180,18 @@ void initDispatchBindings(PyObject* module) {
|
||||
|
||||
return states;
|
||||
});
|
||||
|
||||
// Prints out the name of every operator that has a kernel registered to the Dispatcher
|
||||
// under [dispatch_key].
|
||||
// If no arguments are specified, it'll print out the name of every operator that the Dispatcher knows of.
|
||||
// This can be useful to answer questions like "list all operators that do not have a CPU kernel".
|
||||
m.def("_dispatch_print_registrations_for_dispatch_key", [](const char* dispatch_key = "") {
|
||||
auto k = std::string(dispatch_key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key));
|
||||
auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
|
||||
for (auto& op : op_names) {
|
||||
std::cout << op << std::endl;
|
||||
}
|
||||
}, py::arg("dispatch_key") = static_cast<const char*>(""));
|
||||
}
|
||||
|
||||
}}} // namespace torch::impl::dispatch
|
||||
|
@ -2810,7 +2810,6 @@ def get_tensors_from(args, kwargs):
|
||||
return set([arg for arg in args if isinstance(arg, Tensor)] +
|
||||
[v for v in kwargs.values() if isinstance(v, Tensor)])
|
||||
|
||||
|
||||
def has_breakpad():
|
||||
# We always build with breakpad in CI
|
||||
if IS_IN_CI:
|
||||
|
Reference in New Issue
Block a user