#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::nativert { using namespace torch::nativert::detail; struct DistributedRunConfig; /** * A very dumb executor. Basically just runs each node in order and contains a * giant unordered map for every intermediate, no optimizations applied. */ class Executor { class ExecutorFrameDeleter { public: explicit ExecutorFrameDeleter(Executor& e) : e_(&e) {} ExecutorFrameDeleter(ExecutorFrameDeleter&&) = default; ExecutorFrameDeleter& operator=(ExecutorFrameDeleter&&) = default; ExecutorFrameDeleter(const ExecutorFrameDeleter&) = default; ExecutorFrameDeleter& operator=(const ExecutorFrameDeleter&) = default; ~ExecutorFrameDeleter() = default; void operator()(ExecutionFrame* p) { e_->returnExecutorFrameToPool(std::unique_ptr(p)); } private: Executor* e_; }; class ExecutorFramePtr { public: ExecutorFramePtr(std::unique_ptr ptr, Executor& e) : ptr_(std::unique_ptr( ptr.release(), ExecutorFrameDeleter{e})) {} ExecutorFramePtr() = delete; ExecutorFramePtr(ExecutorFramePtr&&) = default; ExecutorFramePtr& operator=(ExecutorFramePtr&&) = default; ExecutorFramePtr(const ExecutorFramePtr&) = delete; ExecutorFramePtr& operator=(const ExecutorFramePtr&) = delete; ~ExecutorFramePtr() = default; ExecutionFrame& operator*() { return *ptr_; } ExecutionFrame* operator->() { return ptr_.get(); } private: std::unique_ptr ptr_; }; public: // Constructor used for Inference Path Executor( torch::nativert::ExecutorConfig executorConfig, std::shared_ptr graph, const std::shared_ptr& weights, const std::shared_ptr& pytorchStreamReader = nullptr); std::shared_ptr getWeights() { std::shared_ptr ret; weights_.withLock([&](auto& w) { ret = w; }); return ret; } void processWeights(const std::shared_ptr& weights); void atomicSwapWeights(std::shared_ptr weights); // This API only returns the flattened UserOutputs, // intended to be used for Inference path // TODO Investigate whether we should remove this, still seems // useful for testing. std::vector execute(std::vector inputs); std::vector execute( const std::vector& args, const std::unordered_map& kwargs, const ITreeSpec& inputTreeSpec); ProfileMetrics benchmarkIndividualNodes( const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns); const torch::nativert::GraphSignature& graphSignature() const { return graph_->signature(); } static std::string className() { return "Executor.v0"; } const torch::nativert::ExecutorConfig& executorConfig() const { return executorConfig_; } std::vector getDelegates(); // Get the number of execution frames in the pool auto getNumExecutionFrames() const { return numExecutionFrames_.load(); } static c10::FastMap getKernelSchemas(const std::vector>& kernels); protected: torch::nativert::ExecutorConfig executorConfig_; std::shared_ptr graph_; // manages the parameters, buffers and tensor constants c10::Synchronized> weights_; void initialize( const std::shared_ptr& weights, const std::shared_ptr& pytorchStreamReader); ExecutorFramePtr getExecutorFrameFromPool(); void returnExecutorFrameToPool(std::unique_ptr frame); // Clears stale execution frames from the pool void clearStaleExecutionFrames(); private: void maybeRunConstantFolding(const std::shared_ptr& weights); void validateInputs(const std::vector& inputs) const; // Helper method to get current timestamp in seconds int64_t getCurrentTimestampSeconds() const; void initWeights(const std::shared_ptr& weights); std::unique_ptr graphExecutor_; // NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_. std::vector> delegateExecutors_; std::vector constFoldingExecutions_; std::optional constantFolder_; c10::Semaphore sem_; torch::nativert::detail::MPMCQueue> executionFrames_; torch::nativert::detail::MPMCQueue> inactiveExecutionFrames_; std::atomic_int64_t numExecutionFrames_; std::unique_ptr layoutPlanner_; std::atomic_int64_t lastClearedTimestamp_; std::mutex cleanupLock_; std::atomic_bool clearingInProgress_{false}; }; } // namespace torch::nativert