add static dispatch kernel registration to open source (#160439)

Summary: static dispatch registry should be moved to open source. the rest can maintain internally for now, since delegates will all go through ET hop.

Test Plan: spot checked existing tests and didn't see any missing registrations

Differential Revision: D80099377

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160439
Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
This commit is contained in:
dolpm
2025-08-20 17:58:00 +00:00
committed by PyTorch MergeBot
parent b2632e7982
commit 1471b20cb3
8 changed files with 103 additions and 0 deletions

View File

@ -634,6 +634,7 @@ libtorch_nativert_sources = [
"torch/nativert/graph/passes/SubgraphRewriter.cpp",
"torch/nativert/graph/passes/pass_manager/GraphPasses.cpp",
"torch/nativert/graph/passes/pass_manager/PassManager.cpp",
"torch/nativert/kernels/KernelHandlerRegistry.cpp",
]
torch_mobile_tracer_sources = [

View File

@ -39,6 +39,7 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/graph/passes/SubgraphRewriter.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
)
add_executable(test_nativert

View File

@ -0,0 +1,15 @@
#include <gtest/gtest.h>
#include <torch/nativert/kernels/KernelFactory.h>
#include <torch/nativert/kernels/KernelHandlerRegistry.h>
using namespace ::testing;
using namespace torch::nativert;
TEST(StaticDispatchKernelRegistrationTests, TestRegistration) {
EXPECT_FALSE(KernelFactory::isHandlerRegistered("static_cpu"));
register_kernel_handlers();
EXPECT_TRUE(KernelFactory::isHandlerRegistered("static_cpu"));
// try to re-register, which should be a no-op
register_kernel_handlers();
}

View File

@ -10,6 +10,7 @@
#include <torch/nativert/executor/Placement.h>
#include <torch/nativert/graph/GraphPasses.h>
#include <torch/nativert/graph/Serialization.h>
#include <torch/nativert/kernels/KernelHandlerRegistry.h>
namespace torch::nativert {
@ -55,6 +56,7 @@ std::shared_ptr<Weights> loadWeightsDefault(
ModelRunner::ModelRunner(
const std::string& packagePath,
const std::string& modelName) {
register_kernel_handlers();
auto pytorchStreamReader =
std::make_shared<caffe2::serialize::PyTorchStreamReader>(
std::make_unique<caffe2::serialize::FileAdapter>(packagePath));

View File

@ -77,6 +77,13 @@ void KernelFactory::registerHandler(
});
}
/* static */ bool KernelFactory::isHandlerRegistered(
const std::string& handler) {
return getKernelFactoryRegistry().withLock([&](auto&& reg) {
return reg.handlers.find(handler) != reg.handlers.end();
});
}
ExecutionKernels KernelFactory::initializeNodeKernels(
const Graph& graph,
const std::shared_ptr<Weights>& weights,

View File

@ -75,6 +75,8 @@ class KernelFactory {
static void registerHandler(
const std::string& name,
KernelFactoryHandler handler);
static bool isHandlerRegistered(const std::string& handler);
};
} // namespace torch::nativert

View File

@ -0,0 +1,68 @@
#include <torch/nativert/kernels/KernelHandlerRegistry.h>
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <ATen/core/ivalue.h>
#include <c10/util/CallOnce.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/graph/GraphPasses.h>
#include <torch/nativert/graph/GraphUtils.h>
#include <torch/nativert/kernels/KernelFactory.h>
#include <torch/nativert/kernels/KernelRegistry.h>
namespace torch::nativert {
namespace {
std::string maybeRevisedStaticDispatchTarget(const Node& node) {
auto overloadName = selectScalarOverloadName(node);
if (!overloadName.empty() && !c10::ends_with(node.target(), overloadName)) {
const std::string& newTarget =
std::string(node.target())
.replace(node.target().rfind('.'), std::string::npos, overloadName);
LOG(INFO) << fmt::format(
"Converting Tensor to {} for node: {} -> {}",
overloadName,
node.target(),
newTarget);
return newTarget;
}
return std::string(node.target());
}
} // namespace
void register_kernel_handlers() {
static c10::once_flag flag;
c10::call_once(flag, []() {
using OpKernelPtr = KernelFactoryHandler::OpKernelPtr;
using DelegateExecutorPtr = KernelFactoryHandler::DelegateExecutorPtr;
KernelFactory::registerHandler(
"static_cpu",
KernelFactoryHandler(
[](const Node& node,
const torch::nativert::ExecutorConfig& executorConfig) {
if (!executorConfig.enableStaticCPUKernels ||
!torch::nativert::areAllIOTensorsAttributesOnCpu(node)) {
return false;
}
const std::string target = maybeRevisedStaticDispatchTarget(node);
return torch::nativert::StaticallyDispatchedCPUKernelRegistry()
->Has(target);
},
[](const Node& node,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::shared_ptr<Weights> weights,
const torch::nativert::ExecutorConfig& executorConfig,
caffe2::serialize::PyTorchStreamReader* packageReader)
-> std::pair<OpKernelPtr, DelegateExecutorPtr> {
return {
torch::nativert::StaticallyDispatchedCPUKernelRegistry()
->Create(maybeRevisedStaticDispatchTarget(node), &node),
nullptr};
}));
});
}
} // namespace torch::nativert

View File

@ -0,0 +1,7 @@
#pragma once
namespace torch::nativert {
void register_kernel_handlers();
} // namespace torch::nativert