Files
pytorch/torch/nativert/kernels/KernelFactory.h
dolpm 1471b20cb3 add static dispatch kernel registration to open source (#160439)
Summary: static dispatch registry should be moved to open source. the rest can maintain internally for now, since delegates will all go through ET hop.

Test Plan: spot checked existing tests and didn't see any missing registrations

Differential Revision: D80099377

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160439
Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
2025-08-20 17:58:00 +00:00

83 lines
2.6 KiB
C++

#pragma once
#include <memory>
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
#include <torch/nativert/executor/DelegateExecutor.h>
#include <torch/nativert/executor/ExecutorConfig.h>
#include <torch/nativert/executor/GraphExecutorBase.h>
#include <torch/nativert/executor/OpKernel.h>
namespace torch::nativert {
struct ConstFoldingExecution {
std::unique_ptr<GraphExecutorBase> executor;
};
struct ExecutionKernels {
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
std::vector<ConstFoldingExecution> constFoldingExecutions;
};
class KernelFactoryHandler {
public:
using OpKernelPtr = std::unique_ptr<OpKernel>;
using DelegateExecutorPtr = std::unique_ptr<DelegateExecutor>;
using Matcher = c10::function_ref<
bool(const Node& node, const torch::nativert::ExecutorConfig&)>;
using Callback =
c10::function_ref<std::pair<OpKernelPtr, DelegateExecutorPtr>(
const Node&,
std::shared_ptr<Weights> weights,
const torch::nativert::ExecutorConfig& executorConfig,
caffe2::serialize::PyTorchStreamReader* pytorchStreamReader)>;
KernelFactoryHandler(Matcher matcher, Callback callback)
: matcher_(matcher), callback_(callback) {}
KernelFactoryHandler() = delete;
KernelFactoryHandler(const KernelFactoryHandler&) = default;
KernelFactoryHandler& operator=(const KernelFactoryHandler&) = default;
KernelFactoryHandler(KernelFactoryHandler&&) = default;
KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default;
~KernelFactoryHandler() = default;
bool match(const Node& node, const torch::nativert::ExecutorConfig& config)
const {
return matcher_(node, config);
}
std::pair<OpKernelPtr, DelegateExecutorPtr> operator()(
const Node& node,
std::shared_ptr<Weights> weights,
const torch::nativert::ExecutorConfig& executorConfig,
caffe2::serialize::PyTorchStreamReader* pytorchStreamReader) const {
return callback_(node, weights, executorConfig, pytorchStreamReader);
}
private:
Matcher matcher_;
Callback callback_;
};
class KernelFactory {
public:
KernelFactory() = default;
ExecutionKernels initializeNodeKernels(
const Graph& graph,
const std::shared_ptr<Weights>& weights,
const torch::nativert::ExecutorConfig& executorConfig,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader = nullptr);
static void registerHandler(
const std::string& name,
KernelFactoryHandler handler);
static bool isHandlerRegistered(const std::string& handler);
};
} // namespace torch::nativert