From 8460131087b34191717aa723cb60a19e6bbbf5af Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Tue, 29 Jul 2025 21:08:10 +0000 Subject: [PATCH] [nativert] Add OSS version of ModelRunner (#159268) Summary: Implement a ModelRunner from scratch with the minimum features for OSS only Test Plan: test_export -r NativeRT Rollback Plan: Differential Revision: D78979812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159268 Approved by: https://github.com/dolpm --- BUILD.bazel | 1 + buckbuild.bzl | 1 + build_variables.bzl | 2 + test/cpp/nativert/CMakeLists.txt | 2 + test/export/test_export.py | 26 +++ test/export/test_nativert.py | 199 ++++++++++++++++++ torch/csrc/Module.cpp | 4 + torch/csrc/jit/serialization/unpickler.cpp | 26 +++ torch/csrc/jit/serialization/unpickler.h | 1 + torch/nativert/ModelRunner.cpp | 139 ++++++++++++ torch/nativert/ModelRunner.h | 45 ++++ torch/nativert/executor/Weights.cpp | 15 +- torch/nativert/executor/Weights.h | 6 +- .../nativert/executor/memory/LayoutPlanner.h | 5 +- torch/nativert/python/Bindings.cpp | 80 +++++++ torch/nativert/python/Bindings.h | 13 ++ 16 files changed, 561 insertions(+), 4 deletions(-) create mode 100644 test/export/test_nativert.py create mode 100644 torch/nativert/ModelRunner.cpp create mode 100644 torch/nativert/ModelRunner.h create mode 100644 torch/nativert/python/Bindings.cpp create mode 100644 torch/nativert/python/Bindings.h diff --git a/BUILD.bazel b/BUILD.bazel index 5a31eb6558aa..50ffa1257647 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -679,6 +679,7 @@ cc_library( [ "torch/*.h", "torch/csrc/**/*.h", + "torch/nativert/**/*.h", "torch/csrc/distributed/c10d/**/*.hpp", "torch/lib/libshm/*.h", ], diff --git a/buckbuild.bzl b/buckbuild.bzl index 4eb92674ceec..09a515584d97 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -944,6 +944,7 @@ def define_buck_targets( [ ("torch/csrc/api/include", "torch/**/*.h"), ("", "torch/csrc/**/*.h"), + ("", "torch/nativert/**/*.h"), ("", "torch/headeronly/**/*.h"), ("", "torch/script.h"), ("", "torch/library.h"), diff --git a/build_variables.bzl b/build_variables.bzl index f6fba33dc4d4..6f55b156f8a5 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -593,6 +593,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full) libtorch_nativert_sources = [ + "torch/nativert/ModelRunner.cpp", "torch/nativert/graph/Graph.cpp", "torch/nativert/graph/GraphPasses.cpp", "torch/nativert/graph/GraphSignature.cpp", @@ -986,6 +987,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils/verbose.cpp", "torch/csrc/cpu/Module.cpp", "torch/csrc/instruction_counter/Module.cpp", + "torch/nativert/python/Bindings.cpp", ] + lazy_tensor_core_python_sources libtorch_python_distributed_core_sources = [ diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index b6e6cd20ced7..c05416ce0eef 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -5,8 +5,10 @@ file(GLOB_RECURSE NATIVERT_ALL_TEST_FILES "${NATIVERT_TEST_ROOT}/test_*.cpp") # Build the cpp gtest binary containing the cpp-only tests. set(NATIVERT_TEST_SRCS ${NATIVERT_ALL_TEST_FILES} + ${TORCH_ROOT}/torch/nativert/ModelRunner.cpp ${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp ${TORCH_ROOT}/torch/nativert/graph/Graph.cpp + ${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp ${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp ${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp diff --git a/test/export/test_export.py b/test/export/test_export.py index a46a8815f93e..18726086edb1 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1543,6 +1543,9 @@ graph(): torch.export.export(M(), (torch.randn(7),), strict=strict) def test_cond_branches_return_constant_int(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + class M(torch.nn.Module): def forward(self, x): idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple()) @@ -1588,6 +1591,7 @@ class GraphModule(torch.nn.Module): ) self.assertEqual(m(*args), ep.module()(*args)) + @testing.expectedFailureCppRuntimeNonStrict def test_cond_access_identical_symint_closure(self): class Example2(torch.nn.Module): def forward(self, x, trigger, target): @@ -2278,6 +2282,9 @@ def forward(self, x, y): ep = export(model, inputs) def test_subclasses_parameterization(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + class Foo(torch.nn.Module): def __init__(self): super().__init__() @@ -2330,6 +2337,7 @@ graph(): self.assertEqual(res, ref_out) + @testing.expectedFailureCppRuntimeNonStrict def test_subclasses_parameterization_nested(self): class Foo(torch.nn.Module): def __init__(self): @@ -4970,6 +4978,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): ) def test_simple_unbacked_view(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + class Foo(torch.nn.Module): def forward(self, x): u0 = x.item() @@ -5301,6 +5312,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): ep.module()(torch.randn(6, 3), torch.randn(7, 4)) def test_map(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + class Module(torch.nn.Module): def forward(self, xs, y, z): def body(x, y, z): @@ -7947,6 +7961,9 @@ def forward(self, x): self.assertEqual(ref_out, ep.module()(ref_x, mod)) def test_unbacked_noncontig_lin(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + class Foo(torch.nn.Module): def __init__(self): super().__init__() @@ -8011,6 +8028,9 @@ def forward(self, x): self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps))) def test_sym_or_sym_and(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + from torch.fx.experimental.symbolic_shapes import sym_and, sym_or class Foo(torch.nn.Module): @@ -8326,6 +8346,9 @@ def forward(self, b_a_buffer, x): @requires_cuda def test_export_associative_scan_lifted_buffers(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + device = torch.device("cuda") combine_mode = "pointwise" @@ -13242,6 +13265,9 @@ def forward(self, x, y): self.assertTrue(placeholders[2].meta["val"].requires_grad) def test_unbacked_expand(self): + if "cpp_runtime_nonstrict" in self.id(): + self.skipTest("TODO Unexpected success in OSS but not in fbcode.") + class Foo(torch.nn.Module): def forward(self, xs): u0, u1, u2 = xs.tolist() diff --git a/test/export/test_nativert.py b/test/export/test_nativert.py new file mode 100644 index 000000000000..044b6051400d --- /dev/null +++ b/test/export/test_nativert.py @@ -0,0 +1,199 @@ +# Owner(s): ["oncall: export"] + + +import copy +import pathlib +import tempfile +import unittest + +import torch +from torch._C._nativert import PyModelRunner +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils import _pytree as pytree + + +try: + from . import test_export, testing +except ImportError: + import test_export + import testing + +from torch.export import export + + +test_classes = {} + + +def _use_real_inputs(ep): + ep = copy.copy(ep) + + has_fake_tensor = False + + def _to_real_tensor(t): + if isinstance(t, torch.nn.Parameter): + return torch.nn.Parameter(_to_real_tensor(t.data)) + if isinstance(t, FakeTensor): + nonlocal has_fake_tensor + has_fake_tensor = True + return torch.randn(t.shape, device=t.device, requires_grad=t.requires_grad) + return t + + new_example_inputs = pytree.tree_map_only( + (torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.example_inputs + ) + if has_fake_tensor: + ep.example_inputs = new_example_inputs + + ep = ep._update( + ep.graph_module, + ep.graph_signature, + state_dict=pytree.tree_map_only( + (torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.state_dict + ), + constants=pytree.tree_map_only( + (torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.constants + ), + ) + return ep + + +def _is_supported_types(arg) -> bool: + if isinstance(arg, list): + return ( + all(_is_supported_types(a) for a in arg) + and len({type(a) for a in arg}) <= 1 + ) + elif isinstance(arg, tuple): + return all(_is_supported_types(a) for a in arg) + elif isinstance(arg, dict): + return ( + all(_is_supported_types(a) for a in arg.values()) + and len({type(a) for a in arg.values()}) <= 1 + ) + elif isinstance(arg, (torch.Tensor, int, float, bool, str)): + return True + elif arg is None: + return True + else: + return False + + +def run_with_nativert(ep): + # Downstream tests might mutate the exported program in subtle ways, so + # we need to make a copy here. + ep_infer = copy.deepcopy(ep) + ep_infer = _use_real_inputs(ep_infer.run_decompositions()) + MODEL_NAME = "forward" + + # TODO Does named tempfile have collision? + with tempfile.NamedTemporaryFile(delete=False) as f: + torch.export.pt2_archive._package.package_pt2( + f, exported_programs={MODEL_NAME: ep_infer} + ) + filename = f.name + + try: + ep_args, ep_kwargs = ep_infer.example_inputs + ep_args_copied, ep_kwargs_copied = ( + copy.deepcopy(ep_args), + copy.deepcopy(ep_kwargs), + ) + torch.manual_seed(0) + try: + flat_expected = pytree.tree_leaves( + ep_infer.module()(*ep_args_copied, **ep_kwargs_copied) + ) + except Exception as e: + raise unittest.case.SkipTest(str(e)) from e + + model_runner = PyModelRunner(filename, MODEL_NAME) + torch.manual_seed(0) + if _is_supported_types((ep_args, ep_kwargs)): + results = model_runner.run(*ep_args, **ep_kwargs) + else: + results = model_runner.run_with_flat_inputs_and_outputs( + *pytree.tree_leaves((ep_args, ep_kwargs)) + ) + flat_results = pytree.tree_leaves(results) + assert len(flat_results) == len(flat_expected) + for result, expected in zip(flat_results, flat_expected): + assert type(result) == type(expected) + if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor): + assert result.shape == expected.shape + assert result.dtype == expected.dtype + assert result.device == expected.device + torch.testing.assert_close(result, expected, equal_nan=True) + else: + assert result == expected + except RuntimeError as e: + # User need to register pytree type on the cpp side, which + # cannot be tested in python unittest. + if "Unknown pytree node type" in str(e): + pass + else: + raise e + finally: + pathlib.Path(filename).unlink(missing_ok=True) + return ep + + +def mocked_nativert_export_strict(*args, **kwargs): + if "strict" in kwargs: + ep = export(*args, **kwargs) + else: + ep = export(*args, **kwargs, strict=True) + + run_with_nativert(ep) + return ep + + +def mocked_nativert_export_nonstrict(*args, **kwargs): + if "strict" in kwargs: + ep = export(*args, **kwargs) + else: + ep = export(*args, **kwargs, strict=False) + + run_with_nativert(ep) + return ep + + +def make_dynamic_cls(cls, strict=False): + cls_prefix = "NativeRT" + + if strict: + test_class = testing.make_test_cls_with_mocked_export( + cls, + cls_prefix, + test_export.CPP_RUNTIME_STRICT_SUFFIX, + mocked_nativert_export_strict, + xfail_prop="_expected_failure_cpp_runtime", + test_only_if_no_xfail=True, + ) + else: + test_class = testing.make_test_cls_with_mocked_export( + cls, + cls_prefix, + test_export.CPP_RUNTIME_NONSTRICT_SUFFIX, + mocked_nativert_export_nonstrict, + xfail_prop="_expected_failure_cpp_runtime_non_strict", + test_only_if_no_xfail=True, + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + + +tests = [ + test_export.TestExport, +] +for test in tests: + make_dynamic_cls(test, strict=True) + make_dynamic_cls(test, strict=False) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index aab2a31402aa..20116b97a481 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -138,6 +138,8 @@ #include #endif +#include + namespace py = pybind11; static PyObject* module; @@ -2780,6 +2782,8 @@ Call this whenever a new thread is created in order to propagate values from #ifdef USE_KINETO torch::global_kineto_init(); #endif + auto nativert_module = py_module.def_submodule("_nativert"); + torch::nativert::initModelRunnerPybind(nativert_module); return module; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 747dd8aae55d..0253a5588030 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -669,6 +669,16 @@ void Unpickler::readGlobal( // See [NOTE] skip_next_read_global this->skip_next_read_global--; if (this->skip_next_read_global == 1) { + if (module_name == "torch" && class_name == "Tensor") { + // This is a special case when we are unpickling a subclassed tensor + // with type torch.nn.Buffer. We didn't frequently run into this because + // torch.nn.Buffer is introduced later in PyTorch 2 and this type IValue + // will not be used in C++. + rebuildTensor(false); + stack_.emplace_back(int64_t(globals_.size() - 1)); + this->skip_next_read_global = 0; + return; + } // Pass through to the correct handler } else if (this->skip_next_read_global == 0) { // Corresponds to the type of `Tensor` being unpickled @@ -773,6 +783,10 @@ void Unpickler::readGlobal( // Unpickle a Tensor with Python attributes or // a Subclassed Tensor. rebuildTensorFromTypeV2(); + } else if ( + module_name == "torch._utils" && (class_name == "_rebuild_parameter")) { + // Unpickle a Parameter + rebuildParameter(); } else if ( module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { rebuildSparseTensor(); @@ -1024,6 +1038,18 @@ void Unpickler::rebuildTensorFromTypeV2() { }); } +void Unpickler::rebuildParameter() { + globals_.emplace_back([this] { + auto args = pop(stack_).toTuple(); + size_t tup_idx = 0; + const auto args_elems = args->elements(); + auto result = args_elems.at(tup_idx++).toTensor(); + auto requires_grad = args_elems.at(tup_idx++).toBool(); + result.requires_grad_(requires_grad); + stack_.emplace_back(std::move(result)); + }); +} + #ifdef USE_RPC void Unpickler::rebuildRRef() { globals_.emplace_back([this] { diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index d66cf23f4789..702a1d8816e7 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -137,6 +137,7 @@ class TORCH_API Unpickler { const std::string& module_name, const std::string& class_name); void rebuildTensor(bool quantized); + void rebuildParameter(); void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); #ifdef USE_DISTRIBUTED diff --git a/torch/nativert/ModelRunner.cpp b/torch/nativert/ModelRunner.cpp new file mode 100644 index 000000000000..f1c2a35db14c --- /dev/null +++ b/torch/nativert/ModelRunner.cpp @@ -0,0 +1,139 @@ +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace torch::nativert { + +using torch::nativert::jsonToGraph; +using torch::nativert::detail::itreeSpecLoads; + +namespace { +std::shared_ptr loadWeightsDefault( + Graph& graph, + caffe2::serialize::PyTorchStreamReader& reader, + std::string_view modelName) { + auto weightsPath = fmt::format( + "{}{}.pt", torch::_export::archive_spec::WEIGHTS_DIR, modelName); + auto constantsPath = fmt::format( + "{}{}.pt", torch::_export::archive_spec::CONSTANTS_DIR, modelName); + TORCH_CHECK( + reader.hasRecord(weightsPath), weightsPath, " not found in package"); + TORCH_CHECK( + reader.hasRecord(constantsPath), constantsPath, " not found in package"); + const auto& [weightsData, weightsSize] = reader.getRecord(weightsPath); + auto weights = + torch::jit::pickle_load_obj( + std::string_view{static_cast(weightsData.get()), weightsSize}) + .toGenericDict(); + const auto& [constantsData, constantsSize] = reader.getRecord(constantsPath); + auto constants = + torch::jit::pickle_load_obj( + std::string_view{ + static_cast(constantsData.get()), constantsSize}) + .toGenericDict(); + std::unordered_map stateDict; + std::unordered_map constantsDict; + for (const auto& item : weights) { + stateDict[item.key().toStringRef()] = item.value(); + } + for (const auto& item : constants) { + constantsDict[item.key().toStringRef()] = item.value(); + } + return std::make_shared(&graph, stateDict, constantsDict); +} +} // namespace + +ModelRunner::ModelRunner( + const std::string& packagePath, + const std::string& modelName) { + auto pytorchStreamReader = + std::make_shared( + std::make_unique(packagePath)); + std::string modelFilePath = fmt::format( + torch::_export::archive_spec::MODELS_FILENAME_FORMAT, modelName); + LOG(INFO) << "Loading model from: " << modelFilePath; + + TORCH_CHECK( + pytorchStreamReader->hasRecord(modelFilePath), + modelFilePath, + " not found in package"); + const auto& [modelData, modelSize] = + pytorchStreamReader->getRecord(modelFilePath); + const std::string modelSerialized{ + reinterpret_cast(modelData.get()), modelSize}; + + exportedProgram_ = nlohmann::json::parse(modelSerialized) + .template get(); + + TORCH_CHECK(exportedProgram_.get_graph_module() + .get_module_call_graph()[0] + .get_fqn() + .empty()); + + graph_ = jsonToGraph(exportedProgram_.get_graph_module()); + + std::vector userInputs( + graph_->userInputs().begin(), graph_->userInputs().end()); + const auto& signatureOpt = exportedProgram_.get_graph_module() + .get_module_call_graph()[0] + .get_signature(); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + const auto& signature = signatureOpt.value(); + inputSpec_ = itreeSpecLoads(signature.get_in_spec(), userInputs); + + const auto& userOutputs = graph_->userOutputs(); + std::vector updatedUserOutput(userOutputs.size(), nullptr); + for (size_t i = 0; i < userOutputs.size(); ++i) { + if (const auto* valuePtr = std::get_if(&userOutputs[i])) { + updatedUserOutput[i] = *valuePtr; + } + } + outputSpec_ = itreeSpecLoads(signature.get_out_spec(), updatedUserOutput); + + torch::nativert::Placement placement; + + graph_->applyDevicePlacement(placement); + selectScalarOverload(graph_.get()); + + auto weights = loadWeightsDefault(*graph_, *pytorchStreamReader, modelName); + + weights->validateAllWeightsLoaded(); + + torch::nativert::ExecutorConfig config; + + executor_ = std::make_unique( + config, graph_, std::move(weights), pytorchStreamReader); +} + +c10::IValue ModelRunner::run( + const std::vector& args, + const std::unordered_map& kwargs) { + TORCH_CHECK(executor_, "ModelRunner not initialized"); + + // ModelRunner is only used for inference + c10::InferenceMode mode; + + return itreeUnflatten( + executor_->execute(args, kwargs, inputSpec_), outputSpec_); +} + +std::vector ModelRunner::runWithFlatInputsAndOutputs( + std::vector flatInputs) { + TORCH_CHECK(executor_, "ModelRunner not initialized"); + + // ModelRunner is only used for inference + c10::InferenceMode mode; + + return executor_->execute(std::move(flatInputs)); +} + +} // namespace torch::nativert diff --git a/torch/nativert/ModelRunner.h b/torch/nativert/ModelRunner.h new file mode 100644 index 000000000000..4c8875731885 --- /dev/null +++ b/torch/nativert/ModelRunner.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace torch::nativert { +class TORCH_API ModelRunner { + public: + ModelRunner(const std::string& packagePath, const std::string& modelName); + + ModelRunner(ModelRunner&&) = default; + ModelRunner& operator=(ModelRunner&&) = default; + ModelRunner(const ModelRunner&) = delete; + ModelRunner& operator=(const ModelRunner&) = delete; + ~ModelRunner() = default; + + c10::IValue run( + const std::vector& args, + const std::unordered_map& kwargs); + + /** + * A low level API which expects user to always pass in flattened inputs. + * The ownership of the entire input list must be transferred to the + * executor via std::move or in-place construction. + */ + std::vector runWithFlatInputsAndOutputs( + std::vector flatInputs); + + private: + // original non-delegated graph from torch.export() + std::shared_ptr graph_; + + std::unique_ptr executor_; + + ITreeSpec inputSpec_; + ITreeSpec outputSpec_; + + torch::_export::ExportedProgram exportedProgram_; +}; +} // namespace torch::nativert diff --git a/torch/nativert/executor/Weights.cpp b/torch/nativert/executor/Weights.cpp index 918b532160c1..d685cc1a7816 100644 --- a/torch/nativert/executor/Weights.cpp +++ b/torch/nativert/executor/Weights.cpp @@ -25,13 +25,26 @@ WeightVersion Weights::globalVersion_ = 0; Weights::Weights( const Graph* graph, const std::optional>& - stateDict) + stateDict, + const std::optional>& + constants) : graph_(graph), weightsMeta_(graph->weightsMeta()), version_(globalVersion_++) { if (stateDict.has_value()) { loadStateDict(stateDict.value()); } + if (constants.has_value()) { + for (const auto& [name, value] : constants.value()) { + if (value.isTensor()) { + allValues_[name] = value.toTensor(); + } else if (value.isCustomClass()) { + customObjs_[name] = value; + } else { + TORCH_CHECK(false, "Unknown constant type: ", value.tagKind()); + } + } + } } Weights::Weights( diff --git a/torch/nativert/executor/Weights.h b/torch/nativert/executor/Weights.h index 7108f32bba9a..e3c1469c0d5c 100644 --- a/torch/nativert/executor/Weights.h +++ b/torch/nativert/executor/Weights.h @@ -20,10 +20,12 @@ using WeightVersion = int; */ class Weights { public: - explicit Weights( + Weights( const Graph* graph, const std::optional>& - stateDict = std::nullopt); + stateDict = std::nullopt, + const std::optional>& + constants = std::nullopt); // Arguments // - pytorchStreamReader: the reader for the model archive diff --git a/torch/nativert/executor/memory/LayoutPlanner.h b/torch/nativert/executor/memory/LayoutPlanner.h index 76d9ae2a2d54..c3620aec0021 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.h +++ b/torch/nativert/executor/memory/LayoutPlanner.h @@ -51,7 +51,10 @@ class LayoutPlanner { kernelSchemas, const std::vector& persistentValues, const torch::nativert::LayoutPlannerSettings& settings); - ~LayoutPlanner(); +#if !defined(_MSC_VER) + TORCH_API // TODO Doesn't work on msvc. +#endif + ~LayoutPlanner(); LayoutPlanner(LayoutPlanner&& other) = delete; LayoutPlanner(const LayoutPlanner& other) = delete; diff --git a/torch/nativert/python/Bindings.cpp b/torch/nativert/python/Bindings.cpp new file mode 100644 index 000000000000..77939702a2d9 --- /dev/null +++ b/torch/nativert/python/Bindings.cpp @@ -0,0 +1,80 @@ +#include + +#include +#include + +#include + +namespace py = pybind11; + +template +using shared_ptr_class_ = py::class_>; + +namespace torch { +namespace nativert { + +using torch::nativert::detail::argsToIValue; + +void initModelRunnerPybind(py::module& m) { +#if !defined(OVRSOURCE) + shared_ptr_class_(m, "PyModelRunner") + .def( + py::init(), + py::arg("packagePath"), + py::arg("modelName")) + .def( + "run", + [](torch::nativert::ModelRunner& self, + py::args pyargs, + const py::kwargs& pykwargs) { + std::vector args; + for (const auto i : c10::irange(pyargs.size())) { + auto ivalue = + torch::jit::toIValue(pyargs[i], c10::AnyType::get()); + args.push_back(std::move(ivalue)); + } + std::unordered_map kwargs; + for (const auto& [key, pyarg] : pykwargs) { + auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get()); + kwargs[py::str(key)] = std::move(ivalue); + } + c10::IValue ret = self.run(args, kwargs); + return torch::jit::createPyObjectForStack({ret}); + }) + .def( + "__call__", + [](torch::nativert::ModelRunner& self, + py::args pyargs, + const py::kwargs& pykwargs) { + std::vector args; + for (const auto i : c10::irange(pyargs.size())) { + auto ivalue = + torch::jit::toIValue(pyargs[i], c10::AnyType::get()); + args.push_back(std::move(ivalue)); + } + std::unordered_map kwargs; + for (const auto& [key, pyarg] : pykwargs) { + auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get()); + kwargs[py::str(key)] = std::move(ivalue); + } + c10::IValue ret = self.run(args, kwargs); + return torch::jit::createPyObjectForStack({ret}); + }) + .def( + "run_with_flat_inputs_and_outputs", + [](torch::nativert::ModelRunner& self, py::args pyargs) { + std::vector args; + for (const auto i : c10::irange(pyargs.size())) { + auto ivalue = + torch::jit::toIValue(pyargs[i], c10::AnyType::get()); + args.push_back(std::move(ivalue)); + } + + auto rets = self.runWithFlatInputsAndOutputs(std::move(args)); + return torch::jit::createPyObjectForStack(std::move(rets)); + }); +#endif // !defined(OVRSOURCE) +} + +} // namespace nativert +} // namespace torch diff --git a/torch/nativert/python/Bindings.h b/torch/nativert/python/Bindings.h new file mode 100644 index 000000000000..1821d211e27d --- /dev/null +++ b/torch/nativert/python/Bindings.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace py = pybind11; + +namespace torch { +namespace nativert { + +void initModelRunnerPybind(pybind11::module& m); + +} // namespace nativert +} // namespace torch