#pragma once #include #include #include #include #include #include namespace torch::nativert { struct ConstFoldingExecution { std::unique_ptr executor; }; struct ExecutionKernels { std::vector> nodeKernels; std::vector> delegateExecutors; std::vector constFoldingExecutions; }; class KernelFactoryHandler { public: using OpKernelPtr = std::unique_ptr; using DelegateExecutorPtr = std::unique_ptr; using Matcher = c10::function_ref< bool(const Node& node, const torch::nativert::ExecutorConfig&)>; using Callback = c10::function_ref( const Node&, std::shared_ptr weights, const torch::nativert::ExecutorConfig& executorConfig, caffe2::serialize::PyTorchStreamReader* pytorchStreamReader)>; KernelFactoryHandler(Matcher matcher, Callback callback) : matcher_(matcher), callback_(callback) {} KernelFactoryHandler() = delete; KernelFactoryHandler(const KernelFactoryHandler&) = default; KernelFactoryHandler& operator=(const KernelFactoryHandler&) = default; KernelFactoryHandler(KernelFactoryHandler&&) = default; KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default; ~KernelFactoryHandler() = default; bool match(const Node& node, const torch::nativert::ExecutorConfig& config) const { return matcher_(node, config); } std::pair operator()( const Node& node, std::shared_ptr weights, const torch::nativert::ExecutorConfig& executorConfig, caffe2::serialize::PyTorchStreamReader* pytorchStreamReader) const { return callback_(node, weights, executorConfig, pytorchStreamReader); } private: Matcher matcher_; Callback callback_; }; class KernelFactory { public: KernelFactory() = default; ExecutionKernels initializeNodeKernels( const Graph& graph, const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, const std::shared_ptr& pytorchStreamReader = nullptr); static void registerHandler( const std::string& name, KernelFactoryHandler handler); static bool isHandlerRegistered(const std::string& handler); }; } // namespace torch::nativert