mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[nativert] Add OSS version of ModelRunner (#159268)
Summary: Implement a ModelRunner from scratch with the minimum features for OSS only Test Plan: test_export -r NativeRT Rollback Plan: Differential Revision: D78979812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159268 Approved by: https://github.com/dolpm
This commit is contained in:
committed by
PyTorch MergeBot
parent
c0c24b61ff
commit
8460131087
@ -679,6 +679,7 @@ cc_library(
|
||||
[
|
||||
"torch/*.h",
|
||||
"torch/csrc/**/*.h",
|
||||
"torch/nativert/**/*.h",
|
||||
"torch/csrc/distributed/c10d/**/*.hpp",
|
||||
"torch/lib/libshm/*.h",
|
||||
],
|
||||
|
@ -944,6 +944,7 @@ def define_buck_targets(
|
||||
[
|
||||
("torch/csrc/api/include", "torch/**/*.h"),
|
||||
("", "torch/csrc/**/*.h"),
|
||||
("", "torch/nativert/**/*.h"),
|
||||
("", "torch/headeronly/**/*.h"),
|
||||
("", "torch/script.h"),
|
||||
("", "torch/library.h"),
|
||||
|
@ -593,6 +593,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
|
||||
|
||||
|
||||
libtorch_nativert_sources = [
|
||||
"torch/nativert/ModelRunner.cpp",
|
||||
"torch/nativert/graph/Graph.cpp",
|
||||
"torch/nativert/graph/GraphPasses.cpp",
|
||||
"torch/nativert/graph/GraphSignature.cpp",
|
||||
@ -986,6 +987,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/verbose.cpp",
|
||||
"torch/csrc/cpu/Module.cpp",
|
||||
"torch/csrc/instruction_counter/Module.cpp",
|
||||
"torch/nativert/python/Bindings.cpp",
|
||||
] + lazy_tensor_core_python_sources
|
||||
|
||||
libtorch_python_distributed_core_sources = [
|
||||
|
@ -5,8 +5,10 @@ file(GLOB_RECURSE NATIVERT_ALL_TEST_FILES "${NATIVERT_TEST_ROOT}/test_*.cpp")
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(NATIVERT_TEST_SRCS
|
||||
${NATIVERT_ALL_TEST_FILES}
|
||||
${TORCH_ROOT}/torch/nativert/ModelRunner.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp
|
||||
|
@ -1543,6 +1543,9 @@ graph():
|
||||
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||
|
||||
def test_cond_branches_return_constant_int(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple())
|
||||
@ -1588,6 +1591,7 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
self.assertEqual(m(*args), ep.module()(*args))
|
||||
|
||||
@testing.expectedFailureCppRuntimeNonStrict
|
||||
def test_cond_access_identical_symint_closure(self):
|
||||
class Example2(torch.nn.Module):
|
||||
def forward(self, x, trigger, target):
|
||||
@ -2278,6 +2282,9 @@ def forward(self, x, y):
|
||||
ep = export(model, inputs)
|
||||
|
||||
def test_subclasses_parameterization(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -2330,6 +2337,7 @@ graph():
|
||||
|
||||
self.assertEqual(res, ref_out)
|
||||
|
||||
@testing.expectedFailureCppRuntimeNonStrict
|
||||
def test_subclasses_parameterization_nested(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -4970,6 +4978,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
)
|
||||
|
||||
def test_simple_unbacked_view(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
u0 = x.item()
|
||||
@ -5301,6 +5312,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
ep.module()(torch.randn(6, 3), torch.randn(7, 4))
|
||||
|
||||
def test_map(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, xs, y, z):
|
||||
def body(x, y, z):
|
||||
@ -7947,6 +7961,9 @@ def forward(self, x):
|
||||
self.assertEqual(ref_out, ep.module()(ref_x, mod))
|
||||
|
||||
def test_unbacked_noncontig_lin(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -8011,6 +8028,9 @@ def forward(self, x):
|
||||
self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps)))
|
||||
|
||||
def test_sym_or_sym_and(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import sym_and, sym_or
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
@ -8326,6 +8346,9 @@ def forward(self, b_a_buffer, x):
|
||||
|
||||
@requires_cuda
|
||||
def test_export_associative_scan_lifted_buffers(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
device = torch.device("cuda")
|
||||
combine_mode = "pointwise"
|
||||
|
||||
@ -13242,6 +13265,9 @@ def forward(self, x, y):
|
||||
self.assertTrue(placeholders[2].meta["val"].requires_grad)
|
||||
|
||||
def test_unbacked_expand(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, xs):
|
||||
u0, u1, u2 = xs.tolist()
|
||||
|
199
test/export/test_nativert.py
Normal file
199
test/export/test_nativert.py
Normal file
@ -0,0 +1,199 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._C._nativert import PyModelRunner
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export
|
||||
import testing
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def _use_real_inputs(ep):
|
||||
ep = copy.copy(ep)
|
||||
|
||||
has_fake_tensor = False
|
||||
|
||||
def _to_real_tensor(t):
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
return torch.nn.Parameter(_to_real_tensor(t.data))
|
||||
if isinstance(t, FakeTensor):
|
||||
nonlocal has_fake_tensor
|
||||
has_fake_tensor = True
|
||||
return torch.randn(t.shape, device=t.device, requires_grad=t.requires_grad)
|
||||
return t
|
||||
|
||||
new_example_inputs = pytree.tree_map_only(
|
||||
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.example_inputs
|
||||
)
|
||||
if has_fake_tensor:
|
||||
ep.example_inputs = new_example_inputs
|
||||
|
||||
ep = ep._update(
|
||||
ep.graph_module,
|
||||
ep.graph_signature,
|
||||
state_dict=pytree.tree_map_only(
|
||||
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.state_dict
|
||||
),
|
||||
constants=pytree.tree_map_only(
|
||||
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.constants
|
||||
),
|
||||
)
|
||||
return ep
|
||||
|
||||
|
||||
def _is_supported_types(arg) -> bool:
|
||||
if isinstance(arg, list):
|
||||
return (
|
||||
all(_is_supported_types(a) for a in arg)
|
||||
and len({type(a) for a in arg}) <= 1
|
||||
)
|
||||
elif isinstance(arg, tuple):
|
||||
return all(_is_supported_types(a) for a in arg)
|
||||
elif isinstance(arg, dict):
|
||||
return (
|
||||
all(_is_supported_types(a) for a in arg.values())
|
||||
and len({type(a) for a in arg.values()}) <= 1
|
||||
)
|
||||
elif isinstance(arg, (torch.Tensor, int, float, bool, str)):
|
||||
return True
|
||||
elif arg is None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def run_with_nativert(ep):
|
||||
# Downstream tests might mutate the exported program in subtle ways, so
|
||||
# we need to make a copy here.
|
||||
ep_infer = copy.deepcopy(ep)
|
||||
ep_infer = _use_real_inputs(ep_infer.run_decompositions())
|
||||
MODEL_NAME = "forward"
|
||||
|
||||
# TODO Does named tempfile have collision?
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
torch.export.pt2_archive._package.package_pt2(
|
||||
f, exported_programs={MODEL_NAME: ep_infer}
|
||||
)
|
||||
filename = f.name
|
||||
|
||||
try:
|
||||
ep_args, ep_kwargs = ep_infer.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
|
||||
model_runner = PyModelRunner(filename, MODEL_NAME)
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) == type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
pathlib.Path(filename).unlink(missing_ok=True)
|
||||
return ep
|
||||
|
||||
|
||||
def mocked_nativert_export_strict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = export(*args, **kwargs)
|
||||
else:
|
||||
ep = export(*args, **kwargs, strict=True)
|
||||
|
||||
run_with_nativert(ep)
|
||||
return ep
|
||||
|
||||
|
||||
def mocked_nativert_export_nonstrict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = export(*args, **kwargs)
|
||||
else:
|
||||
ep = export(*args, **kwargs, strict=False)
|
||||
|
||||
run_with_nativert(ep)
|
||||
return ep
|
||||
|
||||
|
||||
def make_dynamic_cls(cls, strict=False):
|
||||
cls_prefix = "NativeRT"
|
||||
|
||||
if strict:
|
||||
test_class = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
cls_prefix,
|
||||
test_export.CPP_RUNTIME_STRICT_SUFFIX,
|
||||
mocked_nativert_export_strict,
|
||||
xfail_prop="_expected_failure_cpp_runtime",
|
||||
test_only_if_no_xfail=True,
|
||||
)
|
||||
else:
|
||||
test_class = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
cls_prefix,
|
||||
test_export.CPP_RUNTIME_NONSTRICT_SUFFIX,
|
||||
mocked_nativert_export_nonstrict,
|
||||
xfail_prop="_expected_failure_cpp_runtime_non_strict",
|
||||
test_only_if_no_xfail=True,
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test, strict=True)
|
||||
make_dynamic_cls(test, strict=False)
|
||||
del test
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -138,6 +138,8 @@
|
||||
#include <torch/csrc/itt.h>
|
||||
#endif
|
||||
|
||||
#include <torch/nativert/python/Bindings.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
static PyObject* module;
|
||||
@ -2780,6 +2782,8 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
#ifdef USE_KINETO
|
||||
torch::global_kineto_init();
|
||||
#endif
|
||||
auto nativert_module = py_module.def_submodule("_nativert");
|
||||
torch::nativert::initModelRunnerPybind(nativert_module);
|
||||
return module;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
@ -669,6 +669,16 @@ void Unpickler::readGlobal(
|
||||
// See [NOTE] skip_next_read_global
|
||||
this->skip_next_read_global--;
|
||||
if (this->skip_next_read_global == 1) {
|
||||
if (module_name == "torch" && class_name == "Tensor") {
|
||||
// This is a special case when we are unpickling a subclassed tensor
|
||||
// with type torch.nn.Buffer. We didn't frequently run into this because
|
||||
// torch.nn.Buffer is introduced later in PyTorch 2 and this type IValue
|
||||
// will not be used in C++.
|
||||
rebuildTensor(false);
|
||||
stack_.emplace_back(int64_t(globals_.size() - 1));
|
||||
this->skip_next_read_global = 0;
|
||||
return;
|
||||
}
|
||||
// Pass through to the correct handler
|
||||
} else if (this->skip_next_read_global == 0) {
|
||||
// Corresponds to the type of `Tensor` being unpickled
|
||||
@ -773,6 +783,10 @@ void Unpickler::readGlobal(
|
||||
// Unpickle a Tensor with Python attributes or
|
||||
// a Subclassed Tensor.
|
||||
rebuildTensorFromTypeV2();
|
||||
} else if (
|
||||
module_name == "torch._utils" && (class_name == "_rebuild_parameter")) {
|
||||
// Unpickle a Parameter
|
||||
rebuildParameter();
|
||||
} else if (
|
||||
module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
|
||||
rebuildSparseTensor();
|
||||
@ -1024,6 +1038,18 @@ void Unpickler::rebuildTensorFromTypeV2() {
|
||||
});
|
||||
}
|
||||
|
||||
void Unpickler::rebuildParameter() {
|
||||
globals_.emplace_back([this] {
|
||||
auto args = pop(stack_).toTuple();
|
||||
size_t tup_idx = 0;
|
||||
const auto args_elems = args->elements();
|
||||
auto result = args_elems.at(tup_idx++).toTensor();
|
||||
auto requires_grad = args_elems.at(tup_idx++).toBool();
|
||||
result.requires_grad_(requires_grad);
|
||||
stack_.emplace_back(std::move(result));
|
||||
});
|
||||
}
|
||||
|
||||
#ifdef USE_RPC
|
||||
void Unpickler::rebuildRRef() {
|
||||
globals_.emplace_back([this] {
|
||||
|
@ -137,6 +137,7 @@ class TORCH_API Unpickler {
|
||||
const std::string& module_name,
|
||||
const std::string& class_name);
|
||||
void rebuildTensor(bool quantized);
|
||||
void rebuildParameter();
|
||||
void rebuildTensorFromTypeV2();
|
||||
void rebuildSparseTensor();
|
||||
#ifdef USE_DISTRIBUTED
|
||||
|
139
torch/nativert/ModelRunner.cpp
Normal file
139
torch/nativert/ModelRunner.cpp
Normal file
@ -0,0 +1,139 @@
|
||||
#include <torch/nativert/ModelRunner.h>
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <caffe2/serialize/file_adapter.h>
|
||||
#include <torch/csrc/export/pt2_archive_constants.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
#include <torch/nativert/graph/GraphPasses.h>
|
||||
#include <torch/nativert/graph/Serialization.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
using torch::nativert::jsonToGraph;
|
||||
using torch::nativert::detail::itreeSpecLoads;
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<Weights> loadWeightsDefault(
|
||||
Graph& graph,
|
||||
caffe2::serialize::PyTorchStreamReader& reader,
|
||||
std::string_view modelName) {
|
||||
auto weightsPath = fmt::format(
|
||||
"{}{}.pt", torch::_export::archive_spec::WEIGHTS_DIR, modelName);
|
||||
auto constantsPath = fmt::format(
|
||||
"{}{}.pt", torch::_export::archive_spec::CONSTANTS_DIR, modelName);
|
||||
TORCH_CHECK(
|
||||
reader.hasRecord(weightsPath), weightsPath, " not found in package");
|
||||
TORCH_CHECK(
|
||||
reader.hasRecord(constantsPath), constantsPath, " not found in package");
|
||||
const auto& [weightsData, weightsSize] = reader.getRecord(weightsPath);
|
||||
auto weights =
|
||||
torch::jit::pickle_load_obj(
|
||||
std::string_view{static_cast<char*>(weightsData.get()), weightsSize})
|
||||
.toGenericDict();
|
||||
const auto& [constantsData, constantsSize] = reader.getRecord(constantsPath);
|
||||
auto constants =
|
||||
torch::jit::pickle_load_obj(
|
||||
std::string_view{
|
||||
static_cast<char*>(constantsData.get()), constantsSize})
|
||||
.toGenericDict();
|
||||
std::unordered_map<std::string, c10::IValue> stateDict;
|
||||
std::unordered_map<std::string, c10::IValue> constantsDict;
|
||||
for (const auto& item : weights) {
|
||||
stateDict[item.key().toStringRef()] = item.value();
|
||||
}
|
||||
for (const auto& item : constants) {
|
||||
constantsDict[item.key().toStringRef()] = item.value();
|
||||
}
|
||||
return std::make_shared<Weights>(&graph, stateDict, constantsDict);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ModelRunner::ModelRunner(
|
||||
const std::string& packagePath,
|
||||
const std::string& modelName) {
|
||||
auto pytorchStreamReader =
|
||||
std::make_shared<caffe2::serialize::PyTorchStreamReader>(
|
||||
std::make_unique<caffe2::serialize::FileAdapter>(packagePath));
|
||||
std::string modelFilePath = fmt::format(
|
||||
torch::_export::archive_spec::MODELS_FILENAME_FORMAT, modelName);
|
||||
LOG(INFO) << "Loading model from: " << modelFilePath;
|
||||
|
||||
TORCH_CHECK(
|
||||
pytorchStreamReader->hasRecord(modelFilePath),
|
||||
modelFilePath,
|
||||
" not found in package");
|
||||
const auto& [modelData, modelSize] =
|
||||
pytorchStreamReader->getRecord(modelFilePath);
|
||||
const std::string modelSerialized{
|
||||
reinterpret_cast<char*>(modelData.get()), modelSize};
|
||||
|
||||
exportedProgram_ = nlohmann::json::parse(modelSerialized)
|
||||
.template get<torch::_export::ExportedProgram>();
|
||||
|
||||
TORCH_CHECK(exportedProgram_.get_graph_module()
|
||||
.get_module_call_graph()[0]
|
||||
.get_fqn()
|
||||
.empty());
|
||||
|
||||
graph_ = jsonToGraph(exportedProgram_.get_graph_module());
|
||||
|
||||
std::vector<const Value*> userInputs(
|
||||
graph_->userInputs().begin(), graph_->userInputs().end());
|
||||
const auto& signatureOpt = exportedProgram_.get_graph_module()
|
||||
.get_module_call_graph()[0]
|
||||
.get_signature();
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
const auto& signature = signatureOpt.value();
|
||||
inputSpec_ = itreeSpecLoads(signature.get_in_spec(), userInputs);
|
||||
|
||||
const auto& userOutputs = graph_->userOutputs();
|
||||
std::vector<const Value*> updatedUserOutput(userOutputs.size(), nullptr);
|
||||
for (size_t i = 0; i < userOutputs.size(); ++i) {
|
||||
if (const auto* valuePtr = std::get_if<Value*>(&userOutputs[i])) {
|
||||
updatedUserOutput[i] = *valuePtr;
|
||||
}
|
||||
}
|
||||
outputSpec_ = itreeSpecLoads(signature.get_out_spec(), updatedUserOutput);
|
||||
|
||||
torch::nativert::Placement placement;
|
||||
|
||||
graph_->applyDevicePlacement(placement);
|
||||
selectScalarOverload(graph_.get());
|
||||
|
||||
auto weights = loadWeightsDefault(*graph_, *pytorchStreamReader, modelName);
|
||||
|
||||
weights->validateAllWeightsLoaded();
|
||||
|
||||
torch::nativert::ExecutorConfig config;
|
||||
|
||||
executor_ = std::make_unique<Executor>(
|
||||
config, graph_, std::move(weights), pytorchStreamReader);
|
||||
}
|
||||
|
||||
c10::IValue ModelRunner::run(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||
TORCH_CHECK(executor_, "ModelRunner not initialized");
|
||||
|
||||
// ModelRunner is only used for inference
|
||||
c10::InferenceMode mode;
|
||||
|
||||
return itreeUnflatten(
|
||||
executor_->execute(args, kwargs, inputSpec_), outputSpec_);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ModelRunner::runWithFlatInputsAndOutputs(
|
||||
std::vector<c10::IValue> flatInputs) {
|
||||
TORCH_CHECK(executor_, "ModelRunner not initialized");
|
||||
|
||||
// ModelRunner is only used for inference
|
||||
c10::InferenceMode mode;
|
||||
|
||||
return executor_->execute(std::move(flatInputs));
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
45
torch/nativert/ModelRunner.h
Normal file
45
torch/nativert/ModelRunner.h
Normal file
@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <torch/csrc/utils/generated_serialization_types.h>
|
||||
#include <torch/nativert/detail/ITree.h>
|
||||
#include <torch/nativert/executor/Executor.h>
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
class TORCH_API ModelRunner {
|
||||
public:
|
||||
ModelRunner(const std::string& packagePath, const std::string& modelName);
|
||||
|
||||
ModelRunner(ModelRunner&&) = default;
|
||||
ModelRunner& operator=(ModelRunner&&) = default;
|
||||
ModelRunner(const ModelRunner&) = delete;
|
||||
ModelRunner& operator=(const ModelRunner&) = delete;
|
||||
~ModelRunner() = default;
|
||||
|
||||
c10::IValue run(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs);
|
||||
|
||||
/**
|
||||
* A low level API which expects user to always pass in flattened inputs.
|
||||
* The ownership of the entire input list must be transferred to the
|
||||
* executor via std::move or in-place construction.
|
||||
*/
|
||||
std::vector<c10::IValue> runWithFlatInputsAndOutputs(
|
||||
std::vector<c10::IValue> flatInputs);
|
||||
|
||||
private:
|
||||
// original non-delegated graph from torch.export()
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
std::unique_ptr<Executor> executor_;
|
||||
|
||||
ITreeSpec inputSpec_;
|
||||
ITreeSpec outputSpec_;
|
||||
|
||||
torch::_export::ExportedProgram exportedProgram_;
|
||||
};
|
||||
} // namespace torch::nativert
|
@ -25,13 +25,26 @@ WeightVersion Weights::globalVersion_ = 0;
|
||||
Weights::Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict)
|
||||
stateDict,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
constants)
|
||||
: graph_(graph),
|
||||
weightsMeta_(graph->weightsMeta()),
|
||||
version_(globalVersion_++) {
|
||||
if (stateDict.has_value()) {
|
||||
loadStateDict(stateDict.value());
|
||||
}
|
||||
if (constants.has_value()) {
|
||||
for (const auto& [name, value] : constants.value()) {
|
||||
if (value.isTensor()) {
|
||||
allValues_[name] = value.toTensor();
|
||||
} else if (value.isCustomClass()) {
|
||||
customObjs_[name] = value;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown constant type: ", value.tagKind());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Weights::Weights(
|
||||
|
@ -20,10 +20,12 @@ using WeightVersion = int;
|
||||
*/
|
||||
class Weights {
|
||||
public:
|
||||
explicit Weights(
|
||||
Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict = std::nullopt);
|
||||
stateDict = std::nullopt,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
constants = std::nullopt);
|
||||
|
||||
// Arguments
|
||||
// - pytorchStreamReader: the reader for the model archive
|
||||
|
@ -51,7 +51,10 @@ class LayoutPlanner {
|
||||
kernelSchemas,
|
||||
const std::vector<bool>& persistentValues,
|
||||
const torch::nativert::LayoutPlannerSettings& settings);
|
||||
~LayoutPlanner();
|
||||
#if !defined(_MSC_VER)
|
||||
TORCH_API // TODO Doesn't work on msvc.
|
||||
#endif
|
||||
~LayoutPlanner();
|
||||
|
||||
LayoutPlanner(LayoutPlanner&& other) = delete;
|
||||
LayoutPlanner(const LayoutPlanner& other) = delete;
|
||||
|
80
torch/nativert/python/Bindings.cpp
Normal file
80
torch/nativert/python/Bindings.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#include <unordered_map>
|
||||
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <torch/nativert/ModelRunner.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
template <typename T>
|
||||
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
|
||||
namespace torch {
|
||||
namespace nativert {
|
||||
|
||||
using torch::nativert::detail::argsToIValue;
|
||||
|
||||
void initModelRunnerPybind(py::module& m) {
|
||||
#if !defined(OVRSOURCE)
|
||||
shared_ptr_class_<ModelRunner>(m, "PyModelRunner")
|
||||
.def(
|
||||
py::init<const std::string&, const std::string&>(),
|
||||
py::arg("packagePath"),
|
||||
py::arg("modelName"))
|
||||
.def(
|
||||
"run",
|
||||
[](torch::nativert::ModelRunner& self,
|
||||
py::args pyargs,
|
||||
const py::kwargs& pykwargs) {
|
||||
std::vector<c10::IValue> args;
|
||||
for (const auto i : c10::irange(pyargs.size())) {
|
||||
auto ivalue =
|
||||
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
|
||||
args.push_back(std::move(ivalue));
|
||||
}
|
||||
std::unordered_map<std::string, c10::IValue> kwargs;
|
||||
for (const auto& [key, pyarg] : pykwargs) {
|
||||
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
|
||||
kwargs[py::str(key)] = std::move(ivalue);
|
||||
}
|
||||
c10::IValue ret = self.run(args, kwargs);
|
||||
return torch::jit::createPyObjectForStack({ret});
|
||||
})
|
||||
.def(
|
||||
"__call__",
|
||||
[](torch::nativert::ModelRunner& self,
|
||||
py::args pyargs,
|
||||
const py::kwargs& pykwargs) {
|
||||
std::vector<c10::IValue> args;
|
||||
for (const auto i : c10::irange(pyargs.size())) {
|
||||
auto ivalue =
|
||||
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
|
||||
args.push_back(std::move(ivalue));
|
||||
}
|
||||
std::unordered_map<std::string, c10::IValue> kwargs;
|
||||
for (const auto& [key, pyarg] : pykwargs) {
|
||||
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
|
||||
kwargs[py::str(key)] = std::move(ivalue);
|
||||
}
|
||||
c10::IValue ret = self.run(args, kwargs);
|
||||
return torch::jit::createPyObjectForStack({ret});
|
||||
})
|
||||
.def(
|
||||
"run_with_flat_inputs_and_outputs",
|
||||
[](torch::nativert::ModelRunner& self, py::args pyargs) {
|
||||
std::vector<c10::IValue> args;
|
||||
for (const auto i : c10::irange(pyargs.size())) {
|
||||
auto ivalue =
|
||||
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
|
||||
args.push_back(std::move(ivalue));
|
||||
}
|
||||
|
||||
auto rets = self.runWithFlatInputsAndOutputs(std::move(args));
|
||||
return torch::jit::createPyObjectForStack(std::move(rets));
|
||||
});
|
||||
#endif // !defined(OVRSOURCE)
|
||||
}
|
||||
|
||||
} // namespace nativert
|
||||
} // namespace torch
|
13
torch/nativert/python/Bindings.h
Normal file
13
torch/nativert/python/Bindings.h
Normal file
@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace torch {
|
||||
namespace nativert {
|
||||
|
||||
void initModelRunnerPybind(pybind11::module& m);
|
||||
|
||||
} // namespace nativert
|
||||
} // namespace torch
|
Reference in New Issue
Block a user