Files
pytorch/torch/nativert/ModelRunner.h
Jacob Szwejbka 65985937d9 expose number of outputs in native runtime for unified runtime (#161723)
This is only user outputs which is what we want. Spoke to @zhxchen17 though and it seems like nativeRT might have some bugs on propogating updates to things like input mutation or buffer mutation though. Something to take a look at in a follow up.

Also I have no idea where the nativeRT tests are. Any pointers @zhxchen17  @SherlockNoMad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161723
Approved by: https://github.com/zhxchen17
2025-09-04 01:20:31 +00:00

63 lines
1.9 KiB
C++

#pragma once
#include <fmt/format.h>
#include <c10/macros/Export.h>
#include <torch/csrc/utils/generated_serialization_types.h>
#include <torch/nativert/ModelRunnerHandle.h>
#include <torch/nativert/detail/ITree.h>
#include <torch/nativert/executor/Executor.h>
#include <torch/nativert/executor/Placement.h>
namespace torch::nativert {
class TORCH_API ModelRunner {
public:
ModelRunner(const std::string& packagePath, const std::string& modelName);
ModelRunner(ModelRunner&&) = default;
ModelRunner& operator=(ModelRunner&&) = default;
ModelRunner(const ModelRunner&) = delete;
ModelRunner& operator=(const ModelRunner&) = delete;
~ModelRunner() = default;
c10::IValue run(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
/**
* A low level API which expects user to always pass in flattened inputs.
* The ownership of the entire input list must be transferred to the
* executor via std::move or in-place construction.
*/
std::vector<c10::IValue> runWithFlatInputsAndOutputs(
std::vector<c10::IValue> flatInputs);
uint64_t numOutputs() const;
std::shared_ptr<Weights> loadWeightsDefault(
Graph& graph,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>& reader);
private:
std::unordered_map<std::string, std::string> getPayloadConfig(
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader,
std::string_view configFormat,
const std::string& modelName);
// original non-delegated graph from torch.export()
std::shared_ptr<Graph> graph_;
std::unique_ptr<Executor> executor_;
ITreeSpec inputSpec_;
ITreeSpec outputSpec_;
torch::_export::ExportedProgram exportedProgram_;
std::unordered_map<std::string, std::string> tensorPaths_;
std::unordered_map<std::string, std::string> constantPaths_;
};
} // namespace torch::nativert