diff --git a/build_variables.bzl b/build_variables.bzl index c3c99014d9f4..dfae1d527bb7 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -634,6 +634,7 @@ libtorch_nativert_sources = [ "torch/nativert/graph/passes/SubgraphRewriter.cpp", "torch/nativert/graph/passes/pass_manager/GraphPasses.cpp", "torch/nativert/graph/passes/pass_manager/PassManager.cpp", + "torch/nativert/kernels/KernelHandlerRegistry.cpp", ] torch_mobile_tracer_sources = [ diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 822ed7c3bd99..1b7024f75488 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -39,6 +39,7 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/graph/passes/SubgraphRewriter.cpp ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp + ${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp ) add_executable(test_nativert diff --git a/test/cpp/nativert/test_static_dispatch_kernel_registration.cpp b/test/cpp/nativert/test_static_dispatch_kernel_registration.cpp new file mode 100644 index 000000000000..df5f427879e1 --- /dev/null +++ b/test/cpp/nativert/test_static_dispatch_kernel_registration.cpp @@ -0,0 +1,15 @@ +#include + +#include +#include + +using namespace ::testing; +using namespace torch::nativert; + +TEST(StaticDispatchKernelRegistrationTests, TestRegistration) { + EXPECT_FALSE(KernelFactory::isHandlerRegistered("static_cpu")); + register_kernel_handlers(); + EXPECT_TRUE(KernelFactory::isHandlerRegistered("static_cpu")); + // try to re-register, which should be a no-op + register_kernel_handlers(); +} diff --git a/torch/nativert/ModelRunner.cpp b/torch/nativert/ModelRunner.cpp index 83cb0e00bd72..633a66c1bd93 100644 --- a/torch/nativert/ModelRunner.cpp +++ b/torch/nativert/ModelRunner.cpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace torch::nativert { @@ -55,6 +56,7 @@ std::shared_ptr loadWeightsDefault( ModelRunner::ModelRunner( const std::string& packagePath, const std::string& modelName) { + register_kernel_handlers(); auto pytorchStreamReader = std::make_shared( std::make_unique(packagePath)); diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index adf9bae8877a..1702751e704b 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -77,6 +77,13 @@ void KernelFactory::registerHandler( }); } +/* static */ bool KernelFactory::isHandlerRegistered( + const std::string& handler) { + return getKernelFactoryRegistry().withLock([&](auto&& reg) { + return reg.handlers.find(handler) != reg.handlers.end(); + }); +} + ExecutionKernels KernelFactory::initializeNodeKernels( const Graph& graph, const std::shared_ptr& weights, diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h index 05773dc5e4c5..4b5486cd322b 100644 --- a/torch/nativert/kernels/KernelFactory.h +++ b/torch/nativert/kernels/KernelFactory.h @@ -75,6 +75,8 @@ class KernelFactory { static void registerHandler( const std::string& name, KernelFactoryHandler handler); + + static bool isHandlerRegistered(const std::string& handler); }; } // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelHandlerRegistry.cpp b/torch/nativert/kernels/KernelHandlerRegistry.cpp new file mode 100644 index 000000000000..653ca5dfcb81 --- /dev/null +++ b/torch/nativert/kernels/KernelHandlerRegistry.cpp @@ -0,0 +1,68 @@ +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::nativert { + +namespace { +std::string maybeRevisedStaticDispatchTarget(const Node& node) { + auto overloadName = selectScalarOverloadName(node); + + if (!overloadName.empty() && !c10::ends_with(node.target(), overloadName)) { + const std::string& newTarget = + std::string(node.target()) + .replace(node.target().rfind('.'), std::string::npos, overloadName); + LOG(INFO) << fmt::format( + "Converting Tensor to {} for node: {} -> {}", + overloadName, + node.target(), + newTarget); + return newTarget; + } + return std::string(node.target()); +} +} // namespace + +void register_kernel_handlers() { + static c10::once_flag flag; + c10::call_once(flag, []() { + using OpKernelPtr = KernelFactoryHandler::OpKernelPtr; + using DelegateExecutorPtr = KernelFactoryHandler::DelegateExecutorPtr; + KernelFactory::registerHandler( + "static_cpu", + KernelFactoryHandler( + [](const Node& node, + const torch::nativert::ExecutorConfig& executorConfig) { + if (!executorConfig.enableStaticCPUKernels || + !torch::nativert::areAllIOTensorsAttributesOnCpu(node)) { + return false; + } + const std::string target = maybeRevisedStaticDispatchTarget(node); + return torch::nativert::StaticallyDispatchedCPUKernelRegistry() + ->Has(target); + }, + [](const Node& node, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + caffe2::serialize::PyTorchStreamReader* packageReader) + -> std::pair { + return { + torch::nativert::StaticallyDispatchedCPUKernelRegistry() + ->Create(maybeRevisedStaticDispatchTarget(node), &node), + nullptr}; + })); + }); +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelHandlerRegistry.h b/torch/nativert/kernels/KernelHandlerRegistry.h new file mode 100644 index 000000000000..985ca0819a9a --- /dev/null +++ b/torch/nativert/kernels/KernelHandlerRegistry.h @@ -0,0 +1,7 @@ +#pragma once + +namespace torch::nativert { + +void register_kernel_handlers(); + +} // namespace torch::nativert