mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
expose number of outputs in native runtime for unified runtime (#161723)
This is only user outputs which is what we want. Spoke to @zhxchen17 though and it seems like nativeRT might have some bugs on propogating updates to things like input mutation or buffer mutation though. Something to take a look at in a follow up. Also I have no idea where the nativeRT tests are. Any pointers @zhxchen17 @SherlockNoMad Pull Request resolved: https://github.com/pytorch/pytorch/pull/161723 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
fbf3d2027d
commit
65985937d9
@ -152,6 +152,11 @@ std::vector<c10::IValue> ModelRunner::runWithFlatInputsAndOutputs(
|
||||
return executor_->execute(std::move(flatInputs));
|
||||
}
|
||||
|
||||
uint64_t ModelRunner::numOutputs() const {
|
||||
TORCH_CHECK(executor_, "ModelRunner not initialized");
|
||||
return executor_->graphSignature().userOutputs().size();
|
||||
}
|
||||
|
||||
ModelRunnerHandle::ModelRunnerHandle(
|
||||
const std::string& packagePath,
|
||||
const std::string& modelName)
|
||||
|
@ -32,6 +32,8 @@ class TORCH_API ModelRunner {
|
||||
std::vector<c10::IValue> runWithFlatInputsAndOutputs(
|
||||
std::vector<c10::IValue> flatInputs);
|
||||
|
||||
uint64_t numOutputs() const;
|
||||
|
||||
std::shared_ptr<Weights> loadWeightsDefault(
|
||||
Graph& graph,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>& reader);
|
||||
|
Reference in New Issue
Block a user