expose number of outputs in native runtime for unified runtime (#161723)

This is only user outputs which is what we want. Spoke to @zhxchen17 though and it seems like nativeRT might have some bugs on propogating updates to things like input mutation or buffer mutation though. Something to take a look at in a follow up.

Also I have no idea where the nativeRT tests are. Any pointers @zhxchen17  @SherlockNoMad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161723
Approved by: https://github.com/zhxchen17
This commit is contained in:
Jacob Szwejbka
2025-09-04 01:20:27 +00:00
committed by PyTorch MergeBot
parent fbf3d2027d
commit 65985937d9
2 changed files with 7 additions and 0 deletions

View File

@ -152,6 +152,11 @@ std::vector<c10::IValue> ModelRunner::runWithFlatInputsAndOutputs(
return executor_->execute(std::move(flatInputs));
}
uint64_t ModelRunner::numOutputs() const {
TORCH_CHECK(executor_, "ModelRunner not initialized");
return executor_->graphSignature().userOutputs().size();
}
ModelRunnerHandle::ModelRunnerHandle(
const std::string& packagePath,
const std::string& modelName)

View File

@ -32,6 +32,8 @@ class TORCH_API ModelRunner {
std::vector<c10::IValue> runWithFlatInputsAndOutputs(
std::vector<c10::IValue> flatInputs);
uint64_t numOutputs() const;
std::shared_ptr<Weights> loadWeightsDefault(
Graph& graph,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>& reader);