mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add static dispatch kernel registration to open source (#160439)
Summary: static dispatch registry should be moved to open source. the rest can maintain internally for now, since delegates will all go through ET hop. Test Plan: spot checked existing tests and didn't see any missing registrations Differential Revision: D80099377 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160439 Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
This commit is contained in:
@ -634,6 +634,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/graph/passes/SubgraphRewriter.cpp",
|
||||
"torch/nativert/graph/passes/pass_manager/GraphPasses.cpp",
|
||||
"torch/nativert/graph/passes/pass_manager/PassManager.cpp",
|
||||
"torch/nativert/kernels/KernelHandlerRegistry.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
|
@ -39,6 +39,7 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/SubgraphRewriter.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
|
||||
)
|
||||
|
||||
add_executable(test_nativert
|
||||
|
@ -0,0 +1,15 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nativert/kernels/KernelFactory.h>
|
||||
#include <torch/nativert/kernels/KernelHandlerRegistry.h>
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace torch::nativert;
|
||||
|
||||
TEST(StaticDispatchKernelRegistrationTests, TestRegistration) {
|
||||
EXPECT_FALSE(KernelFactory::isHandlerRegistered("static_cpu"));
|
||||
register_kernel_handlers();
|
||||
EXPECT_TRUE(KernelFactory::isHandlerRegistered("static_cpu"));
|
||||
// try to re-register, which should be a no-op
|
||||
register_kernel_handlers();
|
||||
}
|
@ -10,6 +10,7 @@
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
#include <torch/nativert/graph/GraphPasses.h>
|
||||
#include <torch/nativert/graph/Serialization.h>
|
||||
#include <torch/nativert/kernels/KernelHandlerRegistry.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
@ -55,6 +56,7 @@ std::shared_ptr<Weights> loadWeightsDefault(
|
||||
ModelRunner::ModelRunner(
|
||||
const std::string& packagePath,
|
||||
const std::string& modelName) {
|
||||
register_kernel_handlers();
|
||||
auto pytorchStreamReader =
|
||||
std::make_shared<caffe2::serialize::PyTorchStreamReader>(
|
||||
std::make_unique<caffe2::serialize::FileAdapter>(packagePath));
|
||||
|
@ -77,6 +77,13 @@ void KernelFactory::registerHandler(
|
||||
});
|
||||
}
|
||||
|
||||
/* static */ bool KernelFactory::isHandlerRegistered(
|
||||
const std::string& handler) {
|
||||
return getKernelFactoryRegistry().withLock([&](auto&& reg) {
|
||||
return reg.handlers.find(handler) != reg.handlers.end();
|
||||
});
|
||||
}
|
||||
|
||||
ExecutionKernels KernelFactory::initializeNodeKernels(
|
||||
const Graph& graph,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
|
@ -75,6 +75,8 @@ class KernelFactory {
|
||||
static void registerHandler(
|
||||
const std::string& name,
|
||||
KernelFactoryHandler handler);
|
||||
|
||||
static bool isHandlerRegistered(const std::string& handler);
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
68
torch/nativert/kernels/KernelHandlerRegistry.cpp
Normal file
68
torch/nativert/kernels/KernelHandlerRegistry.cpp
Normal file
@ -0,0 +1,68 @@
|
||||
#include <torch/nativert/kernels/KernelHandlerRegistry.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
#include <torch/nativert/graph/GraphPasses.h>
|
||||
#include <torch/nativert/graph/GraphUtils.h>
|
||||
#include <torch/nativert/kernels/KernelFactory.h>
|
||||
#include <torch/nativert/kernels/KernelRegistry.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
std::string maybeRevisedStaticDispatchTarget(const Node& node) {
|
||||
auto overloadName = selectScalarOverloadName(node);
|
||||
|
||||
if (!overloadName.empty() && !c10::ends_with(node.target(), overloadName)) {
|
||||
const std::string& newTarget =
|
||||
std::string(node.target())
|
||||
.replace(node.target().rfind('.'), std::string::npos, overloadName);
|
||||
LOG(INFO) << fmt::format(
|
||||
"Converting Tensor to {} for node: {} -> {}",
|
||||
overloadName,
|
||||
node.target(),
|
||||
newTarget);
|
||||
return newTarget;
|
||||
}
|
||||
return std::string(node.target());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void register_kernel_handlers() {
|
||||
static c10::once_flag flag;
|
||||
c10::call_once(flag, []() {
|
||||
using OpKernelPtr = KernelFactoryHandler::OpKernelPtr;
|
||||
using DelegateExecutorPtr = KernelFactoryHandler::DelegateExecutorPtr;
|
||||
KernelFactory::registerHandler(
|
||||
"static_cpu",
|
||||
KernelFactoryHandler(
|
||||
[](const Node& node,
|
||||
const torch::nativert::ExecutorConfig& executorConfig) {
|
||||
if (!executorConfig.enableStaticCPUKernels ||
|
||||
!torch::nativert::areAllIOTensorsAttributesOnCpu(node)) {
|
||||
return false;
|
||||
}
|
||||
const std::string target = maybeRevisedStaticDispatchTarget(node);
|
||||
return torch::nativert::StaticallyDispatchedCPUKernelRegistry()
|
||||
->Has(target);
|
||||
},
|
||||
[](const Node& node,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::shared_ptr<Weights> weights,
|
||||
const torch::nativert::ExecutorConfig& executorConfig,
|
||||
caffe2::serialize::PyTorchStreamReader* packageReader)
|
||||
-> std::pair<OpKernelPtr, DelegateExecutorPtr> {
|
||||
return {
|
||||
torch::nativert::StaticallyDispatchedCPUKernelRegistry()
|
||||
->Create(maybeRevisedStaticDispatchTarget(node), &node),
|
||||
nullptr};
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
7
torch/nativert/kernels/KernelHandlerRegistry.h
Normal file
7
torch/nativert/kernels/KernelHandlerRegistry.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
void register_kernel_handlers();
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user