[nativert] Add OSS version of ModelRunner (#159268)

Summary: Implement a ModelRunner from scratch with the minimum features for OSS only

Test Plan:
test_export -r NativeRT

Rollback Plan:

Differential Revision: D78979812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159268
Approved by: https://github.com/dolpm
This commit is contained in:
Zhengxu Chen
2025-07-29 21:08:10 +00:00
committed by PyTorch MergeBot
parent c0c24b61ff
commit 8460131087
16 changed files with 561 additions and 4 deletions

View File

@ -679,6 +679,7 @@ cc_library(
[
"torch/*.h",
"torch/csrc/**/*.h",
"torch/nativert/**/*.h",
"torch/csrc/distributed/c10d/**/*.hpp",
"torch/lib/libshm/*.h",
],

View File

@ -944,6 +944,7 @@ def define_buck_targets(
[
("torch/csrc/api/include", "torch/**/*.h"),
("", "torch/csrc/**/*.h"),
("", "torch/nativert/**/*.h"),
("", "torch/headeronly/**/*.h"),
("", "torch/script.h"),
("", "torch/library.h"),

View File

@ -593,6 +593,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
libtorch_nativert_sources = [
"torch/nativert/ModelRunner.cpp",
"torch/nativert/graph/Graph.cpp",
"torch/nativert/graph/GraphPasses.cpp",
"torch/nativert/graph/GraphSignature.cpp",
@ -986,6 +987,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/verbose.cpp",
"torch/csrc/cpu/Module.cpp",
"torch/csrc/instruction_counter/Module.cpp",
"torch/nativert/python/Bindings.cpp",
] + lazy_tensor_core_python_sources
libtorch_python_distributed_core_sources = [

View File

@ -5,8 +5,10 @@ file(GLOB_RECURSE NATIVERT_ALL_TEST_FILES "${NATIVERT_TEST_ROOT}/test_*.cpp")
# Build the cpp gtest binary containing the cpp-only tests.
set(NATIVERT_TEST_SRCS
${NATIVERT_ALL_TEST_FILES}
${TORCH_ROOT}/torch/nativert/ModelRunner.cpp
${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp

View File

@ -1543,6 +1543,9 @@ graph():
torch.export.export(M(), (torch.randn(7),), strict=strict)
def test_cond_branches_return_constant_int(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
class M(torch.nn.Module):
def forward(self, x):
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple())
@ -1588,6 +1591,7 @@ class GraphModule(torch.nn.Module):
)
self.assertEqual(m(*args), ep.module()(*args))
@testing.expectedFailureCppRuntimeNonStrict
def test_cond_access_identical_symint_closure(self):
class Example2(torch.nn.Module):
def forward(self, x, trigger, target):
@ -2278,6 +2282,9 @@ def forward(self, x, y):
ep = export(model, inputs)
def test_subclasses_parameterization(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2330,6 +2337,7 @@ graph():
self.assertEqual(res, ref_out)
@testing.expectedFailureCppRuntimeNonStrict
def test_subclasses_parameterization_nested(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -4970,6 +4978,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
)
def test_simple_unbacked_view(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
class Foo(torch.nn.Module):
def forward(self, x):
u0 = x.item()
@ -5301,6 +5312,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
ep.module()(torch.randn(6, 3), torch.randn(7, 4))
def test_map(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
class Module(torch.nn.Module):
def forward(self, xs, y, z):
def body(x, y, z):
@ -7947,6 +7961,9 @@ def forward(self, x):
self.assertEqual(ref_out, ep.module()(ref_x, mod))
def test_unbacked_noncontig_lin(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
@ -8011,6 +8028,9 @@ def forward(self, x):
self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps)))
def test_sym_or_sym_and(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
from torch.fx.experimental.symbolic_shapes import sym_and, sym_or
class Foo(torch.nn.Module):
@ -8326,6 +8346,9 @@ def forward(self, b_a_buffer, x):
@requires_cuda
def test_export_associative_scan_lifted_buffers(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
device = torch.device("cuda")
combine_mode = "pointwise"
@ -13242,6 +13265,9 @@ def forward(self, x, y):
self.assertTrue(placeholders[2].meta["val"].requires_grad)
def test_unbacked_expand(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
class Foo(torch.nn.Module):
def forward(self, xs):
u0, u1, u2 = xs.tolist()

View File

@ -0,0 +1,199 @@
# Owner(s): ["oncall: export"]
import copy
import pathlib
import tempfile
import unittest
import torch
from torch._C._nativert import PyModelRunner
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils import _pytree as pytree
try:
from . import test_export, testing
except ImportError:
import test_export
import testing
from torch.export import export
test_classes = {}
def _use_real_inputs(ep):
ep = copy.copy(ep)
has_fake_tensor = False
def _to_real_tensor(t):
if isinstance(t, torch.nn.Parameter):
return torch.nn.Parameter(_to_real_tensor(t.data))
if isinstance(t, FakeTensor):
nonlocal has_fake_tensor
has_fake_tensor = True
return torch.randn(t.shape, device=t.device, requires_grad=t.requires_grad)
return t
new_example_inputs = pytree.tree_map_only(
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.example_inputs
)
if has_fake_tensor:
ep.example_inputs = new_example_inputs
ep = ep._update(
ep.graph_module,
ep.graph_signature,
state_dict=pytree.tree_map_only(
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.state_dict
),
constants=pytree.tree_map_only(
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.constants
),
)
return ep
def _is_supported_types(arg) -> bool:
if isinstance(arg, list):
return (
all(_is_supported_types(a) for a in arg)
and len({type(a) for a in arg}) <= 1
)
elif isinstance(arg, tuple):
return all(_is_supported_types(a) for a in arg)
elif isinstance(arg, dict):
return (
all(_is_supported_types(a) for a in arg.values())
and len({type(a) for a in arg.values()}) <= 1
)
elif isinstance(arg, (torch.Tensor, int, float, bool, str)):
return True
elif arg is None:
return True
else:
return False
def run_with_nativert(ep):
# Downstream tests might mutate the exported program in subtle ways, so
# we need to make a copy here.
ep_infer = copy.deepcopy(ep)
ep_infer = _use_real_inputs(ep_infer.run_decompositions())
MODEL_NAME = "forward"
# TODO Does named tempfile have collision?
with tempfile.NamedTemporaryFile(delete=False) as f:
torch.export.pt2_archive._package.package_pt2(
f, exported_programs={MODEL_NAME: ep_infer}
)
filename = f.name
try:
ep_args, ep_kwargs = ep_infer.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
model_runner = PyModelRunner(filename, MODEL_NAME)
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) == type(expected)
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
finally:
pathlib.Path(filename).unlink(missing_ok=True)
return ep
def mocked_nativert_export_strict(*args, **kwargs):
if "strict" in kwargs:
ep = export(*args, **kwargs)
else:
ep = export(*args, **kwargs, strict=True)
run_with_nativert(ep)
return ep
def mocked_nativert_export_nonstrict(*args, **kwargs):
if "strict" in kwargs:
ep = export(*args, **kwargs)
else:
ep = export(*args, **kwargs, strict=False)
run_with_nativert(ep)
return ep
def make_dynamic_cls(cls, strict=False):
cls_prefix = "NativeRT"
if strict:
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
test_export.CPP_RUNTIME_STRICT_SUFFIX,
mocked_nativert_export_strict,
xfail_prop="_expected_failure_cpp_runtime",
test_only_if_no_xfail=True,
)
else:
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
test_export.CPP_RUNTIME_NONSTRICT_SUFFIX,
mocked_nativert_export_nonstrict,
xfail_prop="_expected_failure_cpp_runtime_non_strict",
test_only_if_no_xfail=True,
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
tests = [
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test, strict=True)
make_dynamic_cls(test, strict=False)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -138,6 +138,8 @@
#include <torch/csrc/itt.h>
#endif
#include <torch/nativert/python/Bindings.h>
namespace py = pybind11;
static PyObject* module;
@ -2780,6 +2782,8 @@ Call this whenever a new thread is created in order to propagate values from
#ifdef USE_KINETO
torch::global_kineto_init();
#endif
auto nativert_module = py_module.def_submodule("_nativert");
torch::nativert::initModelRunnerPybind(nativert_module);
return module;
END_HANDLE_TH_ERRORS
}

View File

@ -669,6 +669,16 @@ void Unpickler::readGlobal(
// See [NOTE] skip_next_read_global
this->skip_next_read_global--;
if (this->skip_next_read_global == 1) {
if (module_name == "torch" && class_name == "Tensor") {
// This is a special case when we are unpickling a subclassed tensor
// with type torch.nn.Buffer. We didn't frequently run into this because
// torch.nn.Buffer is introduced later in PyTorch 2 and this type IValue
// will not be used in C++.
rebuildTensor(false);
stack_.emplace_back(int64_t(globals_.size() - 1));
this->skip_next_read_global = 0;
return;
}
// Pass through to the correct handler
} else if (this->skip_next_read_global == 0) {
// Corresponds to the type of `Tensor` being unpickled
@ -773,6 +783,10 @@ void Unpickler::readGlobal(
// Unpickle a Tensor with Python attributes or
// a Subclassed Tensor.
rebuildTensorFromTypeV2();
} else if (
module_name == "torch._utils" && (class_name == "_rebuild_parameter")) {
// Unpickle a Parameter
rebuildParameter();
} else if (
module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
rebuildSparseTensor();
@ -1024,6 +1038,18 @@ void Unpickler::rebuildTensorFromTypeV2() {
});
}
void Unpickler::rebuildParameter() {
globals_.emplace_back([this] {
auto args = pop(stack_).toTuple();
size_t tup_idx = 0;
const auto args_elems = args->elements();
auto result = args_elems.at(tup_idx++).toTensor();
auto requires_grad = args_elems.at(tup_idx++).toBool();
result.requires_grad_(requires_grad);
stack_.emplace_back(std::move(result));
});
}
#ifdef USE_RPC
void Unpickler::rebuildRRef() {
globals_.emplace_back([this] {

View File

@ -137,6 +137,7 @@ class TORCH_API Unpickler {
const std::string& module_name,
const std::string& class_name);
void rebuildTensor(bool quantized);
void rebuildParameter();
void rebuildTensorFromTypeV2();
void rebuildSparseTensor();
#ifdef USE_DISTRIBUTED

View File

@ -0,0 +1,139 @@
#include <torch/nativert/ModelRunner.h>
#include <variant>
#include <nlohmann/json.hpp>
#include <caffe2/serialize/file_adapter.h>
#include <torch/csrc/export/pt2_archive_constants.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/nativert/executor/Placement.h>
#include <torch/nativert/graph/GraphPasses.h>
#include <torch/nativert/graph/Serialization.h>
namespace torch::nativert {
using torch::nativert::jsonToGraph;
using torch::nativert::detail::itreeSpecLoads;
namespace {
std::shared_ptr<Weights> loadWeightsDefault(
Graph& graph,
caffe2::serialize::PyTorchStreamReader& reader,
std::string_view modelName) {
auto weightsPath = fmt::format(
"{}{}.pt", torch::_export::archive_spec::WEIGHTS_DIR, modelName);
auto constantsPath = fmt::format(
"{}{}.pt", torch::_export::archive_spec::CONSTANTS_DIR, modelName);
TORCH_CHECK(
reader.hasRecord(weightsPath), weightsPath, " not found in package");
TORCH_CHECK(
reader.hasRecord(constantsPath), constantsPath, " not found in package");
const auto& [weightsData, weightsSize] = reader.getRecord(weightsPath);
auto weights =
torch::jit::pickle_load_obj(
std::string_view{static_cast<char*>(weightsData.get()), weightsSize})
.toGenericDict();
const auto& [constantsData, constantsSize] = reader.getRecord(constantsPath);
auto constants =
torch::jit::pickle_load_obj(
std::string_view{
static_cast<char*>(constantsData.get()), constantsSize})
.toGenericDict();
std::unordered_map<std::string, c10::IValue> stateDict;
std::unordered_map<std::string, c10::IValue> constantsDict;
for (const auto& item : weights) {
stateDict[item.key().toStringRef()] = item.value();
}
for (const auto& item : constants) {
constantsDict[item.key().toStringRef()] = item.value();
}
return std::make_shared<Weights>(&graph, stateDict, constantsDict);
}
} // namespace
ModelRunner::ModelRunner(
const std::string& packagePath,
const std::string& modelName) {
auto pytorchStreamReader =
std::make_shared<caffe2::serialize::PyTorchStreamReader>(
std::make_unique<caffe2::serialize::FileAdapter>(packagePath));
std::string modelFilePath = fmt::format(
torch::_export::archive_spec::MODELS_FILENAME_FORMAT, modelName);
LOG(INFO) << "Loading model from: " << modelFilePath;
TORCH_CHECK(
pytorchStreamReader->hasRecord(modelFilePath),
modelFilePath,
" not found in package");
const auto& [modelData, modelSize] =
pytorchStreamReader->getRecord(modelFilePath);
const std::string modelSerialized{
reinterpret_cast<char*>(modelData.get()), modelSize};
exportedProgram_ = nlohmann::json::parse(modelSerialized)
.template get<torch::_export::ExportedProgram>();
TORCH_CHECK(exportedProgram_.get_graph_module()
.get_module_call_graph()[0]
.get_fqn()
.empty());
graph_ = jsonToGraph(exportedProgram_.get_graph_module());
std::vector<const Value*> userInputs(
graph_->userInputs().begin(), graph_->userInputs().end());
const auto& signatureOpt = exportedProgram_.get_graph_module()
.get_module_call_graph()[0]
.get_signature();
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
const auto& signature = signatureOpt.value();
inputSpec_ = itreeSpecLoads(signature.get_in_spec(), userInputs);
const auto& userOutputs = graph_->userOutputs();
std::vector<const Value*> updatedUserOutput(userOutputs.size(), nullptr);
for (size_t i = 0; i < userOutputs.size(); ++i) {
if (const auto* valuePtr = std::get_if<Value*>(&userOutputs[i])) {
updatedUserOutput[i] = *valuePtr;
}
}
outputSpec_ = itreeSpecLoads(signature.get_out_spec(), updatedUserOutput);
torch::nativert::Placement placement;
graph_->applyDevicePlacement(placement);
selectScalarOverload(graph_.get());
auto weights = loadWeightsDefault(*graph_, *pytorchStreamReader, modelName);
weights->validateAllWeightsLoaded();
torch::nativert::ExecutorConfig config;
executor_ = std::make_unique<Executor>(
config, graph_, std::move(weights), pytorchStreamReader);
}
c10::IValue ModelRunner::run(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
TORCH_CHECK(executor_, "ModelRunner not initialized");
// ModelRunner is only used for inference
c10::InferenceMode mode;
return itreeUnflatten(
executor_->execute(args, kwargs, inputSpec_), outputSpec_);
}
std::vector<c10::IValue> ModelRunner::runWithFlatInputsAndOutputs(
std::vector<c10::IValue> flatInputs) {
TORCH_CHECK(executor_, "ModelRunner not initialized");
// ModelRunner is only used for inference
c10::InferenceMode mode;
return executor_->execute(std::move(flatInputs));
}
} // namespace torch::nativert

View File

@ -0,0 +1,45 @@
#pragma once
#include <fmt/format.h>
#include <c10/macros/Export.h>
#include <torch/csrc/utils/generated_serialization_types.h>
#include <torch/nativert/detail/ITree.h>
#include <torch/nativert/executor/Executor.h>
#include <torch/nativert/executor/Placement.h>
namespace torch::nativert {
class TORCH_API ModelRunner {
public:
ModelRunner(const std::string& packagePath, const std::string& modelName);
ModelRunner(ModelRunner&&) = default;
ModelRunner& operator=(ModelRunner&&) = default;
ModelRunner(const ModelRunner&) = delete;
ModelRunner& operator=(const ModelRunner&) = delete;
~ModelRunner() = default;
c10::IValue run(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
/**
* A low level API which expects user to always pass in flattened inputs.
* The ownership of the entire input list must be transferred to the
* executor via std::move or in-place construction.
*/
std::vector<c10::IValue> runWithFlatInputsAndOutputs(
std::vector<c10::IValue> flatInputs);
private:
// original non-delegated graph from torch.export()
std::shared_ptr<Graph> graph_;
std::unique_ptr<Executor> executor_;
ITreeSpec inputSpec_;
ITreeSpec outputSpec_;
torch::_export::ExportedProgram exportedProgram_;
};
} // namespace torch::nativert

View File

@ -25,13 +25,26 @@ WeightVersion Weights::globalVersion_ = 0;
Weights::Weights(
const Graph* graph,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
stateDict)
stateDict,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
constants)
: graph_(graph),
weightsMeta_(graph->weightsMeta()),
version_(globalVersion_++) {
if (stateDict.has_value()) {
loadStateDict(stateDict.value());
}
if (constants.has_value()) {
for (const auto& [name, value] : constants.value()) {
if (value.isTensor()) {
allValues_[name] = value.toTensor();
} else if (value.isCustomClass()) {
customObjs_[name] = value;
} else {
TORCH_CHECK(false, "Unknown constant type: ", value.tagKind());
}
}
}
}
Weights::Weights(

View File

@ -20,10 +20,12 @@ using WeightVersion = int;
*/
class Weights {
public:
explicit Weights(
Weights(
const Graph* graph,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
stateDict = std::nullopt);
stateDict = std::nullopt,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
constants = std::nullopt);
// Arguments
// - pytorchStreamReader: the reader for the model archive

View File

@ -51,7 +51,10 @@ class LayoutPlanner {
kernelSchemas,
const std::vector<bool>& persistentValues,
const torch::nativert::LayoutPlannerSettings& settings);
~LayoutPlanner();
#if !defined(_MSC_VER)
TORCH_API // TODO Doesn't work on msvc.
#endif
~LayoutPlanner();
LayoutPlanner(LayoutPlanner&& other) = delete;
LayoutPlanner(const LayoutPlanner& other) = delete;

View File

@ -0,0 +1,80 @@
#include <unordered_map>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/nativert/ModelRunner.h>
namespace py = pybind11;
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
namespace torch {
namespace nativert {
using torch::nativert::detail::argsToIValue;
void initModelRunnerPybind(py::module& m) {
#if !defined(OVRSOURCE)
shared_ptr_class_<ModelRunner>(m, "PyModelRunner")
.def(
py::init<const std::string&, const std::string&>(),
py::arg("packagePath"),
py::arg("modelName"))
.def(
"run",
[](torch::nativert::ModelRunner& self,
py::args pyargs,
const py::kwargs& pykwargs) {
std::vector<c10::IValue> args;
for (const auto i : c10::irange(pyargs.size())) {
auto ivalue =
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
args.push_back(std::move(ivalue));
}
std::unordered_map<std::string, c10::IValue> kwargs;
for (const auto& [key, pyarg] : pykwargs) {
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
kwargs[py::str(key)] = std::move(ivalue);
}
c10::IValue ret = self.run(args, kwargs);
return torch::jit::createPyObjectForStack({ret});
})
.def(
"__call__",
[](torch::nativert::ModelRunner& self,
py::args pyargs,
const py::kwargs& pykwargs) {
std::vector<c10::IValue> args;
for (const auto i : c10::irange(pyargs.size())) {
auto ivalue =
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
args.push_back(std::move(ivalue));
}
std::unordered_map<std::string, c10::IValue> kwargs;
for (const auto& [key, pyarg] : pykwargs) {
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
kwargs[py::str(key)] = std::move(ivalue);
}
c10::IValue ret = self.run(args, kwargs);
return torch::jit::createPyObjectForStack({ret});
})
.def(
"run_with_flat_inputs_and_outputs",
[](torch::nativert::ModelRunner& self, py::args pyargs) {
std::vector<c10::IValue> args;
for (const auto i : c10::irange(pyargs.size())) {
auto ivalue =
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
args.push_back(std::move(ivalue));
}
auto rets = self.runWithFlatInputsAndOutputs(std::move(args));
return torch::jit::createPyObjectForStack(std::move(rets));
});
#endif // !defined(OVRSOURCE)
}
} // namespace nativert
} // namespace torch

View File

@ -0,0 +1,13 @@
#pragma once
#include <torch/csrc/utils/pybind.h>
namespace py = pybind11;
namespace torch {
namespace nativert {
void initModelRunnerPybind(pybind11::module& m);
} // namespace nativert
} // namespace torch