mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: static dispatch registry should be moved to open source. the rest can maintain internally for now, since delegates will all go through ET hop. Test Plan: spot checked existing tests and didn't see any missing registrations Differential Revision: D80099377 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160439 Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
83 lines
2.6 KiB
C++
83 lines
2.6 KiB
C++
#pragma once
|
|
|
|
#include <memory>
|
|
|
|
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
|
|
#include <torch/nativert/executor/DelegateExecutor.h>
|
|
#include <torch/nativert/executor/ExecutorConfig.h>
|
|
#include <torch/nativert/executor/GraphExecutorBase.h>
|
|
#include <torch/nativert/executor/OpKernel.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
struct ConstFoldingExecution {
|
|
std::unique_ptr<GraphExecutorBase> executor;
|
|
};
|
|
|
|
struct ExecutionKernels {
|
|
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
|
|
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
|
|
std::vector<ConstFoldingExecution> constFoldingExecutions;
|
|
};
|
|
|
|
class KernelFactoryHandler {
|
|
public:
|
|
using OpKernelPtr = std::unique_ptr<OpKernel>;
|
|
using DelegateExecutorPtr = std::unique_ptr<DelegateExecutor>;
|
|
using Matcher = c10::function_ref<
|
|
bool(const Node& node, const torch::nativert::ExecutorConfig&)>;
|
|
using Callback =
|
|
c10::function_ref<std::pair<OpKernelPtr, DelegateExecutorPtr>(
|
|
const Node&,
|
|
std::shared_ptr<Weights> weights,
|
|
const torch::nativert::ExecutorConfig& executorConfig,
|
|
caffe2::serialize::PyTorchStreamReader* pytorchStreamReader)>;
|
|
|
|
KernelFactoryHandler(Matcher matcher, Callback callback)
|
|
: matcher_(matcher), callback_(callback) {}
|
|
|
|
KernelFactoryHandler() = delete;
|
|
KernelFactoryHandler(const KernelFactoryHandler&) = default;
|
|
KernelFactoryHandler& operator=(const KernelFactoryHandler&) = default;
|
|
KernelFactoryHandler(KernelFactoryHandler&&) = default;
|
|
KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default;
|
|
~KernelFactoryHandler() = default;
|
|
|
|
bool match(const Node& node, const torch::nativert::ExecutorConfig& config)
|
|
const {
|
|
return matcher_(node, config);
|
|
}
|
|
|
|
std::pair<OpKernelPtr, DelegateExecutorPtr> operator()(
|
|
const Node& node,
|
|
std::shared_ptr<Weights> weights,
|
|
const torch::nativert::ExecutorConfig& executorConfig,
|
|
caffe2::serialize::PyTorchStreamReader* pytorchStreamReader) const {
|
|
return callback_(node, weights, executorConfig, pytorchStreamReader);
|
|
}
|
|
|
|
private:
|
|
Matcher matcher_;
|
|
Callback callback_;
|
|
};
|
|
|
|
class KernelFactory {
|
|
public:
|
|
KernelFactory() = default;
|
|
|
|
ExecutionKernels initializeNodeKernels(
|
|
const Graph& graph,
|
|
const std::shared_ptr<Weights>& weights,
|
|
const torch::nativert::ExecutorConfig& executorConfig,
|
|
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
|
pytorchStreamReader = nullptr);
|
|
|
|
static void registerHandler(
|
|
const std::string& name,
|
|
KernelFactoryHandler handler);
|
|
|
|
static bool isHandlerRegistered(const std::string& handler);
|
|
};
|
|
|
|
} // namespace torch::nativert
|