add python API to print all operators that have kernels registered to a particular DispatchKey (#63575)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63575

Test Plan: Imported from OSS

Reviewed By: ezyang, Chillee

Differential Revision: D30426919

Pulled By: bdhirsh

fbshipit-source-id: b0e487e48dfe02f7b9d678403f0a2b5bfe146f4e
This commit is contained in:
Brian Hirsh
2021-09-22 09:14:01 -07:00
committed by Facebook GitHub Bot
parent 9324d682fd
commit bcc6e3ab5e
9 changed files with 164 additions and 37 deletions

View File

@ -323,6 +323,19 @@ std::vector<OperatorHandle> Dispatcher::findDanglingImpls() const {
});
}
std::vector<OperatorName> Dispatcher::getRegistrationsForDispatchKey(c10::optional<DispatchKey> k) const {
return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
std::vector<OperatorName> op_names;
for (const auto& op : operatorLookupTable) {
// If no DispatchKey is specified, print all of the operators.
if (!k || op.second.hasKernelForDispatchKey(*k)) {
op_names.push_back(op.first);
}
}
return op_names;
});
}
int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate

View File

@ -252,6 +252,13 @@ public:
*/
std::vector<OperatorHandle> findDanglingImpls() const;
/**
* Useful for inspecting global Dispatcher registration state.
* Returns the names of all operators with a kernel registered for the specified DispatchKey.
* If no DispatchKey is specified, it returns all registered operators.
*/
std::vector<OperatorName> getRegistrationsForDispatchKey(c10::optional<DispatchKey> k) const;
private:
Dispatcher();
@ -318,6 +325,11 @@ public:
return operatorDef_->op.dumpState();
}
bool hasKernelForDispatchKey(DispatchKey k) const {
return operatorDef_->op.hasKernelForDispatchKey(k);
}
std::string dumpComputedTable() const {
return operatorDef_->op.dumpComputedTable();
}

View File

@ -187,6 +187,17 @@ public:
std::string listAllDispatchKeys() const;
// Returns true if kernel_ has entry for any key in ks.
//
// Invariant: There are no alias keys in the passed-in dispatch key set.
// Note [No Alias Keys in DispatchKeySet]
// Alias keys should be checked using `hasKernelForDispatchKey`
// Alias keys shouldn't go inside of a DispatchKeySet, since they can technically
// have a value > 63 (causing overflow).
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
// Returns true if kernel_ has entry for a particular key.
bool hasKernelForDispatchKey(DispatchKey k) const;
private:
OperatorName name_;
@ -266,17 +277,6 @@ private:
void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
// Like above, but for ALL entries in the dispatch table.
void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
// Returns true if kernel_ has entry for any key in ks.
//
// Invariant: There are no alias keys in the passed-in dispatch key set.
// Note [No Alias Keys in DispatchKeySet]
// Alias keys should be checked using `hasKernelForDispatchKey`
// Alias keys shouldn't go inside of a DispatchKeySet, since they can technically
// have a value > 63 (causing overflow).
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
// Returns true if kernel_ has entry for a particular key.
bool hasKernelForDispatchKey(DispatchKey k) const;
// Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front().
const AnnotatedKernel* getKernelForDispatchKey(DispatchKey dispatch_key) const;
};

View File

@ -18,6 +18,8 @@
#include <ATen/core/LegacyTypeDispatch.h>
#include <algorithm>
using c10::RegisterOperators;
using c10::OperatorKernel;
using c10::OperatorHandle;
@ -2101,6 +2103,23 @@ TEST(OperatorRegistrationTest, callKernelsWithDispatchKeySetConvention_mixedCall
EXPECT_TRUE(called_kernel_cpu);
}
TEST(OperatorRegistrationTest, getRegistrationsForDispatchKey) {
// should return every registered op
auto all_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::nullopt);
// should return every registered op with a cpu kernel
auto cpu_ops = Dispatcher::singleton().getRegistrationsForDispatchKey(c10::DispatchKey::CPU);
ASSERT_TRUE(all_ops.size() > 0);
ASSERT_TRUE(cpu_ops.size() > 0);
auto cmp_lambda = [](const c10::OperatorName a, const c10::OperatorName& b) -> bool {
return c10::toString(a) < c10::toString(b);
};
std::sort(all_ops.begin(), all_ops.end(), cmp_lambda);
std::sort(cpu_ops.begin(), cpu_ops.end(), cmp_lambda);
ASSERT_TRUE(std::includes(all_ops.begin(), all_ops.end(), cpu_ops.begin(), cpu_ops.end(), cmp_lambda));
}
}
#pragma GCC diagnostic pop

View File

@ -1,5 +1,7 @@
#include <c10/core/DispatchKey.h>
#include <unordered_map>
namespace c10 {
const char* toString(DispatchKey t) {
@ -202,4 +204,82 @@ DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
}
}
c10::DispatchKey parseDispatchKey(const std::string& k) {
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
{"Undefined", c10::DispatchKey::Undefined},
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},
{"HIP", c10::DispatchKey::HIP},
{"FPGA", c10::DispatchKey::FPGA},
{"ORT", c10::DispatchKey::ORT},
{"XLA", c10::DispatchKey::XLA},
{"MLC", c10::DispatchKey::MLC},
{"Vulkan", c10::DispatchKey::Vulkan},
{"Metal", c10::DispatchKey::Metal},
{"XPU", c10::DispatchKey::XPU},
{"HPU", c10::DispatchKey::HPU},
{"VE", c10::DispatchKey::VE},
{"Lazy", c10::DispatchKey::Lazy},
{"Meta", c10::DispatchKey::Meta},
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
{"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
{"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
{"SparseCPU", c10::DispatchKey::SparseCPU},
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
{"SparseHIP", c10::DispatchKey::SparseHIP},
{"SparseXPU", c10::DispatchKey::SparseXPU},
{"SparseVE", c10::DispatchKey::SparseVE},
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
{"NestedTensor", c10::DispatchKey::NestedTensor},
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
{"BackendSelect", c10::DispatchKey::BackendSelect},
{"Python", c10::DispatchKey::Python},
{"FuncTorchPython", c10::DispatchKey::FuncTorchPython},
{"Named", c10::DispatchKey::Named},
{"Conjugate", c10::DispatchKey::Conjugate},
{"Negative", c10::DispatchKey::Negative},
{"FuncTorchDynamicLayerBackMode",
c10::DispatchKey::FuncTorchDynamicLayerBackMode},
{"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
{"AutogradOther", c10::DispatchKey::AutogradOther},
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
{"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
{"Tracer", c10::DispatchKey::Tracer},
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
{"Batched", c10::DispatchKey::Batched},
{"VmapMode", c10::DispatchKey::VmapMode},
{"FuncTorchGradWrapper", c10::DispatchKey::FuncTorchGradWrapper},
{"FuncTorchDynamicLayerFrontMode",
c10::DispatchKey::FuncTorchDynamicLayerFrontMode},
{"TESTING_ONLY_GenericWrapper",
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
{"Autograd", c10::DispatchKey::Autograd},
{"CompositeImplicitAutograd",
c10::DispatchKey::CompositeImplicitAutograd},
{"CompositeExplicitAutograd",
c10::DispatchKey::CompositeExplicitAutograd},
};
auto it = key_map.find(k);
TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: ", k);
return it->second;
}
} // namespace c10

View File

@ -363,6 +363,10 @@ C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
C10_API DispatchKey getAutogradKeyFromBackend(DispatchKey t);
// Parses a string into a dispatch key.
// If the string cannot be correctly parsed, throws an exception.
C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
// These are some convenience identifiers for dispatch keys which are
// shorter to type than their long counterparts. Note that some of these
// dispatch keys directly correspond to DeviceType; and most APIs that

View File

@ -786,6 +786,12 @@ CPU: registered at {}:5 :: () -> () [ boxed unboxed ]
'''.format(extension_path),
impls[0])
def test_dispatch_print_registrations_for_dispatch_key_invalid(self):
with self.assertRaisesRegex(
RuntimeError,
"could not parse dispatch key: invalid_key"):
C._dispatch_print_registrations_for_dispatch_key('invalid_key')
class TestPythonDispatcher(TestCase):
def test_basic(self):
dispatcher = PythonDispatcher()

View File

@ -26,29 +26,6 @@ torch::Library::Kind parseKind(const std::string& k) {
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
return it->second;
}
c10::optional<c10::DispatchKey> parseDispatchKey(const std::string& k) {
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},
{"XLA", c10::DispatchKey::XLA},
{"Lazy", c10::DispatchKey::Lazy},
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
{"CompositeImplicitAutograd", c10::DispatchKey::CompositeImplicitAutograd},
{"Autograd", c10::DispatchKey::Autograd},
{"CompositeExplicitAutograd", c10::DispatchKey::CompositeExplicitAutograd},
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
{"", c10::DispatchKey::Undefined},
};
auto it = key_map.find(k);
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
if (it->second == c10::DispatchKey::Undefined) {
return c10::nullopt;
} else {
return c10::make_optional(it->second);
}
}
c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
@ -64,7 +41,7 @@ c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
template <typename Func>
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
auto mb_key = parseDispatchKey(key);
auto mb_key = std::string(key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(key));
if (mb_key) {
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
} else {
@ -154,7 +131,12 @@ void initDispatchBindings(PyObject* module) {
;
m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch) {
return std::make_unique<torch::Library>(parseKind(kind), std::move(name), parseDispatchKey(dispatch), "/dev/null", 0);
return std::make_unique<torch::Library>(
parseKind(kind),
std::move(name),
std::string(dispatch) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch)),
"/dev/null",
0);
});
m.def("_dispatch_dump", [](const char* name) -> std::string {
@ -198,6 +180,18 @@ void initDispatchBindings(PyObject* module) {
return states;
});
// Prints out the name of every operator that has a kernel registered to the Dispatcher
// under [dispatch_key].
// If no arguments are specified, it'll print out the name of every operator that the Dispatcher knows of.
// This can be useful to answer questions like "list all operators that do not have a CPU kernel".
m.def("_dispatch_print_registrations_for_dispatch_key", [](const char* dispatch_key = "") {
auto k = std::string(dispatch_key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key));
auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
for (auto& op : op_names) {
std::cout << op << std::endl;
}
}, py::arg("dispatch_key") = static_cast<const char*>(""));
}
}}} // namespace torch::impl::dispatch

View File

@ -2810,7 +2810,6 @@ def get_tensors_from(args, kwargs):
return set([arg for arg in args if isinstance(arg, Tensor)] +
[v for v in kwargs.values() if isinstance(v, Tensor)])
def has_breakpad():
# We always build with breakpad in CI
if IS_IN_CI: