mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Compare commits
1 Commits
ciflow/tru
...
zhxchen17/
| Author | SHA1 | Date | |
|---|---|---|---|
| 290a4b4fea |
@ -587,6 +587,44 @@ jit_sources_full = [
|
||||
|
||||
libtorch_core_jit_sources = sorted(jit_sources_full)
|
||||
|
||||
libtorch_runtime_sources = [
|
||||
"torch/csrc/nativert/common/ConfigUtils.cpp",
|
||||
"torch/csrc/nativert/common/Conv.cpp",
|
||||
"torch/csrc/nativert/common/FileUtil.cpp",
|
||||
"torch/csrc/nativert/common/Pytree.cpp",
|
||||
"torch/csrc/nativert/common/String.cpp",
|
||||
"torch/csrc/nativert/executor/AOTInductorModelImpl.cpp",
|
||||
"torch/csrc/nativert/executor/SerialGraphExecutor.cpp",
|
||||
"torch/csrc/nativert/executor/AOTIDelegateExecutor.cpp",
|
||||
"torch/csrc/nativert/executor/Executor.cpp",
|
||||
"torch/csrc/nativert/executor/GraphExecutorBase.cpp",
|
||||
"torch/csrc/nativert/executor/ConstantFolder.cpp",
|
||||
"torch/csrc/nativert/executor/DelegateExecutor.cpp",
|
||||
"torch/csrc/nativert/executor/ExecutionFrame.cpp",
|
||||
"torch/csrc/nativert/executor/ExecutionPlanner.cpp",
|
||||
"torch/csrc/nativert/executor/ModelRunnerBase.cpp",
|
||||
"torch/csrc/nativert/executor/OpKernel.cpp",
|
||||
"torch/csrc/nativert/executor/ParallelGraphExecutor.cpp",
|
||||
"torch/csrc/nativert/executor/Placement.cpp",
|
||||
"torch/csrc/nativert/executor/Weights.cpp",
|
||||
"torch/csrc/nativert/graph/Graph.cpp",
|
||||
"torch/csrc/nativert/graph/GraphPasses.cpp",
|
||||
"torch/csrc/nativert/graph/GraphSignature.cpp",
|
||||
"torch/csrc/nativert/graph/Serialization.cpp",
|
||||
"torch/csrc/nativert/graph/TensorMeta.cpp",
|
||||
"torch/csrc/nativert/kernels/NativeKernels.cpp",
|
||||
"torch/csrc/nativert/kernels/AOTICallDelegateKernel.cpp",
|
||||
"torch/csrc/nativert/kernels/AOTIKernel.cpp",
|
||||
"torch/csrc/nativert/kernels/CallTorchBindKernel.cpp",
|
||||
"torch/csrc/nativert/kernels/GeneratedStaticDispatchKernels.cpp",
|
||||
"torch/csrc/nativert/kernels/AutoFunctionalizeKernel.cpp",
|
||||
"torch/csrc/nativert/kernels/C10Kernel.cpp",
|
||||
"torch/csrc/nativert/kernels/HigherOrderKernel.cpp",
|
||||
"torch/csrc/nativert/kernels/KernelFactory.cpp",
|
||||
"torch/csrc/nativert/kernels/KernelRegistry.cpp",
|
||||
"torch/csrc/nativert/ModelRunner.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
"torch/csrc/jit/mobile/model_tracer/tracer.cpp",
|
||||
"torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp",
|
||||
@ -619,7 +657,7 @@ libtorch_lite_cmake_sources = sorted(
|
||||
torch_mobile_core,
|
||||
)
|
||||
|
||||
libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources
|
||||
libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources + libtorch_runtime_sources
|
||||
|
||||
libtorch_extra_sources = libtorch_core_jit_sources + [
|
||||
"torch/csrc/autograd/TraceTypeManual.cpp",
|
||||
@ -935,6 +973,8 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/verbose.cpp",
|
||||
"torch/csrc/cpu/Module.cpp",
|
||||
"torch/csrc/instruction_counter/Module.cpp",
|
||||
"torch/csrc/nativert/ModelRunnerPybind.cpp",
|
||||
"torch/csrc/nativert/package/pt2_archive_constants_pybind.cpp",
|
||||
] + lazy_tensor_core_python_sources
|
||||
|
||||
libtorch_python_distributed_core_sources = [
|
||||
|
||||
219
test/export/test_nativert.py
Normal file
219
test/export/test_nativert.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._C.nativert import (
|
||||
PyExecutorType,
|
||||
PyModelRunner,
|
||||
PyPlacement,
|
||||
PyRuntimeConfigs,
|
||||
)
|
||||
from torch.export.experimental.package import package_model
|
||||
from torch.export.experimental.package.pt2_archive import PT2ArchiveWriter
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
try:
|
||||
from .. import test_export, testing
|
||||
except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def _use_real_inputs(ep):
|
||||
ep = copy.copy(ep)
|
||||
|
||||
has_fake_tensor = False
|
||||
|
||||
def _to_real_tensor(t):
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
return torch.nn.Parameter(_to_real_tensor(t.data))
|
||||
if isinstance(t, FakeTensor):
|
||||
nonlocal has_fake_tensor
|
||||
has_fake_tensor = True
|
||||
return torch.randn(t.shape, device=t.device, requires_grad=t.requires_grad)
|
||||
return t
|
||||
|
||||
new_example_inputs = pytree.tree_map_only(
|
||||
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.example_inputs
|
||||
)
|
||||
if has_fake_tensor:
|
||||
ep.example_inputs = new_example_inputs
|
||||
|
||||
ep = ep._update(
|
||||
ep.graph_module,
|
||||
ep.graph_signature,
|
||||
state_dict=pytree.tree_map_only(
|
||||
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.state_dict
|
||||
),
|
||||
constants=pytree.tree_map_only(
|
||||
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.constants
|
||||
),
|
||||
)
|
||||
return ep
|
||||
|
||||
|
||||
def _is_supported_types(arg) -> bool:
|
||||
if isinstance(arg, list):
|
||||
return (
|
||||
all(_is_supported_types(a) for a in arg)
|
||||
and len({type(a) for a in arg}) <= 1
|
||||
)
|
||||
elif isinstance(arg, tuple):
|
||||
return all(_is_supported_types(a) for a in arg)
|
||||
elif isinstance(arg, dict):
|
||||
return (
|
||||
all(_is_supported_types(a) for a in arg.values())
|
||||
and len({type(a) for a in arg.values()}) <= 1
|
||||
)
|
||||
elif isinstance(arg, (torch.Tensor, int, float, bool, str)):
|
||||
return True
|
||||
elif arg is None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def run_with_sigmoid(ep):
|
||||
# Downstream tests might mutate the exported program in subtle ways, so
|
||||
# we need to make a copy here.
|
||||
ep_infer = copy.deepcopy(ep)
|
||||
MODEL_NAME = "test_export"
|
||||
ep_infer = _use_real_inputs(ep_infer.run_decompositions())
|
||||
|
||||
# TODO Does named tempfile have collision?
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
with PT2ArchiveWriter(f) as archive_writer:
|
||||
package_model(
|
||||
ep_infer,
|
||||
MODEL_NAME,
|
||||
archive_writer,
|
||||
)
|
||||
|
||||
device = torch.device("cpu")
|
||||
placement = PyPlacement(device)
|
||||
|
||||
runtime_configs = PyRuntimeConfigs()
|
||||
|
||||
filename = f.name
|
||||
try:
|
||||
ep_args, ep_kwargs = ep_infer.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(e)
|
||||
|
||||
model_runner = PyModelRunner(
|
||||
filename,
|
||||
MODEL_NAME,
|
||||
PyExecutorType.INTERPRETER,
|
||||
runtime_configs,
|
||||
placement,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) == type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
pathlib.Path(filename).unlink(missing_ok=True)
|
||||
return ep
|
||||
|
||||
|
||||
def mocked_sigmoid_export_strict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = export(*args, **kwargs)
|
||||
else:
|
||||
ep = export(*args, **kwargs, strict=True)
|
||||
|
||||
run_with_sigmoid(ep)
|
||||
return ep
|
||||
|
||||
|
||||
def mocked_sigmoid_export_nonstrict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = export(*args, **kwargs)
|
||||
else:
|
||||
ep = export(*args, **kwargs, strict=False)
|
||||
|
||||
run_with_sigmoid(ep)
|
||||
return ep
|
||||
|
||||
|
||||
def make_dynamic_cls(cls, strict=False):
|
||||
cls_prefix = "Sigmoid"
|
||||
|
||||
if strict:
|
||||
test_class = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
cls_prefix,
|
||||
test_export.CPP_RUNTIME_STRICT_SUFFIX,
|
||||
mocked_sigmoid_export_strict,
|
||||
xfail_prop="_expected_failure_cpp_runtime",
|
||||
test_only_if_no_xfail=True,
|
||||
)
|
||||
else:
|
||||
test_class = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
cls_prefix,
|
||||
test_export.CPP_RUNTIME_NONSTRICT_SUFFIX,
|
||||
mocked_sigmoid_export_nonstrict,
|
||||
xfail_prop="_expected_failure_cpp_runtime_non_strict",
|
||||
test_only_if_no_xfail=True,
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test, strict=True)
|
||||
make_dynamic_cls(test, strict=False)
|
||||
del test
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -129,6 +129,9 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <torch/csrc/nativert/ModelRunnerPybind.h>
|
||||
#include <torch/csrc/nativert/package/pt2_archive_constants_pybind.h>
|
||||
|
||||
#if defined(USE_VALGRIND)
|
||||
#include <callgrind.h>
|
||||
#endif
|
||||
@ -2586,6 +2589,13 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
#ifdef USE_KINETO
|
||||
torch::global_kineto_init();
|
||||
#endif
|
||||
{
|
||||
auto py_module_runtime = py_module.def_submodule("nativert");
|
||||
torch::nativert::initModelRunnerPybind(py_module_runtime);
|
||||
auto py_module_constants =
|
||||
py_module_runtime.def_submodule("pt2_archive_constants");
|
||||
torch::nativert::initPt2ArchiveConstantsPybind(py_module_constants);
|
||||
}
|
||||
return module;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
113
torch/csrc/nativert/ModelRunner.cpp
Normal file
113
torch/csrc/nativert/ModelRunner.cpp
Normal file
@ -0,0 +1,113 @@
|
||||
#include "torch/csrc/nativert/ModelRunner.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <caffe2/serialize/file_adapter.h>
|
||||
|
||||
#include "torch/csrc/nativert/graph/GraphPasses.h"
|
||||
#include "torch/csrc/nativert/graph/Serialization.h"
|
||||
|
||||
namespace torch::nativert::core {
|
||||
|
||||
ModelRunner::ModelRunner(
|
||||
const std::string& packagePath,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const Placement& placement)
|
||||
: ModelRunner(
|
||||
std::make_unique<caffe2::serialize::FileAdapter>(packagePath),
|
||||
modelName,
|
||||
executorType,
|
||||
runtimeConfigs,
|
||||
placement) {}
|
||||
|
||||
ModelRunner::ModelRunner(
|
||||
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const Placement& placement)
|
||||
: ModelRunner(
|
||||
std::make_shared<caffe2::serialize::PyTorchStreamReader>(rai),
|
||||
modelName,
|
||||
executorType,
|
||||
runtimeConfigs,
|
||||
placement) {}
|
||||
|
||||
ModelRunner::ModelRunner(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const Placement& placement)
|
||||
: ModelRunner(
|
||||
std::move(pytorchStreamReader),
|
||||
modelName,
|
||||
executorType,
|
||||
runtimeConfigs,
|
||||
[=](const torch::nativert::Graph&) { return placement; }) {}
|
||||
|
||||
ModelRunner::ModelRunner(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const std::function<Placement(const torch::nativert::Graph& graph)>&
|
||||
buildPlacementFn)
|
||||
: ModelRunnerBase(
|
||||
pytorchStreamReader,
|
||||
modelName,
|
||||
executorType,
|
||||
runtimeConfigs,
|
||||
buildPlacementFn) {
|
||||
std::string modelSerialized = loadSerializedModel(pytorchStreamReader);
|
||||
|
||||
model_ = nlohmann::json::parse(modelSerialized)
|
||||
.template get<torch::_export::Model>();
|
||||
exportedProgram_ = model_.get_program().get_methods().at("forward");
|
||||
for (const auto& _ : model_.get_delegates()) {
|
||||
(void)_;
|
||||
TORCH_CHECK(false, "Delegates are not supported yet");
|
||||
// TODO delegates_.emplace(name, delegate.get_methods().at("forward"));
|
||||
}
|
||||
stateDictPath_ = model_.get_tensorPaths();
|
||||
constantPaths_ = model_.get_constantPaths();
|
||||
TORCH_CHECK_EQ(
|
||||
exportedProgram_.get_graph_module().get_module_call_graph()[0].get_fqn(),
|
||||
"");
|
||||
|
||||
inputSpec_ = treeSpecLoads(exportedProgram_.get_graph_module()
|
||||
.get_module_call_graph()[0]
|
||||
.get_signature()
|
||||
.value()
|
||||
.get_in_spec());
|
||||
outputSpec_ = treeSpecLoads(exportedProgram_.get_graph_module()
|
||||
.get_module_call_graph()[0]
|
||||
.get_signature()
|
||||
.value()
|
||||
.get_out_spec());
|
||||
|
||||
graph_ = jsonToGraph(
|
||||
model_.get_program().get_methods().at("forward").get_graph_module());
|
||||
|
||||
VLOG(1) << "Graph: \n" << *graph_;
|
||||
|
||||
placement_ = buildPlacementFn(*graph_);
|
||||
LOG(INFO) << "Placement: " << placement_;
|
||||
|
||||
graph_->applyDevicePlacement(placement_);
|
||||
selectScalarOverload(graph_.get());
|
||||
|
||||
loadNewWeights(pytorchStreamReader);
|
||||
|
||||
if (!runtimeConfigs.deferInitialization) {
|
||||
initialize(pytorchStreamReader);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> ModelRunner::deserializeDelegateGraph() const {
|
||||
return {}; // TODO
|
||||
}
|
||||
|
||||
} // namespace torch::nativert::core
|
||||
76
torch/csrc/nativert/ModelRunner.h
Normal file
76
torch/csrc/nativert/ModelRunner.h
Normal file
@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/executor/ModelRunnerBase.h"
|
||||
|
||||
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
namespace torch::nativert::core {
|
||||
class TORCH_API ModelRunner : public ModelRunnerBase {
|
||||
public:
|
||||
ModelRunner(
|
||||
const std::string& packagePath,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const Placement& placement = Placement());
|
||||
|
||||
ModelRunner(
|
||||
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const Placement& placement = Placement());
|
||||
|
||||
ModelRunner(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const Placement& placement = Placement());
|
||||
|
||||
ModelRunner(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
// functor to build the placement after the graph is loaded, but before
|
||||
// loading the weights.
|
||||
const std::function<Placement(const torch::nativert::Graph& graph)>&
|
||||
buildPlacementFn);
|
||||
|
||||
ModelRunner(ModelRunner&&) = default;
|
||||
ModelRunner& operator=(ModelRunner&&) = default;
|
||||
ModelRunner(const ModelRunner&) = delete;
|
||||
ModelRunner& operator=(const ModelRunner&) = delete;
|
||||
~ModelRunner() override = default;
|
||||
|
||||
std::vector<std::string> availableDelegates() const {
|
||||
std::vector<std::string> delegateNames;
|
||||
delegateNames.reserve(delegates_.size());
|
||||
for (const auto& [name, _] : delegates_) {
|
||||
delegateNames.push_back(name);
|
||||
}
|
||||
return delegateNames;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T*> getDelegates() {
|
||||
std::vector<T*> delegates;
|
||||
for (const auto& delegate : executor_->getDelegates()) {
|
||||
if (auto* d = dynamic_cast<T*>(delegate)) {
|
||||
delegates.push_back(d);
|
||||
}
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<Graph> deserializeDelegateGraph() const override;
|
||||
|
||||
torch::_export::Model model_;
|
||||
torch::_export::ExportedProgram exportedProgram_;
|
||||
std::unordered_map<std::string, torch::_export::ExportedProgram> delegates_;
|
||||
};
|
||||
} // namespace torch::nativert::core
|
||||
148
torch/csrc/nativert/ModelRunnerPybind.cpp
Normal file
148
torch/csrc/nativert/ModelRunnerPybind.cpp
Normal file
@ -0,0 +1,148 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <caffe2/serialize/file_adapter.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/pybind.h> // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
#include "c10/core/Device.h"
|
||||
|
||||
#include "torch/csrc/nativert/ModelRunner.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch::nativert;
|
||||
|
||||
template <typename T>
|
||||
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
|
||||
namespace torch {
|
||||
namespace nativert {
|
||||
|
||||
void initModelRunnerPybind(py::module& m) {
|
||||
py::enum_<ExecutorType>(m, "PyExecutorType")
|
||||
.value("INTERPRETER", ExecutorType::INTERPRETER)
|
||||
.value("AOTINDUCTOR", ExecutorType::AOTINDUCTOR)
|
||||
.value("MTIA", ExecutorType::MTIA)
|
||||
.export_values();
|
||||
|
||||
py::class_<Placement>(m, "PyPlacement")
|
||||
.def(py::init<>())
|
||||
.def(py::init<std::optional<c10::Device>>(), py::arg("defaultDevice"))
|
||||
.def(
|
||||
py::init<
|
||||
const std::unordered_map<c10::Device, c10::Device>&,
|
||||
std::optional<c10::Device>>(),
|
||||
py::arg("deviceMap"),
|
||||
py::arg("defaultDevice") = std::nullopt)
|
||||
.def("get_mapped_device", &Placement::getMappedDevice);
|
||||
|
||||
py::class_<BaseRuntimeConfigs>(m, "PyRuntimeConfigs")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("isDebug", &BaseRuntimeConfigs::isDebug)
|
||||
.def_readwrite("validateInputs", &BaseRuntimeConfigs::validateInputs)
|
||||
.def_readwrite(
|
||||
"enableStaticCPUKernels", &BaseRuntimeConfigs::enableStaticCPUKernels)
|
||||
.def_readwrite(
|
||||
"deferInitialization", &BaseRuntimeConfigs::deferInitialization)
|
||||
.def_readwrite("platformArch", &BaseRuntimeConfigs::platformArch)
|
||||
.def_readwrite(
|
||||
"maxNumConcurrentThreads",
|
||||
&BaseRuntimeConfigs::maxNumConcurrentThreads)
|
||||
.def_readwrite("maxParallelOps", &BaseRuntimeConfigs::maxParallelOps);
|
||||
|
||||
shared_ptr_class_<core::ModelRunner>(m, "PyModelRunner")
|
||||
.def(
|
||||
py::init<
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
ExecutorType,
|
||||
const BaseRuntimeConfigs&,
|
||||
const Placement&>(),
|
||||
py::arg("packagePath"),
|
||||
py::arg("modelName"),
|
||||
py::arg("executorType"),
|
||||
py::arg("runtimeConfigs"),
|
||||
py::arg("placement") = Placement())
|
||||
.def(py::init<
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>,
|
||||
const std::string&,
|
||||
ExecutorType,
|
||||
const BaseRuntimeConfigs&,
|
||||
std::function<Placement(const Graph& graph)>&>())
|
||||
.def(
|
||||
"run",
|
||||
[](core::ModelRunner& self,
|
||||
py::args pyargs,
|
||||
const py::kwargs& pykwargs) {
|
||||
std::vector<c10::IValue> args;
|
||||
for (const auto i : c10::irange(pyargs.size())) {
|
||||
auto ivalue =
|
||||
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
|
||||
args.push_back(std::move(ivalue));
|
||||
}
|
||||
std::unordered_map<std::string, c10::IValue> kwargs;
|
||||
for (const auto& [key, pyarg] : pykwargs) {
|
||||
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
|
||||
kwargs[py::str(key)] = std::move(ivalue);
|
||||
}
|
||||
c10::IValue ret = self.run(args, kwargs);
|
||||
return torch::jit::createPyObjectForStack({ret});
|
||||
})
|
||||
.def(
|
||||
"__call__",
|
||||
[](core::ModelRunner& self,
|
||||
py::args pyargs,
|
||||
const py::kwargs& pykwargs) {
|
||||
std::vector<c10::IValue> args;
|
||||
for (const auto i : c10::irange(pyargs.size())) {
|
||||
auto ivalue =
|
||||
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
|
||||
args.push_back(std::move(ivalue));
|
||||
}
|
||||
std::unordered_map<std::string, c10::IValue> kwargs;
|
||||
for (const auto& [key, pyarg] : pykwargs) {
|
||||
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
|
||||
kwargs[py::str(key)] = std::move(ivalue);
|
||||
}
|
||||
c10::IValue ret = self.run(args, kwargs);
|
||||
return torch::jit::createPyObjectForStack({ret});
|
||||
})
|
||||
.def(
|
||||
"load_sample_inputs",
|
||||
[](core::ModelRunner& self,
|
||||
const std::string& packagePath,
|
||||
const Placement& placement = Placement()) {
|
||||
auto reader =
|
||||
std::make_shared<caffe2::serialize::PyTorchStreamReader>(
|
||||
std::make_unique<caffe2::serialize::FileAdapter>(
|
||||
packagePath));
|
||||
const auto [args, kwargs] = self.loadSampleInputs(reader);
|
||||
const auto val = argsToIValue(args, kwargs);
|
||||
return torch::jit::createPyObjectForStack({val});
|
||||
})
|
||||
.def(
|
||||
"run_with_flat_inputs_and_outputs",
|
||||
[](core::ModelRunner& self, py::args pyargs) {
|
||||
std::vector<c10::IValue> args;
|
||||
for (const auto i : c10::irange(pyargs.size())) {
|
||||
auto ivalue =
|
||||
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
|
||||
args.push_back(std::move(ivalue));
|
||||
}
|
||||
|
||||
auto rets = self.runWithFlatInputsAndOutputs(std::move(args));
|
||||
return torch::jit::createPyObjectForStack(std::move(rets));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace nativert
|
||||
} // namespace torch
|
||||
|
||||
// TODO Remove this once we fully migrate to OSS build.
|
||||
#ifdef FBCODE_CAFFE2
|
||||
PYBIND11_MODULE(model_runner_pybind, m) {
|
||||
initModelRunnerPybind(m);
|
||||
}
|
||||
#endif
|
||||
11
torch/csrc/nativert/ModelRunnerPybind.h
Normal file
11
torch/csrc/nativert/ModelRunnerPybind.h
Normal file
@ -0,0 +1,11 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace torch {
|
||||
namespace nativert {
|
||||
|
||||
void initModelRunnerPybind(pybind11::module& m);
|
||||
|
||||
} // namespace nativert
|
||||
} // namespace torch
|
||||
125
torch/csrc/nativert/common/AutoTimer.h
Normal file
125
torch/csrc/nativert/common/AutoTimer.h
Normal file
@ -0,0 +1,125 @@
|
||||
/*
|
||||
* Ported from folly/logging/AutoTimer.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
// Default logger
|
||||
enum class GoogleLoggerStyle { SECONDS, MILLISECONDS };
|
||||
template <GoogleLoggerStyle>
|
||||
struct GoogleLogger;
|
||||
|
||||
/**
|
||||
* Automatically times a block of code, printing a specified log message on
|
||||
* destruction or whenever the log() method is called. For example:
|
||||
*
|
||||
* AutoTimer t("Foo() completed");
|
||||
* doWork();
|
||||
* t.log("Do work finished");
|
||||
* doMoreWork();
|
||||
*
|
||||
* This would print something like:
|
||||
* "Do work finished in 1.2 seconds"
|
||||
* "Foo() completed in 4.3 seconds"
|
||||
*
|
||||
* You can customize what you use as the logger and clock. The logger needs
|
||||
* to have an operator()(StringPiece, std::chrono::duration<double>) that
|
||||
* gets a message and a duration. The clock needs to model Clock from
|
||||
* std::chrono.
|
||||
*
|
||||
* The default logger logs usings glog. It only logs if the message is
|
||||
* non-empty, so you can also just use this class for timing, e.g.:
|
||||
*
|
||||
* AutoTimer t;
|
||||
* doWork()
|
||||
* const auto how_long = t.log();
|
||||
*/
|
||||
template <
|
||||
class Logger = GoogleLogger<GoogleLoggerStyle::MILLISECONDS>,
|
||||
class Clock = std::chrono::high_resolution_clock>
|
||||
class AutoTimer final {
|
||||
public:
|
||||
using DoubleSeconds = std::chrono::duration<double>;
|
||||
|
||||
explicit AutoTimer(
|
||||
std::string&& msg = "",
|
||||
const DoubleSeconds& minTimetoLog = DoubleSeconds::zero(),
|
||||
Logger&& logger = Logger())
|
||||
: destructionMessage_(std::move(msg)),
|
||||
minTimeToLog_(minTimetoLog),
|
||||
logger_(std::move(logger)) {}
|
||||
|
||||
// It doesn't really make sense to copy AutoTimer
|
||||
// Movable to make sure the helper method for creating an AutoTimer works.
|
||||
AutoTimer(const AutoTimer&) = delete;
|
||||
AutoTimer(AutoTimer&&) = default;
|
||||
AutoTimer& operator=(const AutoTimer&) = delete;
|
||||
AutoTimer& operator=(AutoTimer&&) = default;
|
||||
|
||||
~AutoTimer() {
|
||||
if (destructionMessage_) {
|
||||
log(destructionMessage_.value());
|
||||
}
|
||||
}
|
||||
|
||||
DoubleSeconds log(std::string_view msg = "") {
|
||||
return logImpl(Clock::now(), msg);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
DoubleSeconds logFormat(fmt::format_string<Args...> fmt, Args&&... args) {
|
||||
auto now = Clock::now();
|
||||
return logImpl(now, fmt::format(fmt, std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
private:
|
||||
// We take in the current time so that we don't measure time to call
|
||||
// to<std::string> or format() in the duration.
|
||||
DoubleSeconds logImpl(
|
||||
std::chrono::time_point<Clock> now,
|
||||
std::string_view msg) {
|
||||
auto duration = now - start_;
|
||||
if (duration >= minTimeToLog_) {
|
||||
logger_(msg, duration);
|
||||
}
|
||||
start_ = Clock::now(); // Don't measure logging time
|
||||
return duration;
|
||||
}
|
||||
|
||||
std::optional<std::string> destructionMessage_;
|
||||
std::chrono::time_point<Clock> start_ = Clock::now();
|
||||
DoubleSeconds minTimeToLog_;
|
||||
Logger logger_;
|
||||
};
|
||||
|
||||
template <GoogleLoggerStyle Style>
|
||||
struct GoogleLogger final {
|
||||
void operator()(
|
||||
std::string_view msg,
|
||||
const std::chrono::duration<double>& sec) const {
|
||||
if (msg.empty()) {
|
||||
return;
|
||||
}
|
||||
if (Style == GoogleLoggerStyle::SECONDS) {
|
||||
LOG(INFO) << msg << " in " << sec.count() << " seconds";
|
||||
} else if (Style == GoogleLoggerStyle::MILLISECONDS) {
|
||||
LOG(INFO) << msg << " in "
|
||||
<< std::chrono::duration_cast<
|
||||
std::chrono::duration<double, std::milli>>(sec)
|
||||
.count()
|
||||
<< " milliseconds";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
13
torch/csrc/nativert/common/ConfigUtils.cpp
Normal file
13
torch/csrc/nativert/common/ConfigUtils.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include "torch/csrc/nativert/common/ConfigUtils.h"
|
||||
#include <string.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
std::optional<std::string> maybeGetEnv(std::string_view envVar) {
|
||||
const char* env = getenv(envVar.data());
|
||||
if (env == nullptr || strlen(env) == 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::string(env);
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
7
torch/csrc/nativert/common/ConfigUtils.h
Normal file
7
torch/csrc/nativert/common/ConfigUtils.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace torch::nativert {
|
||||
std::optional<std::string> maybeGetEnv(std::string_view envVar);
|
||||
}
|
||||
47
torch/csrc/nativert/common/Conv.cpp
Normal file
47
torch/csrc/nativert/common/Conv.cpp
Normal file
@ -0,0 +1,47 @@
|
||||
#include "torch/csrc/nativert/common/Conv.h"
|
||||
|
||||
#include <charconv>
|
||||
#include <string>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
template <>
|
||||
std::optional<int64_t> tryTo<int64_t>(std::string_view symbol) {
|
||||
int64_t value;
|
||||
auto [ptr, ec] =
|
||||
std::from_chars(symbol.data(), symbol.data() + symbol.size(), value);
|
||||
if (ec != std::errc()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (ptr != symbol.data() + symbol.size()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::optional<double> tryTo<double>(std::string_view symbol) {
|
||||
double value;
|
||||
#ifdef __APPLE__
|
||||
char extra; // to detect any extra characters after the number
|
||||
// Try to parse the string using sscanf
|
||||
auto str = std::string{symbol};
|
||||
if (sscanf(str.c_str(), "%lf %c", &value, &extra) != 1) {
|
||||
// If sscanf returns anything other than 1, it means parsing failed or there
|
||||
// were extra characters
|
||||
return std::nullopt;
|
||||
}
|
||||
#else
|
||||
auto [ptr, ec] =
|
||||
std::from_chars(symbol.data(), symbol.data() + symbol.size(), value);
|
||||
if (ec != std::errc()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (ptr != symbol.data() + symbol.size()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
#endif
|
||||
return value;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
27
torch/csrc/nativert/common/Conv.h
Normal file
27
torch/csrc/nativert/common/Conv.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
template <typename T>
|
||||
std::optional<T> tryTo(std::string_view symbol) = delete;
|
||||
|
||||
/*
|
||||
* Convert a string to an integer. prefixes like "0x" or trailing whitespaces
|
||||
* are not supported. Similayly, integer string with trailing characters like
|
||||
* "123abc" will be rejected either.
|
||||
*/
|
||||
template <>
|
||||
std::optional<int64_t> tryTo<int64_t>(std::string_view symbol);
|
||||
|
||||
/*
|
||||
* Convert a string to a double. prefixes like "0x" or trailing whitespaces
|
||||
* are not supported. Similayly, integer string with trailing characters like
|
||||
* "123abc" will be rejected either.
|
||||
*/
|
||||
template <>
|
||||
std::optional<double> tryTo<double>(std::string_view symbol);
|
||||
|
||||
} // namespace torch::nativert
|
||||
153
torch/csrc/nativert/common/Enumerate.h
Normal file
153
torch/csrc/nativert/common/Enumerate.h
Normal file
@ -0,0 +1,153 @@
|
||||
/*
|
||||
* Ported from folly/container/Enumerate.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
|
||||
#include "c10/macros/Macros.h"
|
||||
|
||||
/**
|
||||
* Similar to Python's enumerate(), enumerate() can be used to
|
||||
* iterate a range with a for-range loop, and it also allows to
|
||||
* retrieve the count of iterations so far. Can be used in constexpr
|
||||
* context.
|
||||
*
|
||||
* For example:
|
||||
*
|
||||
* for (auto&& [index, element] : enumerate(vec)) {
|
||||
* // index is a const reference to a size_t containing the iteration count.
|
||||
* // element is a reference to the type contained within vec, mutable
|
||||
* // unless vec is const.
|
||||
* }
|
||||
*
|
||||
* If the binding is const, the element reference is too.
|
||||
*
|
||||
* for (const auto&& [index, element] : enumerate(vec)) {
|
||||
* // element is always a const reference.
|
||||
* }
|
||||
*
|
||||
* It can also be used as follows:
|
||||
*
|
||||
* for (auto&& it : enumerate(vec)) {
|
||||
* // *it is a reference to the current element. Mutable unless vec is const.
|
||||
* // it->member can be used as well.
|
||||
* // it.index contains the iteration count.
|
||||
* }
|
||||
*
|
||||
* As before, const auto&& it can also be used.
|
||||
*/
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T>
|
||||
struct MakeConst {
|
||||
using type = const T;
|
||||
};
|
||||
template <class T>
|
||||
struct MakeConst<T&> {
|
||||
using type = const T&;
|
||||
};
|
||||
template <class T>
|
||||
struct MakeConst<T*> {
|
||||
using type = const T*;
|
||||
};
|
||||
|
||||
template <class Iterator>
|
||||
class Enumerator {
|
||||
public:
|
||||
constexpr explicit Enumerator(Iterator it) : it_(std::move(it)) {}
|
||||
|
||||
class Proxy {
|
||||
public:
|
||||
using difference_type = ssize_t;
|
||||
using value_type = typename std::iterator_traits<Iterator>::value_type;
|
||||
using reference = typename std::iterator_traits<Iterator>::reference;
|
||||
using pointer = typename std::iterator_traits<Iterator>::pointer;
|
||||
using iterator_category = std::input_iterator_tag;
|
||||
|
||||
C10_ALWAYS_INLINE constexpr explicit Proxy(const Enumerator& e)
|
||||
: index(e.idx_), element(*e.it_) {}
|
||||
|
||||
// Non-const Proxy: Forward constness from Iterator.
|
||||
C10_ALWAYS_INLINE constexpr reference operator*() {
|
||||
return element;
|
||||
}
|
||||
C10_ALWAYS_INLINE constexpr pointer operator->() {
|
||||
return std::addressof(element);
|
||||
}
|
||||
|
||||
// Const Proxy: Force const references.
|
||||
C10_ALWAYS_INLINE constexpr typename MakeConst<reference>::type operator*()
|
||||
const {
|
||||
return element;
|
||||
}
|
||||
C10_ALWAYS_INLINE constexpr typename MakeConst<pointer>::type operator->()
|
||||
const {
|
||||
return std::addressof(element);
|
||||
}
|
||||
|
||||
public:
|
||||
const size_t index;
|
||||
reference element;
|
||||
};
|
||||
|
||||
C10_ALWAYS_INLINE constexpr Proxy operator*() const {
|
||||
return Proxy(*this);
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE constexpr Enumerator& operator++() {
|
||||
++it_;
|
||||
++idx_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename OtherIterator>
|
||||
C10_ALWAYS_INLINE constexpr bool operator==(
|
||||
const Enumerator<OtherIterator>& rhs) const {
|
||||
return it_ == rhs.it_;
|
||||
}
|
||||
|
||||
template <typename OtherIterator>
|
||||
C10_ALWAYS_INLINE constexpr bool operator!=(
|
||||
const Enumerator<OtherIterator>& rhs) const {
|
||||
return !(it_ == rhs.it_);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename OtherIterator>
|
||||
friend class Enumerator;
|
||||
|
||||
Iterator it_;
|
||||
size_t idx_ = 0;
|
||||
};
|
||||
|
||||
template <class Range>
|
||||
class RangeEnumerator {
|
||||
Range r_;
|
||||
using BeginIteratorType = decltype(std::declval<Range>().begin());
|
||||
using EndIteratorType = decltype(std::declval<Range>().end());
|
||||
|
||||
public:
|
||||
constexpr explicit RangeEnumerator(Range&& r) : r_(std::forward<Range>(r)) {}
|
||||
|
||||
constexpr Enumerator<BeginIteratorType> begin() {
|
||||
return Enumerator<BeginIteratorType>(r_.begin());
|
||||
}
|
||||
constexpr Enumerator<EndIteratorType> end() {
|
||||
return Enumerator<EndIteratorType>(r_.end());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class Range>
|
||||
constexpr detail::RangeEnumerator<Range> enumerate(Range&& r) {
|
||||
return detail::RangeEnumerator<Range>(std::forward<Range>(r));
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
189
torch/csrc/nativert/common/FileUtil.cpp
Normal file
189
torch/csrc/nativert/common/FileUtil.cpp
Normal file
@ -0,0 +1,189 @@
|
||||
#include "torch/csrc/nativert/common/FileUtil.h"
|
||||
|
||||
#include <unistd.h>
|
||||
#include <cerrno>
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
|
||||
inline void incr(ssize_t) {}
|
||||
template <typename Offset>
|
||||
inline void incr(ssize_t n, Offset& offset) {
|
||||
offset += static_cast<Offset>(n);
|
||||
}
|
||||
|
||||
// Wrap call to read/pread/write/pwrite(fd, buf, count, offset?) to retry on
|
||||
// incomplete reads / writes. The variadic argument magic is there to support
|
||||
// an additional argument (offset) for pread / pwrite; see the incr() functions
|
||||
// above which do nothing if the offset is not present and increment it if it
|
||||
// is.
|
||||
template <class F, class... Offset>
|
||||
ssize_t wrapFull(F f, int fd, void* buf, size_t count, Offset... offset) {
|
||||
char* b = static_cast<char*>(buf);
|
||||
ssize_t totalBytes = 0;
|
||||
ssize_t r;
|
||||
do {
|
||||
r = f(fd, b, count, offset...);
|
||||
if (r == -1) {
|
||||
if (errno == EINTR) {
|
||||
continue;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
totalBytes += r;
|
||||
b += r;
|
||||
count -= r;
|
||||
incr(r, offset...);
|
||||
} while (r != 0 && count); // 0 means EOF
|
||||
|
||||
return totalBytes;
|
||||
}
|
||||
|
||||
int filterCloseReturn(int r) {
|
||||
// Ignore EINTR. On Linux, close() may only return EINTR after the file
|
||||
// descriptor has been closed, so you must not retry close() on EINTR --
|
||||
// in the best case, you'll get EBADF, and in the worst case, you'll end up
|
||||
// closing a different file (one opened from another thread).
|
||||
//
|
||||
// Interestingly enough, the Single Unix Specification says that the state
|
||||
// of the file descriptor is unspecified if close returns EINTR. In that
|
||||
// case, the safe thing to do is also not to retry close() -- leaking a file
|
||||
// descriptor is definitely better than closing the wrong file.
|
||||
if (r == -1 && errno == EINTR) {
|
||||
return 0;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
// The following wrapX() funcions are private functions for wrapping file-io
|
||||
// against interrupt and partial op completions.
|
||||
|
||||
// Wrap call to f(args) in loop to retry on EINTR
|
||||
template <class F, class... Args>
|
||||
ssize_t wrapNoInt(F f, Args... args) {
|
||||
ssize_t r;
|
||||
do {
|
||||
r = f(args...);
|
||||
} while (r == -1 && errno == EINTR);
|
||||
return r;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int openNoInt(const char* name, int flags, mode_t mode) {
|
||||
// Android NDK bionic with FORTIFY has this definition:
|
||||
// https://android.googlesource.com/platform/bionic/+/9349b9e51b/libc/include/bits/fortify/fcntl.h
|
||||
// ```
|
||||
// __BIONIC_ERROR_FUNCTION_VISIBILITY
|
||||
// int open(const char* pathname, int flags, mode_t modes, ...) __overloadable
|
||||
// __errorattr(__open_too_many_args_error);
|
||||
// ```
|
||||
// This is originally to prevent open() with incorrect parameters.
|
||||
//
|
||||
// However, combined with folly wrapNotInt, template deduction will fail.
|
||||
// In this case, we create a custom lambda to bypass the error.
|
||||
// The solution is referenced from
|
||||
// https://github.com/llvm/llvm-project/commit/0a0e411204a2baa520fd73a8d69b664f98b428ba
|
||||
//
|
||||
auto openWrapper = [&] { return open(name, flags, mode); };
|
||||
return int(wrapNoInt(openWrapper));
|
||||
}
|
||||
|
||||
int closeNoInt(int fd) {
|
||||
return filterCloseReturn(close(fd));
|
||||
}
|
||||
|
||||
ssize_t writeFull(int fd, const void* buf, size_t count) {
|
||||
return wrapFull(write, fd, const_cast<void*>(buf), count);
|
||||
}
|
||||
|
||||
ssize_t readFull(int fd, void* buf, size_t count) {
|
||||
return wrapFull(read, fd, buf, count);
|
||||
}
|
||||
|
||||
File::File(int fd, bool ownsFd) noexcept : fd_(fd), ownsFd_(ownsFd) {
|
||||
TORCH_CHECK(fd >= -1, "fd must be -1 or non-negative");
|
||||
TORCH_CHECK(fd != -1 || !ownsFd, "cannot own -1");
|
||||
}
|
||||
|
||||
File::File(std::string_view name, int flags, mode_t mode)
|
||||
: fd_(::open(name.data(), flags, mode)), ownsFd_(false) {
|
||||
if (fd_ == -1) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"open(\"{}\", {}, 0{}) failed with errno {}.",
|
||||
name,
|
||||
flags,
|
||||
mode,
|
||||
errno));
|
||||
}
|
||||
ownsFd_ = true;
|
||||
}
|
||||
|
||||
File::File(File&& other) noexcept : fd_(other.fd_), ownsFd_(other.ownsFd_) {
|
||||
other.release();
|
||||
}
|
||||
|
||||
File& File::operator=(File&& other) {
|
||||
closeNoThrow();
|
||||
swap(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
File::~File() {
|
||||
auto fd = fd_;
|
||||
if (!closeNoThrow()) { // ignore most errors
|
||||
TORCH_CHECK(
|
||||
errno != EBADF,
|
||||
"closing fd ",
|
||||
fd,
|
||||
", it may already ",
|
||||
"have been closed. Another time, this might close the wrong FD.");
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ File File::temporary() {
|
||||
// make a temp file with tmpfile(), dup the fd, then return it in a File.
|
||||
FILE* tmpFile = tmpfile();
|
||||
if (!tmpFile) {
|
||||
throw std::runtime_error("tmpfile() failed");
|
||||
}
|
||||
auto guard = c10::make_scope_exit([&]() { fclose(tmpFile); });
|
||||
|
||||
int fd = ::dup(fileno(tmpFile));
|
||||
if (fd == -1) {
|
||||
throw std::runtime_error("dup() failed");
|
||||
}
|
||||
|
||||
return File(fd, true);
|
||||
}
|
||||
|
||||
int File::release() noexcept {
|
||||
int released = fd_;
|
||||
fd_ = -1;
|
||||
ownsFd_ = false;
|
||||
return released;
|
||||
}
|
||||
|
||||
void File::swap(File& other) noexcept {
|
||||
using std::swap;
|
||||
swap(fd_, other.fd_);
|
||||
swap(ownsFd_, other.ownsFd_);
|
||||
}
|
||||
|
||||
void File::close() {
|
||||
if (!closeNoThrow()) {
|
||||
throw std::runtime_error("close() failed");
|
||||
}
|
||||
}
|
||||
|
||||
bool File::closeNoThrow() {
|
||||
int r = ownsFd_ ? ::close(fd_) : 0;
|
||||
release();
|
||||
return r == 0;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
230
torch/csrc/nativert/common/FileUtil.h
Normal file
230
torch/csrc/nativert/common/FileUtil.h
Normal file
@ -0,0 +1,230 @@
|
||||
#pragma once
|
||||
|
||||
/*
|
||||
* Ported from folly/FileUtil.h
|
||||
*/
|
||||
#include <limits>
|
||||
#include <string_view>
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include "c10/util/Exception.h"
|
||||
#include "c10/util/ScopeExit.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
class File {
|
||||
public:
|
||||
/**
|
||||
* Creates an empty File object, for late initialization.
|
||||
*/
|
||||
constexpr File() noexcept : fd_(-1), ownsFd_(false) {}
|
||||
|
||||
/**
|
||||
* Create a File object from an existing file descriptor.
|
||||
*
|
||||
* @param fd Existing file descriptor
|
||||
* @param ownsFd Takes ownership of the file descriptor if ownsFd is true.
|
||||
*/
|
||||
explicit File(int fd, bool ownsFd = false) noexcept;
|
||||
|
||||
/**
|
||||
* Open and create a file object. Throws on error.
|
||||
* Owns the file descriptor implicitly.
|
||||
*/
|
||||
explicit File(
|
||||
std::string_view name,
|
||||
int flags = O_RDONLY,
|
||||
mode_t mode = 0666);
|
||||
|
||||
~File();
|
||||
|
||||
/**
|
||||
* Create and return a temporary, owned file (uses tmpfile()).
|
||||
*/
|
||||
static File temporary();
|
||||
|
||||
/**
|
||||
* Return the file descriptor, or -1 if the file was closed.
|
||||
*/
|
||||
int fd() const {
|
||||
return fd_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns 'true' iff the file was successfully opened.
|
||||
*/
|
||||
explicit operator bool() const {
|
||||
return fd_ != -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* If we own the file descriptor, close the file and throw on error.
|
||||
* Otherwise, do nothing.
|
||||
*/
|
||||
void close();
|
||||
|
||||
/**
|
||||
* Closes the file (if owned). Returns true on success, false (and sets
|
||||
* errno) on error.
|
||||
*/
|
||||
bool closeNoThrow();
|
||||
|
||||
/**
|
||||
* Returns and releases the file descriptor; no longer owned by this File.
|
||||
* Returns -1 if the File object didn't wrap a file.
|
||||
*/
|
||||
int release() noexcept;
|
||||
|
||||
/**
|
||||
* Swap this File with another.
|
||||
*/
|
||||
void swap(File& other) noexcept;
|
||||
|
||||
// movable
|
||||
File(File&&) noexcept;
|
||||
File& operator=(File&&);
|
||||
|
||||
private:
|
||||
// unique
|
||||
File(const File&) = delete;
|
||||
File& operator=(const File&) = delete;
|
||||
|
||||
int fd_;
|
||||
bool ownsFd_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Convenience wrappers around some commonly used system calls. The *NoInt
|
||||
* wrappers retry on EINTR. The *Full wrappers retry on EINTR and also loop
|
||||
* until all data is written. Note that *Full wrappers weaken the thread
|
||||
* semantics of underlying system calls.
|
||||
*/
|
||||
int openNoInt(const char* name, int flags, mode_t mode = 0666);
|
||||
int closeNoInt(int fd);
|
||||
|
||||
/**
|
||||
* Similar to readFull and preadFull above, wrappers around write() and
|
||||
* pwrite() that loop until all data is written.
|
||||
*
|
||||
* Generally, the write() / pwrite() system call may always write fewer bytes
|
||||
* than requested, just like read(). In certain cases (such as when writing to
|
||||
* a pipe), POSIX provides stronger guarantees, but not in the general case.
|
||||
* For example, Linux (even on a 64-bit platform) won't write more than 2GB in
|
||||
* one write() system call.
|
||||
*
|
||||
* Note that writevFull and pwritevFull require iov to be non-const, unlike
|
||||
* writev and pwritev. The contents of iov after these functions return
|
||||
* is unspecified.
|
||||
*
|
||||
* These functions return -1 on error, or the total number of bytes written
|
||||
* (which is always the same as the number of requested bytes) on success.
|
||||
*/
|
||||
ssize_t writeFull(int fd, const void* buf, size_t count);
|
||||
|
||||
/**
|
||||
* Wrapper around read() (and pread()) that, in addition to retrying on
|
||||
* EINTR, will loop until all data is read.
|
||||
*
|
||||
* This wrapper is only useful for blocking file descriptors (for non-blocking
|
||||
* file descriptors, you have to be prepared to deal with incomplete reads
|
||||
* anyway), and only exists because POSIX allows read() to return an incomplete
|
||||
* read if interrupted by a signal (instead of returning -1 and setting errno
|
||||
* to EINTR).
|
||||
*
|
||||
* Note that this wrapper weakens the thread safety of read(): the file pointer
|
||||
* is shared between threads, but the system call is atomic. If multiple
|
||||
* threads are reading from a file at the same time, you don't know where your
|
||||
* data came from in the file, but you do know that the returned bytes were
|
||||
* contiguous. You can no longer make this assumption if using readFull().
|
||||
* You should probably use pread() when reading from the same file descriptor
|
||||
* from multiple threads simultaneously, anyway.
|
||||
*
|
||||
* Note that readvFull and preadvFull require iov to be non-const, unlike
|
||||
* readv and preadv. The contents of iov after these functions return
|
||||
* is unspecified.
|
||||
*/
|
||||
[[nodiscard]] ssize_t readFull(int fd, void* buf, size_t count);
|
||||
|
||||
/**
|
||||
* Read entire file (if num_bytes is defaulted) or no more than
|
||||
* num_bytes (otherwise) into container *out. The container is assumed
|
||||
* to be contiguous, with element size equal to 1, and offer size(),
|
||||
* reserve(), and random access (e.g. std::vector<char>, std::string,
|
||||
* fbstring).
|
||||
*
|
||||
* Returns: true on success or false on failure. In the latter case
|
||||
* errno will be set appropriately by the failing system primitive.
|
||||
*/
|
||||
template <class Container>
|
||||
bool readFile(
|
||||
int fd,
|
||||
Container& out,
|
||||
size_t num_bytes = std::numeric_limits<size_t>::max()) {
|
||||
static_assert(
|
||||
sizeof(out[0]) == 1,
|
||||
"readFile: only containers with byte-sized elements accepted");
|
||||
|
||||
size_t soFar = 0; // amount of bytes successfully read
|
||||
auto guard = c10::make_scope_exit([&]() {
|
||||
assert(out.size() >= soFar); // resize better doesn't throw
|
||||
out.resize(soFar);
|
||||
});
|
||||
|
||||
// Obtain file size:
|
||||
struct stat buf;
|
||||
if (fstat(fd, &buf) == -1) {
|
||||
return false;
|
||||
}
|
||||
// Some files (notably under /proc and /sys on Linux) lie about
|
||||
// their size, so treat the size advertised by fstat under advise
|
||||
// but don't rely on it. In particular, if the size is zero, we
|
||||
// should attempt to read stuff. If not zero, we'll attempt to read
|
||||
// one extra byte.
|
||||
constexpr size_t initialAlloc = 1024 * 4;
|
||||
out.resize(std::min(
|
||||
buf.st_size > 0 ? (size_t(buf.st_size) + 1) : initialAlloc, num_bytes));
|
||||
|
||||
while (soFar < out.size()) {
|
||||
const auto actual = readFull(fd, &out[soFar], out.size() - soFar);
|
||||
if (actual == -1) {
|
||||
return false;
|
||||
}
|
||||
soFar += actual;
|
||||
if (soFar < out.size()) {
|
||||
// File exhausted
|
||||
break;
|
||||
}
|
||||
// Ew, allocate more memory. Use exponential growth to avoid
|
||||
// quadratic behavior. Cap size to num_bytes.
|
||||
out.resize(std::min(out.size() * 3 / 2, num_bytes));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Same as above, but takes in a file name instead of fd
|
||||
*/
|
||||
template <class Container>
|
||||
bool readFile(
|
||||
const char* file_name,
|
||||
Container& out,
|
||||
size_t num_bytes = std::numeric_limits<size_t>::max()) {
|
||||
TORCH_CHECK(file_name);
|
||||
|
||||
const auto fd = openNoInt(file_name, O_RDONLY | O_CLOEXEC);
|
||||
if (fd == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto guard = c10::make_scope_exit([&]() {
|
||||
// Ignore errors when closing the file
|
||||
closeNoInt(fd);
|
||||
});
|
||||
|
||||
return readFile(fd, out, num_bytes);
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
189
torch/csrc/nativert/common/IntrusiveList.h
Normal file
189
torch/csrc/nativert/common/IntrusiveList.h
Normal file
@ -0,0 +1,189 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
template <typename T>
|
||||
class IntrusiveList;
|
||||
|
||||
class IntrusiveListHook {
|
||||
template <typename P, typename T>
|
||||
friend class ListIterator;
|
||||
|
||||
template <typename T>
|
||||
friend class IntrusiveList;
|
||||
|
||||
IntrusiveListHook* next_;
|
||||
IntrusiveListHook* prev_;
|
||||
|
||||
void link_before(IntrusiveListHook* next_node) {
|
||||
next_ = next_node;
|
||||
prev_ = next_node->prev_;
|
||||
next_node->prev_ = this;
|
||||
prev_->next_ = this;
|
||||
}
|
||||
|
||||
public:
|
||||
IntrusiveListHook() : next_(this), prev_(this) {}
|
||||
|
||||
IntrusiveListHook(const IntrusiveListHook&) = delete;
|
||||
IntrusiveListHook& operator=(const IntrusiveListHook&) = delete;
|
||||
IntrusiveListHook(IntrusiveListHook&&) = delete;
|
||||
IntrusiveListHook& operator=(IntrusiveListHook&&) = delete;
|
||||
|
||||
void unlink() {
|
||||
TORCH_CHECK(is_linked());
|
||||
next_->prev_ = prev_;
|
||||
prev_->next_ = next_;
|
||||
next_ = this;
|
||||
prev_ = this;
|
||||
}
|
||||
|
||||
~IntrusiveListHook() {
|
||||
if (is_linked()) {
|
||||
unlink();
|
||||
}
|
||||
}
|
||||
|
||||
bool is_linked() const {
|
||||
return next_ != this;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename P, typename T>
|
||||
class ListIterator {
|
||||
static_assert(std::is_same_v<std::remove_const_t<P>, IntrusiveListHook>);
|
||||
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
|
||||
P* ptr_;
|
||||
|
||||
friend class IntrusiveList<T>;
|
||||
|
||||
public:
|
||||
using iterator_category = std::bidirectional_iterator_tag;
|
||||
using value_type = std::conditional_t<std::is_const_v<P>, const T, T>;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = value_type*;
|
||||
using reference = value_type&;
|
||||
|
||||
explicit ListIterator(P* ptr) : ptr_(ptr) {}
|
||||
|
||||
ListIterator(const ListIterator&) = default;
|
||||
ListIterator& operator=(const ListIterator&) = default;
|
||||
|
||||
template <
|
||||
typename Q,
|
||||
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
|
||||
ListIterator(const ListIterator<Q, T>& rhs) : ptr_(rhs.ptr_) {}
|
||||
|
||||
template <
|
||||
typename Q,
|
||||
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
|
||||
ListIterator& operator=(const ListIterator<Q, T>& rhs) {
|
||||
ptr_ = rhs.ptr_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename Q>
|
||||
bool operator==(const ListIterator<Q, T>& other) const {
|
||||
return ptr_ == other.ptr_;
|
||||
}
|
||||
|
||||
template <typename Q>
|
||||
bool operator!=(const ListIterator<Q, T>& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
auto& operator*() const {
|
||||
return static_cast<reference>(*ptr_);
|
||||
}
|
||||
|
||||
ListIterator& operator++() {
|
||||
TORCH_CHECK(ptr_);
|
||||
ptr_ = ptr_->next_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ListIterator& operator--() {
|
||||
TORCH_CHECK(ptr_);
|
||||
ptr_ = ptr_->prev_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
auto* operator->() const {
|
||||
return static_cast<pointer>(ptr_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class IntrusiveList {
|
||||
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
|
||||
|
||||
public:
|
||||
using iterator = ListIterator<IntrusiveListHook, T>;
|
||||
using const_iterator = ListIterator<const IntrusiveListHook, T>;
|
||||
|
||||
auto begin() const {
|
||||
return ++const_iterator{&head_};
|
||||
}
|
||||
|
||||
auto begin() {
|
||||
return ++iterator{&head_};
|
||||
}
|
||||
|
||||
auto end() const {
|
||||
return const_iterator{&head_};
|
||||
}
|
||||
|
||||
auto end() {
|
||||
return iterator{&head_};
|
||||
}
|
||||
|
||||
auto rbegin() const {
|
||||
return std::reverse_iterator{end()};
|
||||
}
|
||||
|
||||
auto rbegin() {
|
||||
return std::reverse_iterator{end()};
|
||||
}
|
||||
|
||||
auto rend() const {
|
||||
return std::reverse_iterator{begin()};
|
||||
}
|
||||
|
||||
auto rend() {
|
||||
return std::reverse_iterator{begin()};
|
||||
}
|
||||
|
||||
auto iterator_to(const T& n) const {
|
||||
return const_iterator{&n};
|
||||
}
|
||||
|
||||
auto iterator_to(T& n) {
|
||||
return iterator{&n};
|
||||
}
|
||||
|
||||
iterator insert(iterator pos, T& n) {
|
||||
n.link_before(pos.ptr_);
|
||||
return iterator{&n};
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
size_t ret = 0;
|
||||
for ([[maybe_unused]] auto& _ : *this) {
|
||||
ret++;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
~IntrusiveList() {
|
||||
while (head_.is_linked()) {
|
||||
head_.next_->unlink();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
IntrusiveListHook head_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
48
torch/csrc/nativert/common/MPMCQueue.h
Normal file
48
torch/csrc/nativert/common/MPMCQueue.h
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* A simple multi-producer, multi-consumer queue we rolled on our own.
|
||||
*
|
||||
* This is a wrapper around std::deque that provides
|
||||
* lock-free readIfNotEmpty and writeIfNotFull.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <type_traits>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
// TODO (zhxchen17) Add wrapper for concurrentqueue.
|
||||
template <typename T>
|
||||
class MPMCQueue {
|
||||
static_assert(!std::is_reference_v<T>);
|
||||
|
||||
public:
|
||||
explicit MPMCQueue(size_t capacity) : capacity_(capacity) {}
|
||||
|
||||
bool readIfNotEmpty(T& out) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (storage_.empty()) {
|
||||
return false;
|
||||
}
|
||||
out = std::move(storage_.front());
|
||||
storage_.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool writeIfNotFull(T&& in) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (storage_.size() == capacity_) {
|
||||
return false;
|
||||
}
|
||||
storage_.push_back(std::move(in));
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::deque<T> storage_;
|
||||
const size_t capacity_;
|
||||
};
|
||||
} // namespace torch::nativert
|
||||
439
torch/csrc/nativert/common/Pytree.cpp
Normal file
439
torch/csrc/nativert/common/Pytree.cpp
Normal file
@ -0,0 +1,439 @@
|
||||
#include "torch/csrc/nativert/common/Pytree.h"
|
||||
#include "torch/csrc/nativert/common/RecordFunction.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <string_view>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/util/Synchronized.h>
|
||||
#include <nlohmann/json.hpp> // @manual=fbsource//third-party/nlohmann-json:nlohmann-json
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
inline constexpr int kDefaultTreeSpecSerializationProtocol = 1;
|
||||
|
||||
c10::IValue dynamicToIValue(const nlohmann::json& obj) {
|
||||
if (obj.is_string()) {
|
||||
return obj.get<std::string>();
|
||||
} else if (obj.is_number_integer()) {
|
||||
return obj.get<int64_t>();
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported dynamic type: ", obj);
|
||||
}
|
||||
}
|
||||
|
||||
void treeFlatten(
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec,
|
||||
std::vector<c10::IValue>& leaves) {
|
||||
if (spec.isLeaf()) {
|
||||
leaves.push_back(tree);
|
||||
return;
|
||||
}
|
||||
auto flattenFn = spec.nodeDefCache().flattenFn;
|
||||
flattenFn(tree, spec, leaves);
|
||||
}
|
||||
|
||||
class PytreeNodeRegistry {
|
||||
public:
|
||||
PytreeNodeRegistry() {
|
||||
// Add some law of physics here.
|
||||
registerNode(
|
||||
"builtins.tuple",
|
||||
NodeDef{
|
||||
[](const c10::IValue& tree,
|
||||
const TreeSpec& spec,
|
||||
std::vector<c10::IValue>& leaves) {
|
||||
const auto& tuple = tree.toTupleRef().elements();
|
||||
TORCH_CHECK_EQ(tuple.size(), spec.children().size());
|
||||
for (size_t i = 0; i < tuple.size(); i++) {
|
||||
treeFlatten(tuple[i], spec.children(i), leaves);
|
||||
}
|
||||
},
|
||||
[](std::vector<c10::IValue> flats,
|
||||
const nlohmann::json& obj) -> c10::IValue {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
|
||||
return c10::ivalue::Tuple::create(std::move(flats));
|
||||
},
|
||||
[](TreeMapNoReturnFn fn,
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec) {
|
||||
const auto& tuple = tree.toTupleRef().elements();
|
||||
TORCH_CHECK_EQ(tuple.size(), spec.children().size());
|
||||
for (size_t i = 0; i < tuple.size(); i++) {
|
||||
leafApply(fn, tuple[i], spec.children(i));
|
||||
}
|
||||
}});
|
||||
const auto& tupleNodeDef = getNodeDef("builtins.tuple");
|
||||
registerNode(
|
||||
"collections.namedtuple",
|
||||
NodeDef{
|
||||
tupleNodeDef.flattenFn,
|
||||
[](std::vector<c10::IValue> flats,
|
||||
const nlohmann::json& obj) -> c10::IValue {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
|
||||
return c10::ivalue::Tuple::create(std::move(flats));
|
||||
},
|
||||
tupleNodeDef.leafApplyFn,
|
||||
[](std::string_view context) { return nlohmann::json{context}; }});
|
||||
registerNode(
|
||||
"builtins.list",
|
||||
NodeDef{
|
||||
[](const c10::IValue& tree,
|
||||
const TreeSpec& spec,
|
||||
std::vector<c10::IValue>& leaves) {
|
||||
auto list = tree.toList();
|
||||
for (size_t i = 0; i < list.size(); i++) {
|
||||
treeFlatten(list[i], spec.children(i), leaves);
|
||||
}
|
||||
},
|
||||
[](std::vector<c10::IValue> flats,
|
||||
const nlohmann::json& obj) -> c10::IValue {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
|
||||
c10::List<c10::IValue> list(c10::AnyType::get());
|
||||
list.reserve(flats.size());
|
||||
for (auto& flat : flats) {
|
||||
list.push_back(std::move(flat));
|
||||
}
|
||||
return list;
|
||||
},
|
||||
[](TreeMapNoReturnFn fn,
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec) {
|
||||
auto list = tree.toList();
|
||||
for (size_t i = 0; i < list.size(); i++) {
|
||||
leafApply(fn, list[i], spec.children(i));
|
||||
}
|
||||
}});
|
||||
registerNode(
|
||||
"torch.fx.immutable_collections.immutable_list",
|
||||
getNodeDef("builtins.list"));
|
||||
registerNode(
|
||||
"builtins.dict",
|
||||
NodeDef{
|
||||
[](const c10::IValue& tree,
|
||||
const TreeSpec& spec,
|
||||
std::vector<c10::IValue>& leaves) {
|
||||
auto dict = tree.toGenericDict();
|
||||
const auto& context = spec.context();
|
||||
TORCH_CHECK_EQ(dict.size(), context.size());
|
||||
size_t i = 0;
|
||||
for (const auto& keyObj : context) {
|
||||
auto key = dynamicToIValue(keyObj);
|
||||
auto it = dict.find(key);
|
||||
|
||||
if (it != dict.end()) {
|
||||
treeFlatten(it->value(), spec.children(i), leaves);
|
||||
} else {
|
||||
// when we have a dict with missing keys, we fill the missing
|
||||
// leaves with c10::IValue()
|
||||
for (size_t j = 0; j < spec.children(i).numLeaves(); ++j) {
|
||||
leaves.emplace_back();
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
},
|
||||
[](std::vector<c10::IValue> flats,
|
||||
const nlohmann::json& obj) -> c10::IValue {
|
||||
c10::Dict<c10::IValue, c10::IValue> dict(
|
||||
c10::AnyType::get(), c10::AnyType::get());
|
||||
TORCH_CHECK(obj.is_array());
|
||||
TORCH_CHECK_EQ(obj.size(), flats.size());
|
||||
dict.reserve(flats.size());
|
||||
for (size_t i = 0; i < flats.size(); i++) {
|
||||
dict.insert(dynamicToIValue(obj[i]), std::move(flats[i]));
|
||||
}
|
||||
return dict;
|
||||
},
|
||||
[](TreeMapNoReturnFn fn,
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec) {
|
||||
auto dict = tree.toGenericDict();
|
||||
const auto& context = spec.context();
|
||||
|
||||
TORCH_CHECK(
|
||||
dict.size() <= context.size(),
|
||||
"input dict has more keys than treeSepc");
|
||||
|
||||
size_t i = 0;
|
||||
for (const auto& keyObj : context) {
|
||||
auto key = dynamicToIValue(keyObj);
|
||||
auto it = dict.find(key);
|
||||
if (it != dict.end()) {
|
||||
leafApply(fn, it->value(), spec.children(i));
|
||||
} else {
|
||||
// when we have a dict with missing keys, we run fn
|
||||
// on leaves with value of c10::IValue()
|
||||
for (size_t j = 0; j < spec.children(i).numLeaves(); ++j) {
|
||||
fn(c10::IValue());
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
}});
|
||||
registerNode(
|
||||
"torch.fx.immutable_collections.immutable_dict",
|
||||
getNodeDef("builtins.dict"));
|
||||
}
|
||||
bool hasNodeDef(std::string_view typeName) const {
|
||||
return registry_.find(std::string{typeName}) != registry_.end();
|
||||
}
|
||||
const NodeDef& getNodeDef(std::string_view typeName) const {
|
||||
return registry_.at(std::string{typeName});
|
||||
}
|
||||
void registerNode(std::string_view typeName, NodeDef nodeDef) {
|
||||
TORCH_CHECK(!hasNodeDef(typeName));
|
||||
registry_.emplace(typeName, std::move(nodeDef));
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, NodeDef> registry_;
|
||||
};
|
||||
|
||||
c10::Synchronized<PytreeNodeRegistry>& getPytreeNodeRegistry() {
|
||||
static auto* registry = new c10::Synchronized<PytreeNodeRegistry>();
|
||||
return *registry;
|
||||
}
|
||||
|
||||
TreeSpec makeTreeSpec(const nlohmann::json& obj) {
|
||||
TORCH_CHECK(obj.is_object());
|
||||
TORCH_CHECK(obj.find("type") != obj.end());
|
||||
if (obj["type"].is_null()) {
|
||||
TORCH_CHECK_EQ(obj["children_spec"].size(), 0);
|
||||
TORCH_CHECK(obj["context"].is_null());
|
||||
return TreeSpec{};
|
||||
}
|
||||
const auto& name = obj["type"].get<std::string>();
|
||||
NodeDef nodeDefCache;
|
||||
getPytreeNodeRegistry().withLock([&](auto& registry) {
|
||||
TORCH_CHECK(registry.hasNodeDef(name), "Unknown pytree node type: ", name);
|
||||
nodeDefCache = registry.getNodeDef(name);
|
||||
});
|
||||
auto context = nodeDefCache.contextLoadFn(obj["context"].get<std::string>());
|
||||
const auto& childrenSpec = obj["children_spec"];
|
||||
TORCH_CHECK(childrenSpec.is_array());
|
||||
std::vector<TreeSpec> children;
|
||||
for (const auto& child : childrenSpec) {
|
||||
children.push_back(makeTreeSpec(child));
|
||||
}
|
||||
return TreeSpec(name, context, std::move(children), std::move(nodeDefCache));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void registerPytreeNode(std::string_view typeName, NodeDef nodeDef) {
|
||||
getPytreeNodeRegistry().withLock([&](auto& registry) {
|
||||
registry.registerNode(typeName, std::move(nodeDef));
|
||||
});
|
||||
}
|
||||
|
||||
TreeSpec treeSpecLoads(std::string_view json) {
|
||||
const auto obj = nlohmann::json::parse(json);
|
||||
TORCH_CHECK(obj.is_array());
|
||||
TORCH_CHECK_EQ(obj.size(), 2);
|
||||
TORCH_CHECK_EQ(obj[0].get<int64_t>(), kDefaultTreeSpecSerializationProtocol);
|
||||
return makeTreeSpec(obj[1]);
|
||||
}
|
||||
|
||||
c10::IValue treeUnflatten(
|
||||
std::vector<c10::IValue> leaves,
|
||||
const TreeSpec& spec) {
|
||||
RecordFunction recordFunction("nativert::treeUnflatten");
|
||||
|
||||
TORCH_CHECK_EQ(leaves.size(), spec.numLeaves());
|
||||
if (spec.isLeaf()) {
|
||||
return std::move(leaves[0]);
|
||||
}
|
||||
auto unflattenFn = spec.nodeDefCache().unflattenFn;
|
||||
if (spec.allLeaves()) {
|
||||
return unflattenFn(std::move(leaves), spec.context());
|
||||
}
|
||||
size_t start = 0;
|
||||
std::vector<c10::IValue> childrenPytrees;
|
||||
for (const auto& child : spec.children()) {
|
||||
if (child.isLeaf()) {
|
||||
childrenPytrees.push_back(std::move(leaves[start]));
|
||||
start++;
|
||||
continue;
|
||||
}
|
||||
size_t numLeaves = child.numLeaves();
|
||||
std::vector<c10::IValue> slice(
|
||||
std::make_move_iterator(leaves.begin() + start),
|
||||
std::make_move_iterator(leaves.begin() + start + numLeaves));
|
||||
childrenPytrees.push_back(treeUnflatten(std::move(slice), child));
|
||||
start += numLeaves;
|
||||
}
|
||||
return unflattenFn(std::move(childrenPytrees), spec.context());
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> treeFlatten(
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec) {
|
||||
std::vector<c10::IValue> leaves;
|
||||
leaves.reserve(spec.numLeaves());
|
||||
treeFlatten(tree, spec, leaves);
|
||||
return leaves;
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> treeFlattenFromArgs(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& spec) {
|
||||
RecordFunction recordFunction("nativert::treeFlattenFromArgs");
|
||||
|
||||
TORCH_CHECK(!spec.isLeaf());
|
||||
TORCH_CHECK_EQ(spec.children().size(), 2);
|
||||
|
||||
std::vector<c10::IValue> leaves;
|
||||
leaves.reserve(spec.numLeaves());
|
||||
const auto& specArgs = spec.children(0);
|
||||
TORCH_CHECK(!specArgs.isLeaf());
|
||||
TORCH_CHECK_EQ(specArgs.children().size(), args.size());
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
treeFlatten(args[i], specArgs.children(i), leaves);
|
||||
}
|
||||
|
||||
const auto& specKwargs = spec.children(1);
|
||||
TORCH_CHECK(!specKwargs.isLeaf());
|
||||
TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size());
|
||||
for (size_t i = 0; i < specKwargs.context().size(); i++) {
|
||||
treeFlatten(
|
||||
kwargs.at(specKwargs.context()[i].get<std::string>()),
|
||||
specKwargs.children(i),
|
||||
leaves);
|
||||
}
|
||||
return leaves;
|
||||
}
|
||||
|
||||
void leafApplyFromArgs(
|
||||
TreeMapNoReturnFn fn,
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& spec) {
|
||||
RecordFunction recordFunction("nativert::leafApplyFromArgs");
|
||||
|
||||
TORCH_CHECK(!spec.isLeaf());
|
||||
TORCH_CHECK_EQ(spec.children().size(), 2);
|
||||
|
||||
std::vector<c10::IValue> leaves;
|
||||
leaves.reserve(spec.numLeaves());
|
||||
const auto& specArgs = spec.children(0);
|
||||
TORCH_CHECK(!specArgs.isLeaf());
|
||||
TORCH_CHECK_EQ(specArgs.children().size(), args.size());
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
leafApply(fn, args[i], specArgs.children(i));
|
||||
}
|
||||
|
||||
const auto& specKwargs = spec.children(1);
|
||||
TORCH_CHECK(!specKwargs.isLeaf());
|
||||
TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size());
|
||||
for (size_t i = 0; i < specKwargs.context().size(); i++) {
|
||||
leafApply(
|
||||
fn,
|
||||
kwargs.at(specKwargs.context()[i].get<std::string>()),
|
||||
specKwargs.children(i));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> treeFlattenToTensorList(
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec) {
|
||||
auto flats = treeFlatten(tree, spec);
|
||||
std::vector<at::Tensor> tensors;
|
||||
tensors.reserve(flats.size());
|
||||
for (const auto& flat : flats) {
|
||||
tensors.push_back(flat.toTensor());
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
c10::IValue
|
||||
treeMap(TreeMapFn f, const c10::IValue& tree, const TreeSpec& spec) {
|
||||
const auto flats = treeFlatten(tree, spec);
|
||||
std::vector<c10::IValue> mapped;
|
||||
mapped.reserve(flats.size());
|
||||
for (const auto& flat : flats) {
|
||||
mapped.push_back(f(flat));
|
||||
}
|
||||
return treeUnflatten(std::move(mapped), spec);
|
||||
}
|
||||
|
||||
c10::IValue argsToIValue(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||
c10::Dict<c10::IValue, c10::IValue> dict(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
for (const auto& [key, arg] : kwargs) {
|
||||
dict.insert(key, arg);
|
||||
}
|
||||
return c10::ivalue::Tuple::create({c10::ivalue::Tuple::create(args), dict});
|
||||
}
|
||||
|
||||
std::
|
||||
pair<std::vector<c10::IValue>, std::unordered_map<std::string, c10::IValue>>
|
||||
treeMapArgs(
|
||||
TreeMapFn f,
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& spec) {
|
||||
const auto val = argsToIValue(args, kwargs);
|
||||
const auto mapVal = treeMap(f, val, spec);
|
||||
auto mapArgs =
|
||||
mapVal.toTupleRef().elements()[0].toTupleRef().elements().vec();
|
||||
std::unordered_map<std::string, c10::IValue> mapKwargs;
|
||||
for (const auto& entry : mapVal.toTupleRef().elements()[1].toGenericDict()) {
|
||||
mapKwargs.emplace(entry.key().toStringRef(), entry.value());
|
||||
}
|
||||
return {std::move(mapArgs), std::move(mapKwargs)};
|
||||
}
|
||||
|
||||
void leafApply(
|
||||
TreeMapNoReturnFn fn,
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec) {
|
||||
if (spec.isLeaf()) {
|
||||
fn(tree);
|
||||
return;
|
||||
}
|
||||
auto leafApplyFn = spec.nodeDefCache().leafApplyFn;
|
||||
leafApplyFn(fn, tree, spec);
|
||||
}
|
||||
|
||||
nlohmann::json defaultContextLoadFn(std::string_view context) {
|
||||
return nlohmann::json::parse(context);
|
||||
}
|
||||
|
||||
c10::TypePtr TreeSpec::toAtenType() const {
|
||||
if (isLeaf()) {
|
||||
return c10::AnyType::get();
|
||||
} else if (uniformName_ == "builtins.tuple") {
|
||||
std::vector<c10::TypePtr> childrenType;
|
||||
for (const auto& childrenSpec : children_) {
|
||||
childrenType.emplace_back(childrenSpec.toAtenType());
|
||||
}
|
||||
return c10::TupleType::create(std::move(childrenType));
|
||||
} else if (
|
||||
uniformName_ == "builtins.list" ||
|
||||
uniformName_ == "torch.fx.immutable_collections.immutable_list") {
|
||||
if (children_.empty()) {
|
||||
return c10::ListType::create(c10::AnyType::get());
|
||||
} else {
|
||||
return c10::ListType::create(children_[0].toAtenType());
|
||||
}
|
||||
} else if (
|
||||
uniformName_ == "builtins.dict" ||
|
||||
uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
|
||||
if (children_.empty()) {
|
||||
return c10::DictType::create(c10::AnyType::get(), c10::AnyType::get());
|
||||
} else {
|
||||
return c10::DictType::create(
|
||||
dynamicToIValue(context_[0]).type(), children_[0].toAtenType());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported uniform name: ", uniformName_.value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
156
torch/csrc/nativert/common/Pytree.h
Normal file
156
torch/csrc/nativert/common/Pytree.h
Normal file
@ -0,0 +1,156 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class TreeSpec;
|
||||
|
||||
using TreeFlattenFn =
|
||||
void (*)(const c10::IValue&, const TreeSpec&, std::vector<c10::IValue>&);
|
||||
using TreeUnflattenFn =
|
||||
c10::IValue (*)(std::vector<c10::IValue>, const nlohmann::json&);
|
||||
|
||||
using ContextLoadFn = nlohmann::json (*)(std::string_view);
|
||||
|
||||
using TreeMapFn = c10::function_ref<c10::IValue(const c10::IValue&)>;
|
||||
using TreeMapNoReturnFn = c10::function_ref<void(const c10::IValue&)>;
|
||||
|
||||
using LeafApplyFn =
|
||||
void (*)(TreeMapNoReturnFn, const c10::IValue&, const TreeSpec&);
|
||||
|
||||
nlohmann::json defaultContextLoadFn(std::string_view);
|
||||
|
||||
struct NodeDef {
|
||||
TreeFlattenFn flattenFn;
|
||||
TreeUnflattenFn unflattenFn;
|
||||
LeafApplyFn leafApplyFn;
|
||||
|
||||
ContextLoadFn contextLoadFn = defaultContextLoadFn;
|
||||
};
|
||||
|
||||
class TreeSpec {
|
||||
public:
|
||||
// Leaf node.
|
||||
TreeSpec() : numLeaves_(1) {}
|
||||
|
||||
// Non leaf node.
|
||||
TreeSpec(
|
||||
std::string_view uniformName,
|
||||
nlohmann::json context,
|
||||
std::vector<TreeSpec> children,
|
||||
NodeDef nodeDefCache)
|
||||
: uniformName_(uniformName),
|
||||
context_(std::move(context)),
|
||||
children_(std::move(children)),
|
||||
nodeDefCache_(nodeDefCache),
|
||||
numLeaves_(0) {
|
||||
for (auto& child : children_) {
|
||||
numLeaves_ += child.numLeaves();
|
||||
allLeaves_ &= child.isLeaf();
|
||||
}
|
||||
}
|
||||
|
||||
bool isLeaf() const {
|
||||
return !uniformName_;
|
||||
}
|
||||
|
||||
std::string_view uniformName() const {
|
||||
TORCH_CHECK(uniformName_);
|
||||
return uniformName_.value();
|
||||
}
|
||||
|
||||
const nlohmann::json& context() const {
|
||||
return context_;
|
||||
}
|
||||
|
||||
const auto& children() const {
|
||||
return children_;
|
||||
}
|
||||
|
||||
const TreeSpec& children(size_t i) const {
|
||||
return children_[i];
|
||||
}
|
||||
|
||||
const NodeDef& nodeDefCache() const {
|
||||
return nodeDefCache_;
|
||||
}
|
||||
|
||||
size_t numLeaves() const {
|
||||
return numLeaves_;
|
||||
}
|
||||
|
||||
bool allLeaves() const {
|
||||
return allLeaves_;
|
||||
}
|
||||
|
||||
c10::TypePtr toAtenType() const;
|
||||
|
||||
private:
|
||||
// Only non leaf nodes have names.
|
||||
// Examples of uniform name: "builtins.tuple", "builtins.dict".
|
||||
std::optional<std::string> uniformName_;
|
||||
nlohmann::json context_;
|
||||
std::vector<TreeSpec> children_;
|
||||
|
||||
// Cached fields.
|
||||
NodeDef nodeDefCache_;
|
||||
size_t numLeaves_;
|
||||
bool allLeaves_ = true;
|
||||
};
|
||||
|
||||
void registerPytreeNode(std::string_view typeName, NodeDef nodeDef);
|
||||
|
||||
// Serialized json tree spec should be dumped from treespec_dumps() in
|
||||
// torch.utils._pytree directly .
|
||||
TreeSpec treeSpecLoads(std::string_view json);
|
||||
|
||||
c10::IValue treeUnflatten(
|
||||
std::vector<c10::IValue> leaves,
|
||||
const TreeSpec& spec);
|
||||
|
||||
std::vector<c10::IValue> treeFlatten(
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec);
|
||||
|
||||
std::vector<c10::IValue> treeFlattenFromArgs(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& spec);
|
||||
|
||||
std::vector<at::Tensor> treeFlattenToTensorList(
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec);
|
||||
|
||||
c10::IValue treeMap(TreeMapFn f, const c10::IValue& tree, const TreeSpec& spec);
|
||||
|
||||
c10::IValue TORCH_API argsToIValue(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs);
|
||||
|
||||
std::
|
||||
pair<std::vector<c10::IValue>, std::unordered_map<std::string, c10::IValue>>
|
||||
treeMapArgs(
|
||||
TreeMapFn f,
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& spec);
|
||||
|
||||
void leafApply(
|
||||
TreeMapNoReturnFn f,
|
||||
const c10::IValue& tree,
|
||||
const TreeSpec& spec);
|
||||
|
||||
void leafApplyFromArgs(
|
||||
TreeMapNoReturnFn fn,
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& spec);
|
||||
|
||||
} // namespace torch::nativert
|
||||
32
torch/csrc/nativert/common/RecordFunction.h
Normal file
32
torch/csrc/nativert/common/RecordFunction.h
Normal file
@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/library.h> //@manual=//caffe2:libtorch
|
||||
#include "torch/csrc/autograd/record_function_ops.h" //@manual=//caffe2:libtorch
|
||||
namespace torch::nativert {
|
||||
|
||||
/**
|
||||
* RAII-style wrapper that behaves similarly to torch.profiler.record_function.
|
||||
*/
|
||||
class RecordFunction {
|
||||
public:
|
||||
RecordFunction() = delete;
|
||||
RecordFunction(const RecordFunction&) = default;
|
||||
RecordFunction& operator=(const RecordFunction&) = default;
|
||||
RecordFunction(RecordFunction&&) = default;
|
||||
RecordFunction& operator=(RecordFunction&&) = default;
|
||||
|
||||
explicit RecordFunction(const std::string& name) {
|
||||
recordFunction_ =
|
||||
torch::autograd::profiler::record_function_enter_new(name);
|
||||
}
|
||||
|
||||
~RecordFunction() {
|
||||
recordFunction_->record.end();
|
||||
}
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction>
|
||||
recordFunction_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
38
torch/csrc/nativert/common/Semaphore.h
Normal file
38
torch/csrc/nativert/common/Semaphore.h
Normal file
@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
// To use moodycamel semaphore, we need to include the header file
|
||||
// for concurrentqueue first. Hiding implementation detail here.
|
||||
#ifdef BLOCK_SIZE
|
||||
#pragma push_macro("BLOCK_SIZE")
|
||||
#undef BLOCK_SIZE
|
||||
#include "torch/csrc/nativert/common/concurrentqueue.h"
|
||||
#pragma pop_macro("BLOCK_SIZE")
|
||||
#else
|
||||
#include "torch/csrc/nativert/common/concurrentqueue.h"
|
||||
#endif
|
||||
|
||||
#include "torch/csrc/nativert/common/lightweightsemaphore.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
// A textbook semaphore implementation. Nothing fancy.
|
||||
// In the future, we can consider using C++20's semaphore.
|
||||
class Semaphore {
|
||||
moodycamel::LightweightSemaphore impl_;
|
||||
|
||||
public:
|
||||
void release() {
|
||||
impl_.signal();
|
||||
}
|
||||
|
||||
void release(size_t n) {
|
||||
impl_.signal(n);
|
||||
}
|
||||
|
||||
void acquire() {
|
||||
impl_.wait();
|
||||
}
|
||||
};
|
||||
} // namespace torch::nativert
|
||||
46
torch/csrc/nativert/common/String.cpp
Normal file
46
torch/csrc/nativert/common/String.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
#include "torch/csrc/nativert/common/String.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
std::vector<std::string_view> split(std::string_view target, char delimiter) {
|
||||
std::vector<std::string_view> atoms;
|
||||
std::string_view buffer = target;
|
||||
while (buffer.size() > 0) {
|
||||
auto i = buffer.find(delimiter);
|
||||
if (i == std::string_view::npos) {
|
||||
atoms.push_back(buffer);
|
||||
buffer.remove_prefix(buffer.size());
|
||||
} else {
|
||||
atoms.push_back(buffer.substr(0, i));
|
||||
buffer.remove_prefix(i + 1);
|
||||
}
|
||||
}
|
||||
return atoms;
|
||||
}
|
||||
|
||||
std::string join(
|
||||
std::string_view delimiter,
|
||||
const std::vector<std::string>& keys) {
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < keys.size(); i++) {
|
||||
result << keys[i];
|
||||
if (i != keys.size() - 1) {
|
||||
result << delimiter;
|
||||
}
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
bool starts_with(std::string_view str, std::string_view prefix) {
|
||||
return str.size() >= prefix.size() &&
|
||||
str.compare(0, prefix.size(), prefix) == 0;
|
||||
}
|
||||
|
||||
bool ends_with(std::string_view str, std::string_view suffix) {
|
||||
return str.size() >= suffix.size() &&
|
||||
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
19
torch/csrc/nativert/common/String.h
Normal file
19
torch/csrc/nativert/common/String.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
std::vector<std::string_view> split(std::string_view target, char delimiter);
|
||||
|
||||
std::string join(
|
||||
std::string_view delimiter,
|
||||
const std::vector<std::string>& keys);
|
||||
|
||||
// These helpers should be replaced by string_view.starts_with and
|
||||
// string_view.ends_with in C++20, when they are available.
|
||||
bool starts_with(std::string_view target, std::string_view prefix);
|
||||
bool ends_with(std::string_view target, std::string_view prefix);
|
||||
|
||||
} // namespace torch::nativert
|
||||
4449
torch/csrc/nativert/common/concurrentqueue.h
Normal file
4449
torch/csrc/nativert/common/concurrentqueue.h
Normal file
File diff suppressed because it is too large
Load Diff
427
torch/csrc/nativert/common/lightweightsemaphore.h
Normal file
427
torch/csrc/nativert/common/lightweightsemaphore.h
Normal file
@ -0,0 +1,427 @@
|
||||
// Provides an efficient implementation of a semaphore (LightweightSemaphore).
|
||||
// This is an extension of Jeff Preshing's sempahore implementation (licensed
|
||||
// under the terms of its separate zlib license) that has been adapted and
|
||||
// extended by Cameron Desrochers.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <cstddef> // For std::size_t
|
||||
#include <type_traits> // For std::make_signed<T>
|
||||
|
||||
#if defined(_WIN32)
|
||||
// Avoid including windows.h in a header; we only need a handful of
|
||||
// items, so we'll redeclare them here (this is relatively safe since
|
||||
// the API generally has to remain stable between Windows versions).
|
||||
// I know this is an ugly hack but it still beats polluting the global
|
||||
// namespace with thousands of generic names or adding a .cpp for nothing.
|
||||
extern "C" {
|
||||
struct _SECURITY_ATTRIBUTES;
|
||||
__declspec(dllimport) void* __stdcall CreateSemaphoreW(
|
||||
_SECURITY_ATTRIBUTES* lpSemaphoreAttributes,
|
||||
long lInitialCount,
|
||||
long lMaximumCount,
|
||||
const wchar_t* lpName);
|
||||
__declspec(dllimport) int __stdcall CloseHandle(void* hObject);
|
||||
__declspec(dllimport) unsigned long __stdcall WaitForSingleObject(
|
||||
void* hHandle,
|
||||
unsigned long dwMilliseconds);
|
||||
__declspec(dllimport) int __stdcall ReleaseSemaphore(
|
||||
void* hSemaphore,
|
||||
long lReleaseCount,
|
||||
long* lpPreviousCount);
|
||||
}
|
||||
#elif defined(__MACH__)
|
||||
#include <mach/mach.h> // @manual
|
||||
#elif defined(__MVS__)
|
||||
#include <zos-semaphore.h> // @manual
|
||||
#elif defined(__unix__)
|
||||
#include <semaphore.h>
|
||||
|
||||
#if defined(__GLIBC_PREREQ) && defined(_GNU_SOURCE)
|
||||
#if __GLIBC_PREREQ(2, 30)
|
||||
#define MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace moodycamel {
|
||||
namespace details {
|
||||
|
||||
// Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's
|
||||
// portable + lightweight semaphore implementations, originally from
|
||||
// https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h
|
||||
// LICENSE:
|
||||
// Copyright (c) 2015 Jeff Preshing
|
||||
//
|
||||
// This software is provided 'as-is', without any express or implied
|
||||
// warranty. In no event will the authors be held liable for any damages
|
||||
// arising from the use of this software.
|
||||
//
|
||||
// Permission is granted to anyone to use this software for any purpose,
|
||||
// including commercial applications, and to alter it and redistribute it
|
||||
// freely, subject to the following restrictions:
|
||||
//
|
||||
// 1. The origin of this software must not be misrepresented; you must not
|
||||
// claim that you wrote the original software. If you use this software
|
||||
// in a product, an acknowledgement in the product documentation would be
|
||||
// appreciated but is not required.
|
||||
// 2. Altered source versions must be plainly marked as such, and must not be
|
||||
// misrepresented as being the original software.
|
||||
// 3. This notice may not be removed or altered from any source distribution.
|
||||
#if defined(_WIN32)
|
||||
class Semaphore {
|
||||
private:
|
||||
void* m_hSema;
|
||||
|
||||
Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
|
||||
Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
|
||||
|
||||
public:
|
||||
Semaphore(int initialCount = 0) {
|
||||
assert(initialCount >= 0);
|
||||
const long maxLong = 0x7fffffff;
|
||||
m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr);
|
||||
assert(m_hSema);
|
||||
}
|
||||
|
||||
~Semaphore() {
|
||||
CloseHandle(m_hSema);
|
||||
}
|
||||
|
||||
bool wait() {
|
||||
const unsigned long infinite = 0xffffffff;
|
||||
return WaitForSingleObject(m_hSema, infinite) == 0;
|
||||
}
|
||||
|
||||
bool try_wait() {
|
||||
return WaitForSingleObject(m_hSema, 0) == 0;
|
||||
}
|
||||
|
||||
bool timed_wait(std::uint64_t usecs) {
|
||||
return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) == 0;
|
||||
}
|
||||
|
||||
void signal(int count = 1) {
|
||||
while (!ReleaseSemaphore(m_hSema, count, nullptr))
|
||||
;
|
||||
}
|
||||
};
|
||||
#elif defined(__MACH__)
|
||||
//---------------------------------------------------------
|
||||
// Semaphore (Apple iOS and OSX)
|
||||
// Can't use POSIX semaphores due to
|
||||
// http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html
|
||||
//---------------------------------------------------------
|
||||
class Semaphore {
|
||||
private:
|
||||
semaphore_t m_sema;
|
||||
|
||||
Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
|
||||
Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
|
||||
|
||||
public:
|
||||
Semaphore(int initialCount = 0) {
|
||||
assert(initialCount >= 0);
|
||||
kern_return_t rc = semaphore_create(
|
||||
mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount);
|
||||
assert(rc == KERN_SUCCESS);
|
||||
(void)rc;
|
||||
}
|
||||
|
||||
~Semaphore() {
|
||||
semaphore_destroy(mach_task_self(), m_sema);
|
||||
}
|
||||
|
||||
bool wait() {
|
||||
return semaphore_wait(m_sema) == KERN_SUCCESS;
|
||||
}
|
||||
|
||||
bool try_wait() {
|
||||
return timed_wait(0);
|
||||
}
|
||||
|
||||
bool timed_wait(std::uint64_t timeout_usecs) {
|
||||
mach_timespec_t ts;
|
||||
ts.tv_sec = static_cast<unsigned int>(timeout_usecs / 1000000);
|
||||
ts.tv_nsec = static_cast<int>((timeout_usecs % 1000000) * 1000);
|
||||
|
||||
// added in OSX 10.10:
|
||||
// https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html
|
||||
kern_return_t rc = semaphore_timedwait(m_sema, ts);
|
||||
return rc == KERN_SUCCESS;
|
||||
}
|
||||
|
||||
void signal() {
|
||||
while (semaphore_signal(m_sema) != KERN_SUCCESS)
|
||||
;
|
||||
}
|
||||
|
||||
void signal(int count) {
|
||||
while (count-- > 0) {
|
||||
while (semaphore_signal(m_sema) != KERN_SUCCESS)
|
||||
;
|
||||
}
|
||||
}
|
||||
};
|
||||
#elif defined(__unix__) || defined(__MVS__)
|
||||
//---------------------------------------------------------
|
||||
// Semaphore (POSIX, Linux, zOS)
|
||||
//---------------------------------------------------------
|
||||
class Semaphore {
|
||||
private:
|
||||
sem_t m_sema;
|
||||
|
||||
Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
|
||||
Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
|
||||
|
||||
public:
|
||||
Semaphore(int initialCount = 0) {
|
||||
assert(initialCount >= 0);
|
||||
int rc = sem_init(&m_sema, 0, static_cast<unsigned int>(initialCount));
|
||||
assert(rc == 0);
|
||||
(void)rc;
|
||||
}
|
||||
|
||||
~Semaphore() {
|
||||
sem_destroy(&m_sema);
|
||||
}
|
||||
|
||||
bool wait() {
|
||||
// http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error
|
||||
int rc;
|
||||
do {
|
||||
rc = sem_wait(&m_sema);
|
||||
} while (rc == -1 && errno == EINTR);
|
||||
return rc == 0;
|
||||
}
|
||||
|
||||
bool try_wait() {
|
||||
int rc;
|
||||
do {
|
||||
rc = sem_trywait(&m_sema);
|
||||
} while (rc == -1 && errno == EINTR);
|
||||
return rc == 0;
|
||||
}
|
||||
|
||||
bool timed_wait(std::uint64_t usecs) {
|
||||
struct timespec ts;
|
||||
const int usecs_in_1_sec = 1000000;
|
||||
const int nsecs_in_1_sec = 1000000000;
|
||||
#ifdef MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC
|
||||
clock_gettime(CLOCK_MONOTONIC, &ts);
|
||||
#else
|
||||
clock_gettime(CLOCK_REALTIME, &ts);
|
||||
#endif
|
||||
ts.tv_sec += (time_t)(usecs / usecs_in_1_sec);
|
||||
ts.tv_nsec += (long)(usecs % usecs_in_1_sec) * 1000;
|
||||
// sem_timedwait bombs if you have more than 1e9 in tv_nsec
|
||||
// so we have to clean things up before passing it in
|
||||
if (ts.tv_nsec >= nsecs_in_1_sec) {
|
||||
ts.tv_nsec -= nsecs_in_1_sec;
|
||||
++ts.tv_sec;
|
||||
}
|
||||
|
||||
int rc;
|
||||
do {
|
||||
#ifdef MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC
|
||||
rc = sem_clockwait(&m_sema, CLOCK_MONOTONIC, &ts);
|
||||
#else
|
||||
rc = sem_timedwait(&m_sema, &ts);
|
||||
#endif
|
||||
} while (rc == -1 && errno == EINTR);
|
||||
return rc == 0;
|
||||
}
|
||||
|
||||
void signal() {
|
||||
while (sem_post(&m_sema) == -1)
|
||||
;
|
||||
}
|
||||
|
||||
void signal(int count) {
|
||||
while (count-- > 0) {
|
||||
while (sem_post(&m_sema) == -1)
|
||||
;
|
||||
}
|
||||
}
|
||||
};
|
||||
#else
|
||||
#error Unsupported platform! (No semaphore wrapper available)
|
||||
#endif
|
||||
|
||||
} // end namespace details
|
||||
|
||||
//---------------------------------------------------------
|
||||
// LightweightSemaphore
|
||||
//---------------------------------------------------------
|
||||
class LightweightSemaphore {
|
||||
public:
|
||||
typedef std::make_signed<std::size_t>::type ssize_t;
|
||||
|
||||
private:
|
||||
std::atomic<ssize_t> m_count;
|
||||
details::Semaphore m_sema;
|
||||
int m_maxSpins;
|
||||
|
||||
bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) {
|
||||
ssize_t oldCount;
|
||||
int spin = m_maxSpins;
|
||||
while (--spin >= 0) {
|
||||
oldCount = m_count.load(std::memory_order_relaxed);
|
||||
if ((oldCount > 0) &&
|
||||
m_count.compare_exchange_strong(
|
||||
oldCount,
|
||||
oldCount - 1,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed))
|
||||
return true;
|
||||
std::atomic_signal_fence(
|
||||
std::memory_order_acquire); // Prevent the compiler from collapsing
|
||||
// the loop.
|
||||
}
|
||||
oldCount = m_count.fetch_sub(1, std::memory_order_acquire);
|
||||
if (oldCount > 0)
|
||||
return true;
|
||||
if (timeout_usecs < 0) {
|
||||
if (m_sema.wait())
|
||||
return true;
|
||||
}
|
||||
if (timeout_usecs > 0 && m_sema.timed_wait((std::uint64_t)timeout_usecs))
|
||||
return true;
|
||||
// At this point, we've timed out waiting for the semaphore, but the
|
||||
// count is still decremented indicating we may still be waiting on
|
||||
// it. So we have to re-adjust the count, but only if the semaphore
|
||||
// wasn't signaled enough times for us too since then. If it was, we
|
||||
// need to release the semaphore too.
|
||||
while (true) {
|
||||
oldCount = m_count.load(std::memory_order_acquire);
|
||||
if (oldCount >= 0 && m_sema.try_wait())
|
||||
return true;
|
||||
if (oldCount < 0 &&
|
||||
m_count.compare_exchange_strong(
|
||||
oldCount,
|
||||
oldCount + 1,
|
||||
std::memory_order_relaxed,
|
||||
std::memory_order_relaxed))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ssize_t waitManyWithPartialSpinning(
|
||||
ssize_t max,
|
||||
std::int64_t timeout_usecs = -1) {
|
||||
assert(max > 0);
|
||||
ssize_t oldCount;
|
||||
int spin = m_maxSpins;
|
||||
while (--spin >= 0) {
|
||||
oldCount = m_count.load(std::memory_order_relaxed);
|
||||
if (oldCount > 0) {
|
||||
ssize_t newCount = oldCount > max ? oldCount - max : 0;
|
||||
if (m_count.compare_exchange_strong(
|
||||
oldCount,
|
||||
newCount,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed))
|
||||
return oldCount - newCount;
|
||||
}
|
||||
std::atomic_signal_fence(std::memory_order_acquire);
|
||||
}
|
||||
oldCount = m_count.fetch_sub(1, std::memory_order_acquire);
|
||||
if (oldCount <= 0) {
|
||||
if ((timeout_usecs == 0) || (timeout_usecs < 0 && !m_sema.wait()) ||
|
||||
(timeout_usecs > 0 &&
|
||||
!m_sema.timed_wait((std::uint64_t)timeout_usecs))) {
|
||||
while (true) {
|
||||
oldCount = m_count.load(std::memory_order_acquire);
|
||||
if (oldCount >= 0 && m_sema.try_wait())
|
||||
break;
|
||||
if (oldCount < 0 &&
|
||||
m_count.compare_exchange_strong(
|
||||
oldCount,
|
||||
oldCount + 1,
|
||||
std::memory_order_relaxed,
|
||||
std::memory_order_relaxed))
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max > 1)
|
||||
return 1 + tryWaitMany(max - 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
public:
|
||||
LightweightSemaphore(ssize_t initialCount = 0, int maxSpins = 10000)
|
||||
: m_count(initialCount), m_maxSpins(maxSpins) {
|
||||
assert(initialCount >= 0);
|
||||
assert(maxSpins >= 0);
|
||||
}
|
||||
|
||||
bool tryWait() {
|
||||
ssize_t oldCount = m_count.load(std::memory_order_relaxed);
|
||||
while (oldCount > 0) {
|
||||
if (m_count.compare_exchange_weak(
|
||||
oldCount,
|
||||
oldCount - 1,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool wait() {
|
||||
return tryWait() || waitWithPartialSpinning();
|
||||
}
|
||||
|
||||
bool wait(std::int64_t timeout_usecs) {
|
||||
return tryWait() || waitWithPartialSpinning(timeout_usecs);
|
||||
}
|
||||
|
||||
// Acquires between 0 and (greedily) max, inclusive
|
||||
ssize_t tryWaitMany(ssize_t max) {
|
||||
assert(max >= 0);
|
||||
ssize_t oldCount = m_count.load(std::memory_order_relaxed);
|
||||
while (oldCount > 0) {
|
||||
ssize_t newCount = oldCount > max ? oldCount - max : 0;
|
||||
if (m_count.compare_exchange_weak(
|
||||
oldCount,
|
||||
newCount,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed))
|
||||
return oldCount - newCount;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Acquires at least one, and (greedily) at most max
|
||||
ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) {
|
||||
assert(max >= 0);
|
||||
ssize_t result = tryWaitMany(max);
|
||||
if (result == 0 && max > 0)
|
||||
result = waitManyWithPartialSpinning(max, timeout_usecs);
|
||||
return result;
|
||||
}
|
||||
|
||||
ssize_t waitMany(ssize_t max) {
|
||||
ssize_t result = waitMany(max, -1);
|
||||
assert(result > 0);
|
||||
return result;
|
||||
}
|
||||
|
||||
void signal(ssize_t count = 1) {
|
||||
assert(count >= 0);
|
||||
ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release);
|
||||
ssize_t toRelease = -oldCount < count ? -oldCount : count;
|
||||
if (toRelease > 0) {
|
||||
m_sema.signal((int)toRelease);
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t availableApprox() const {
|
||||
ssize_t count = m_count.load(std::memory_order_relaxed);
|
||||
return count > 0 ? static_cast<std::size_t>(count) : 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace moodycamel
|
||||
140
torch/csrc/nativert/executor/AOTIDelegateExecutor.cpp
Normal file
140
torch/csrc/nativert/executor/AOTIDelegateExecutor.cpp
Normal file
@ -0,0 +1,140 @@
|
||||
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/common/RecordFunction.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/Weights.h"
|
||||
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
|
||||
|
||||
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::optional<at::ScalarType> parsePrecision(
|
||||
const std::optional<T>& precision) {
|
||||
if (precision) {
|
||||
return static_cast<at::ScalarType>(*precision);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AOTIDelegateExecutor::AOTIDelegateExecutor(
|
||||
const std::string& path,
|
||||
std::shared_ptr<Weights> weights,
|
||||
c10::Device device,
|
||||
const ExecutorConfig& executorConfig,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader) {
|
||||
std::string aotInductorModelFileName = path + "/aotinductor_pickle_data.json";
|
||||
|
||||
LOG(INFO) << "Loading aotinductor model from archive path: "
|
||||
<< aotInductorModelFileName;
|
||||
|
||||
CHECK(packageReader) << "Package reader cannot be null for lowered modules";
|
||||
CHECK(packageReader->hasRecord(aotInductorModelFileName))
|
||||
<< "Missing record " << aotInductorModelFileName;
|
||||
const auto& [aotInductorModelData, aotInductorModelSize] =
|
||||
packageReader->getRecord(aotInductorModelFileName);
|
||||
|
||||
const std::string aotInductorModelSerialized{
|
||||
reinterpret_cast<char*>(aotInductorModelData.get()),
|
||||
aotInductorModelSize};
|
||||
|
||||
LOG(INFO) << "Loaded aot_inductor_model: " << aotInductorModelSerialized;
|
||||
|
||||
auto aotInductorModel =
|
||||
nlohmann::json::parse(aotInductorModelSerialized)
|
||||
.template get<torch::_export::AOTInductorModelPickleData>();
|
||||
|
||||
std::string tmpDir = extractToTemporaryFolder(packageReader, path);
|
||||
LOG(INFO) << "Extracted aot_inductor model to: " << tmpDir;
|
||||
|
||||
std::string modelName = aotInductorModel.get_library_basename();
|
||||
std::string modelPath = tmpDir + "/" + modelName;
|
||||
std::string externKernelNodesPath =
|
||||
tmpDir + "/" + modelName.substr(0, modelName.size() - 3) + ".json";
|
||||
|
||||
LOG(INFO) << "Creating AOTInductorModelImpl with device " << device.str();
|
||||
|
||||
// We have to read the custom_objs_config.json,
|
||||
// because the weights->customObjs_ keys are not the same as the arg names
|
||||
// in the externKernelNodesPath.
|
||||
std::string customObjsJsonPath = tmpDir + "/custom_objs_config.json";
|
||||
|
||||
std::ifstream customObjsJsonFile(customObjsJsonPath);
|
||||
std::unordered_map<std::string, c10::IValue> custom_objs;
|
||||
if (!customObjsJsonFile.is_open()) {
|
||||
// BC-compatible with old files that don't have custom_objs_config.json
|
||||
LOG(INFO) << "Unable to open file " + customObjsJsonPath;
|
||||
} else {
|
||||
LOG(INFO) << "Load custom object mapping from: " << customObjsJsonPath;
|
||||
|
||||
nlohmann::json customObjsJson;
|
||||
customObjsJsonFile >> customObjsJson;
|
||||
|
||||
// Populate custom_objs with the custom object names from the json file,
|
||||
// and the c10::IValue from the weights.
|
||||
for (auto& [customObjName, file_name] : customObjsJson.items()) {
|
||||
custom_objs[customObjName] = weights->getCustomObjByFileName(
|
||||
std::string(archive_spec::CONSTANTS_DIR) +
|
||||
file_name.get<std::string>());
|
||||
LOG(INFO) << "Copy custom object to FbProxyExecutor: " << customObjName
|
||||
<< " from " << file_name;
|
||||
}
|
||||
}
|
||||
aotInductorModelImpl_ =
|
||||
std::make_unique<torch::aot_inductor::AOTInductorModelImpl>(
|
||||
modelPath,
|
||||
tmpDir,
|
||||
aotInductorModel.get_input_names(),
|
||||
aotInductorModel.get_output_names(),
|
||||
parsePrecision(aotInductorModel.get_floating_point_input_dtype()),
|
||||
parsePrecision(aotInductorModel.get_floating_point_output_dtype()),
|
||||
externKernelNodesPath,
|
||||
device.str(),
|
||||
/*num_runtimes*/ executorConfig.maxNumConcurrentThreads,
|
||||
/*custom_objs*/ std::move(custom_objs));
|
||||
|
||||
auto constantInfos = aotInductorModelImpl_->getConstantInfos();
|
||||
for (const auto& [name, constantInfo] : constantInfos) {
|
||||
if (weights->contains(constantInfo.originalFqn)) {
|
||||
weightsNameMap_[constantInfo.originalFqn] = name;
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "AOTI's Constant " << constantInfo.originalFqn
|
||||
<< " is not found in weights, it's likely a constant created by AOTI constant folding. "
|
||||
<< "Valid weight FQNs are " << weights->toString();
|
||||
}
|
||||
}
|
||||
|
||||
// AOTI's DelegateExecutor doesn't need to call processWeights or
|
||||
// commitWeights here because it's invoked from Executor's ctor already.
|
||||
}
|
||||
|
||||
void AOTIDelegateExecutor::processWeights(std::shared_ptr<Weights> weights) {
|
||||
LOG(INFO) << "AOTIDelegateExecutor processing weights";
|
||||
std::unordered_map<std::string, torch::Tensor*> newWeights;
|
||||
for (const auto& [original_fqn, name] : weightsNameMap_) {
|
||||
newWeights.emplace(name, &weights->at(original_fqn));
|
||||
}
|
||||
|
||||
aotInductorModelImpl_->updateInactiveConstantBuffer(std::move(newWeights));
|
||||
aotInductorModelImpl_->runConstantFolding(/*use_inactive*/ true);
|
||||
}
|
||||
|
||||
void AOTIDelegateExecutor::commitWeights() {
|
||||
LOG(INFO) << "AOTIDelegateExecutor committing weights";
|
||||
aotInductorModelImpl_->swapConstantBuffers();
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> AOTIDelegateExecutor::run(
|
||||
std::vector<at::Tensor>& inputs) {
|
||||
RecordFunction func("nativert::AOTIDelegateExecutor::run");
|
||||
|
||||
std::vector<at::Tensor> outputs = aotInductorModelImpl_->forward(inputs);
|
||||
return outputs;
|
||||
}
|
||||
} // namespace torch::nativert
|
||||
37
torch/csrc/nativert/executor/AOTIDelegateExecutor.h
Normal file
37
torch/csrc/nativert/executor/AOTIDelegateExecutor.h
Normal file
@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
|
||||
|
||||
#include "torch/csrc/nativert/executor/AOTInductorModelImpl.h" // @manual=//sigmoid/core/executor:aoti_model_impl
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class Weights;
|
||||
class Node;
|
||||
|
||||
class AOTIDelegateExecutor : public DelegateExecutor {
|
||||
public:
|
||||
explicit AOTIDelegateExecutor(
|
||||
const std::string& path,
|
||||
std::shared_ptr<Weights> weights,
|
||||
c10::Device device,
|
||||
const ExecutorConfig& executorConfig,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader);
|
||||
~AOTIDelegateExecutor() override {}
|
||||
|
||||
void processWeights(std::shared_ptr<Weights> weights) override;
|
||||
|
||||
void commitWeights() override;
|
||||
|
||||
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<torch::aot_inductor::AOTInductorModelImpl>
|
||||
aotInductorModelImpl_;
|
||||
|
||||
// key is weight's original fqn, value is weight's name in AOTI
|
||||
std::unordered_map<std::string, std::string> weightsNameMap_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
605
torch/csrc/nativert/executor/AOTInductorModelImpl.cpp
Normal file
605
torch/csrc/nativert/executor/AOTInductorModelImpl.cpp
Normal file
@ -0,0 +1,605 @@
|
||||
#include "torch/csrc/nativert/executor/AOTInductorModelImpl.h" // @manual
|
||||
// TODO Always use OSS proxy executor.
|
||||
#ifdef FBCODE_CAFFE2
|
||||
#include "deeplearning/aot_inductor/fb/FbProxyExecutor.h"
|
||||
#else
|
||||
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h> // @manual
|
||||
#endif
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h> // @manual
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib> // for getenv
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ATen/Context.h" // @manual
|
||||
#if defined(__SIGRID_USE_GPU__)
|
||||
#include "ATen/cuda/CUDAContext.h" // @manual
|
||||
#include "c10/cuda/CUDAStream.h" // @manual
|
||||
#endif // __SIGRID_USE_GPU__
|
||||
|
||||
#include "torch/csrc/nativert/common/FileUtil.h"
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct GetLastArgType;
|
||||
|
||||
template <typename T>
|
||||
struct tag {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename Function, typename... Args>
|
||||
struct GetLastArgType<Function(Args...)> {
|
||||
using last_arg_type = typename decltype((tag<Args>{}, ...))::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AOTInductorCallImpl;
|
||||
|
||||
template <typename... Args>
|
||||
struct AOTInductorCallImpl<
|
||||
AOTIRuntimeError(AOTInductorModelContainerHandle*, Args...)> {
|
||||
// Special version for ModelContainer creation
|
||||
void operator()(
|
||||
AOTIRuntimeError (*f)(AOTInductorModelContainerHandle*, Args...),
|
||||
AOTInductorModelContainerHandle* handle,
|
||||
Args... args) {
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(f(handle, args...));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
struct AOTInductorCallImpl<
|
||||
AOTIRuntimeError(AOTInductorModelContainerHandle, Args...)> {
|
||||
using Function = AOTIRuntimeError(AOTInductorModelContainerHandle, Args...);
|
||||
template <typename... ArgsWithoutLastArgument>
|
||||
auto operator()(
|
||||
Function* f,
|
||||
AOTInductorModelContainerHandle handle,
|
||||
ArgsWithoutLastArgument... args) {
|
||||
std::remove_pointer_t<typename GetLastArgType<Function>::last_arg_type>
|
||||
result;
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(f(handle, args..., &result));
|
||||
return result;
|
||||
}
|
||||
void operator()(
|
||||
AOTIRuntimeError (*f)(AOTInductorModelContainerHandle, Args...),
|
||||
AOTInductorModelContainerHandle handle,
|
||||
Args... args) {
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(f(handle, args...));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Function, typename... Args>
|
||||
auto AOTInductorCall(
|
||||
Function* f,
|
||||
AOTInductorModelContainerHandle handle,
|
||||
Args... args) {
|
||||
return AOTInductorCallImpl<Function>()(f, handle, args...);
|
||||
}
|
||||
|
||||
template <typename Function>
|
||||
auto AOTInductorCallCreate(
|
||||
Function* f,
|
||||
AOTInductorModelContainerHandle* handle,
|
||||
size_t num_runtimes,
|
||||
bool is_cpu,
|
||||
const char* cubin_dir) {
|
||||
return AOTInductorCallImpl<Function>()(
|
||||
f, handle, num_runtimes, is_cpu, cubin_dir);
|
||||
}
|
||||
|
||||
template <typename Function>
|
||||
auto AOTInductorCallCreateWithDevice(
|
||||
Function* f,
|
||||
AOTInductorModelContainerHandle* handle,
|
||||
size_t num_runtimes,
|
||||
const char* device_str,
|
||||
const char* cubin_dir) {
|
||||
return AOTInductorCallImpl<Function>()(
|
||||
f, handle, num_runtimes, device_str, cubin_dir);
|
||||
}
|
||||
std::string getFileBasename(const std::string& filename) {
|
||||
const auto slash = filename.rfind('/');
|
||||
return slash != std::string::npos ? filename.substr(slash + 1) : filename;
|
||||
}
|
||||
|
||||
// TODO: can we simply use std::filesystem::exists?
|
||||
inline bool fileExists(const std::string& name) {
|
||||
const auto fd =
|
||||
torch::nativert::openNoInt(name.c_str(), O_RDONLY | O_CLOEXEC);
|
||||
if (fd == -1) {
|
||||
return false;
|
||||
}
|
||||
torch::nativert::closeNoInt(fd);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<ProxyExecutor> makeProxyExecutor(
|
||||
const std::string& filename,
|
||||
bool is_cpu,
|
||||
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs) {
|
||||
#ifdef FBCODE_CAFFE2
|
||||
return std::make_unique<FbProxyExecutor>(
|
||||
filename, is_cpu, std::move(custom_objs));
|
||||
#else
|
||||
return std::make_unique<OSSProxyExecutor>(filename, is_cpu);
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// private static
|
||||
std::vector<std::string> AOTInductorModelImpl::library_search_paths_;
|
||||
|
||||
AOTInductorModelImpl::AOTInductorModelImpl(
|
||||
const std::string& model_path,
|
||||
std::optional<std::string> cubin_dir,
|
||||
std::vector<std::string> input_names,
|
||||
std::vector<std::string> output_names,
|
||||
std::optional<at::ScalarType> input_dtype,
|
||||
std::optional<at::ScalarType> output_dtype,
|
||||
std::optional<std::string> extern_kernel_nodes_path,
|
||||
const std::string& device_str,
|
||||
int64_t num_runtimes,
|
||||
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs)
|
||||
: // handle_(dlopen(model_path.c_str(), RTLD_LAZY | RTLD_LOCAL)),
|
||||
libraryBasename_(getFileBasename(model_path)),
|
||||
libraryPath_(model_path),
|
||||
inputNames_(std::move(input_names)),
|
||||
outputNames_(std::move(output_names)),
|
||||
floatingPointInputDtype_(input_dtype),
|
||||
floatingPointOutputDtype_(output_dtype),
|
||||
deviceStr_(device_str) {
|
||||
LOG(INFO) << "Loading .so lib from " << model_path
|
||||
<< " onto device: " << device_str;
|
||||
handle_.reset(dlopen(model_path.c_str(), RTLD_NOW | RTLD_LOCAL));
|
||||
TORCH_CHECK(
|
||||
handle_ != nullptr, "could not dlopen ", model_path, ": ", dlerror());
|
||||
TORCH_CHECK(num_runtimes > 0, "num_runtimes must be positive");
|
||||
|
||||
if (extern_kernel_nodes_path.has_value()) {
|
||||
const std::string& filename = extern_kernel_nodes_path.value();
|
||||
if (fileExists(filename)) {
|
||||
LOG(INFO) << "Loading extern_kernel_nodes .json file from " << filename;
|
||||
|
||||
proxyExecutor_ =
|
||||
makeProxyExecutor(filename, is_cpu(), std::move(custom_objs));
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__SIGRID_USE_GPU__)
|
||||
// It's not clear what stream we want to use yet. Create a new one.
|
||||
// We could alternatively use the default stream, but that could cause extra
|
||||
// synchronization.
|
||||
using StreamGuard = std::unique_ptr<
|
||||
std::remove_pointer_t<cudaStream_t>,
|
||||
decltype(&cudaStreamDestroy)>;
|
||||
|
||||
std::optional<StreamGuard> creation_stream_guard = [&] {
|
||||
if (is_cpu()) {
|
||||
return std::optional<StreamGuard>();
|
||||
}
|
||||
cudaStream_t creation_stream;
|
||||
TORCH_CHECK(
|
||||
cudaStreamCreateWithFlags(&creation_stream, cudaStreamNonBlocking) ==
|
||||
cudaSuccess);
|
||||
return std::make_optional<StreamGuard>(creation_stream, cudaStreamDestroy);
|
||||
}();
|
||||
#endif // __SIGRID_USE_GPU__
|
||||
|
||||
#define LOAD_SYMBOL(var, name_str) \
|
||||
var = reinterpret_cast<decltype(var)>(dlsym(handle_.get(), name_str)); \
|
||||
TORCH_CHECK(var, "could not dlsym " name_str);
|
||||
|
||||
LOAD_SYMBOL(deleteFunc_, "AOTInductorModelContainerDelete");
|
||||
LOAD_SYMBOL(runFunc_, "AOTInductorModelContainerRun");
|
||||
LOAD_SYMBOL(getOutputNameFunc_, "AOTInductorModelContainerGetOutputName");
|
||||
LOAD_SYMBOL(getCallSpecFunc_, "AOTInductorModelContainerGetCallSpec");
|
||||
|
||||
// We never call these functions again after the constructor returns, so
|
||||
// there's no point in caching them in member variables.
|
||||
decltype(&AOTInductorModelContainerCreate) createFunc;
|
||||
decltype(&AOTInductorModelContainerGetInputName) getInputNameFunc;
|
||||
decltype(&AOTInductorModelContainerGetNumInputs) getNumInputsFunc;
|
||||
decltype(&AOTInductorModelContainerGetNumOutputs) getNumOutputsFunc;
|
||||
LOAD_SYMBOL(createFunc, "AOTInductorModelContainerCreate");
|
||||
LOAD_SYMBOL(getInputNameFunc, "AOTInductorModelContainerGetInputName");
|
||||
LOAD_SYMBOL(getNumInputsFunc, "AOTInductorModelContainerGetNumInputs");
|
||||
LOAD_SYMBOL(getNumOutputsFunc, "AOTInductorModelContainerGetNumOutputs");
|
||||
#undef LOAD_SYMBOL
|
||||
|
||||
#define LOAD_SYMBOL_WARN(var, name_str) \
|
||||
var = reinterpret_cast<decltype(var)>(dlsym(handle_.get(), name_str)); \
|
||||
if (!var) { \
|
||||
LOG(WARNING) << "Could not dlsym " << name_str; \
|
||||
}
|
||||
|
||||
// "AOTInductorModelContainerCreateWithDevice" is only available in the binary
|
||||
// compiled after Jan.15.2024
|
||||
decltype(&AOTInductorModelContainerCreateWithDevice) createFuncWithDevice;
|
||||
LOAD_SYMBOL_WARN(
|
||||
createFuncWithDevice, "AOTInductorModelContainerCreateWithDevice");
|
||||
|
||||
LOAD_SYMBOL_WARN(
|
||||
getNumConstantsFunc_, "AOTInductorModelContainerGetNumConstants");
|
||||
LOAD_SYMBOL_WARN(
|
||||
getConstantNameFunc_, "AOTInductorModelContainerGetConstantName");
|
||||
LOAD_SYMBOL_WARN(
|
||||
getConstantOriginalFQNFunc_,
|
||||
"AOTInductorModelContainerGetConstantOriginalFQN");
|
||||
LOAD_SYMBOL_WARN(
|
||||
getConstantFromFoldedFunc_,
|
||||
"AOTInductorModelContainerGetConstantFromFolded");
|
||||
LOAD_SYMBOL_WARN(
|
||||
getConstantTypeFunc_, "AOTInductorModelContainerGetConstantType");
|
||||
LOAD_SYMBOL_WARN(
|
||||
getConstantDtypeFunc_, "AOTInductorModelContainerGetConstantDtype");
|
||||
LOAD_SYMBOL_WARN(
|
||||
runConstantFoldingFunc_, "AOTInductorModelContainerRunConstantFolding");
|
||||
LOAD_SYMBOL_WARN(
|
||||
updateConstantBufferFunc_,
|
||||
"AOTInductorModelContainerUpdateConstantBuffer");
|
||||
LOAD_SYMBOL_WARN(
|
||||
updateInactiveConstantBufferFunc_,
|
||||
"AOTInductorModelContainerUpdateInactiveConstantBuffer");
|
||||
LOAD_SYMBOL_WARN(
|
||||
swapConstantBufferFunc_, "AOTInductorModelContainerSwapConstantBuffer");
|
||||
#undef LOAD_SYMBOL_WARN
|
||||
|
||||
if (createFuncWithDevice) {
|
||||
AOTInductorCallCreateWithDevice(
|
||||
createFuncWithDevice,
|
||||
&containerHandle_,
|
||||
num_runtimes,
|
||||
deviceStr_.c_str(),
|
||||
cubin_dir ? cubin_dir->c_str() : nullptr);
|
||||
} else {
|
||||
AOTInductorCallCreate(
|
||||
createFunc,
|
||||
&containerHandle_,
|
||||
num_runtimes,
|
||||
is_cpu(),
|
||||
cubin_dir ? cubin_dir->c_str() : nullptr);
|
||||
}
|
||||
|
||||
const auto num_inputs = AOTInductorCall(getNumInputsFunc, containerHandle_);
|
||||
const auto num_outputs = AOTInductorCall(getNumOutputsFunc, containerHandle_);
|
||||
TORCH_CHECK(
|
||||
inputNames_.size() == num_inputs,
|
||||
"the size of input_names is ",
|
||||
inputNames_.size(),
|
||||
", but the model expects ",
|
||||
num_inputs);
|
||||
TORCH_CHECK(
|
||||
outputNames_.size() == num_outputs,
|
||||
"the size of output_names is ",
|
||||
outputNames_.size(),
|
||||
", but the model expects ",
|
||||
num_outputs);
|
||||
|
||||
for (const auto idx : c10::irange(num_inputs)) {
|
||||
inputNameToIndex_.emplace(
|
||||
AOTInductorCall(getInputNameFunc, containerHandle_, idx), idx);
|
||||
}
|
||||
for (const auto idx : c10::irange(num_outputs)) {
|
||||
outputNameToIndex_.emplace(
|
||||
AOTInductorCall(getOutputNameFunc_, containerHandle_, idx), idx);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> AOTInductorModelImpl::processInputs(
|
||||
std::vector<torch::Tensor>& python_inputs) {
|
||||
RECORD_USER_SCOPE("AOTInductorModel::ProcessInputs");
|
||||
const auto num_inputs = inputNameToIndex_.size();
|
||||
TORCH_CHECK(
|
||||
python_inputs.size() == num_inputs,
|
||||
"User passed ",
|
||||
python_inputs.size(),
|
||||
" inputs, but the model expects ",
|
||||
num_inputs);
|
||||
std::vector<torch::Tensor> inputs(python_inputs.size());
|
||||
for (int python_input_idx = 0; python_input_idx < inputNames_.size();
|
||||
python_input_idx++) {
|
||||
auto input_name = inputNames_[python_input_idx];
|
||||
auto& input = python_inputs[python_input_idx];
|
||||
if (floatingPointInputDtype_ != std::nullopt && input.is_floating_point()) {
|
||||
// Need to keep input alive; cannot just stash result of to()
|
||||
// call in a local!
|
||||
input = input.to(*floatingPointInputDtype_);
|
||||
}
|
||||
// FIXME: get currect aot_input_idx once we figure out name-mapping
|
||||
// in AOTInductor.
|
||||
// Currently, we have strong assumption that python_input_idx
|
||||
// (fx inputs) is the same as aot_input_idx.
|
||||
// const auto aot_input_idx = input_name_to_index_.at(input_name);
|
||||
const auto aot_input_idx = python_input_idx;
|
||||
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
|
||||
inputs[aot_input_idx] = input;
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> AOTInductorModelImpl::processOutputs(
|
||||
std::vector<torch::Tensor>&& outputs) {
|
||||
if (floatingPointOutputDtype_.has_value()) {
|
||||
for (auto& output : outputs) {
|
||||
if (output.is_floating_point()) {
|
||||
output = output.to(*floatingPointOutputDtype_);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::move(outputs);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> AOTInductorModelImpl::forward(
|
||||
std::vector<torch::Tensor>& python_inputs) {
|
||||
RECORD_USER_SCOPE("AOTInductorModel::Forward");
|
||||
TORCH_CHECK(!python_inputs.empty());
|
||||
|
||||
std::vector<torch::Tensor> input_tensors = processInputs(python_inputs);
|
||||
|
||||
// For outputs, we only allocate a vector to hold returned tensor handles,
|
||||
// not allocating the actual output tensor storage here
|
||||
const auto num_outputs = outputNameToIndex_.size();
|
||||
std::vector<AtenTensorHandle> output_handles(num_outputs);
|
||||
|
||||
{
|
||||
auto input_handles =
|
||||
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
|
||||
input_tensors);
|
||||
#if defined(__SIGRID_USE_GPU__)
|
||||
const auto device = python_inputs[0].device();
|
||||
AOTInductorStreamHandle stream_handle = is_cpu() ? nullptr : [&] {
|
||||
const auto& cuda_stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
const auto stream_id = cuda_stream.stream();
|
||||
return reinterpret_cast<AOTInductorStreamHandle>(stream_id);
|
||||
}();
|
||||
#else
|
||||
AOTInductorStreamHandle stream_handle = nullptr;
|
||||
#endif // __SIGRID_USE_GPU__
|
||||
AOTIProxyExecutorHandle proxy_executor_handle =
|
||||
reinterpret_cast<AOTIProxyExecutorHandle>(proxyExecutor_.get());
|
||||
|
||||
RECORD_USER_SCOPE("AOTInductorModel::AOTInductorRuntime");
|
||||
AOTIRuntimeError run_result = runFunc_(
|
||||
containerHandle_,
|
||||
input_handles.data(),
|
||||
input_tensors.size(),
|
||||
output_handles.data(),
|
||||
output_handles.size(),
|
||||
stream_handle,
|
||||
proxy_executor_handle);
|
||||
if (run_result != AOTI_RUNTIME_SUCCESS) {
|
||||
std::stringstream ss;
|
||||
ss << "AOTInductorModel run failed with input spec: ";
|
||||
for (const auto& i : python_inputs) {
|
||||
ss << i.sizes() << ":" << i.dtype() << ", ";
|
||||
}
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
|
||||
return processOutputs(
|
||||
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
|
||||
output_handles.data(), output_handles.size()));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const char*> AOTInductorModelImpl::get_call_spec() {
|
||||
std::vector<const char*> call_spec = {nullptr, nullptr};
|
||||
getCallSpecFunc_(containerHandle_, call_spec.data(), &call_spec[1]);
|
||||
return call_spec;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, ConstantInfo>
|
||||
AOTInductorModelImpl::getConstantInfos() const {
|
||||
TORCH_CHECK(
|
||||
getNumConstantsFunc_, "getNumConstantsFunc_ was not loaded from .so");
|
||||
TORCH_CHECK(
|
||||
getConstantNameFunc_, "getConstantNameFunc_ was not loaded from .so");
|
||||
TORCH_CHECK(
|
||||
getConstantOriginalFQNFunc_,
|
||||
"getConstantOriginalFQNFunc_ was not loaded from .so");
|
||||
TORCH_CHECK(
|
||||
getConstantDtypeFunc_, "getConstantDtypeFunc_ was not loaded from .so");
|
||||
|
||||
std::unordered_map<std::string, ConstantInfo> result;
|
||||
auto num_constants = AOTInductorCall(getNumConstantsFunc_, containerHandle_);
|
||||
for (size_t i = 0; i < num_constants; ++i) {
|
||||
const auto name =
|
||||
AOTInductorCall(getConstantNameFunc_, containerHandle_, i);
|
||||
const auto original_fqn =
|
||||
AOTInductorCall(getConstantOriginalFQNFunc_, containerHandle_, i);
|
||||
const auto dtype =
|
||||
AOTInductorCall(getConstantDtypeFunc_, containerHandle_, i);
|
||||
|
||||
ConstantType constant_type = ConstantType::Unknown;
|
||||
if (getConstantTypeFunc_) {
|
||||
constant_type = static_cast<ConstantType>(
|
||||
AOTInductorCall(getConstantTypeFunc_, containerHandle_, i));
|
||||
}
|
||||
if (getConstantFromFoldedFunc_ &&
|
||||
AOTInductorCall(getConstantFromFoldedFunc_, containerHandle_, i)) {
|
||||
continue;
|
||||
}
|
||||
TORCH_CHECK(original_fqn, "Cannot find orignal FQN of constant ", name);
|
||||
|
||||
result.emplace(
|
||||
name,
|
||||
ConstantInfo{
|
||||
static_cast<at::ScalarType>(dtype), original_fqn, constant_type});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void AOTInductorModelImpl::runConstantFolding(bool use_inactive) {
|
||||
if (!runConstantFoldingFunc_) {
|
||||
// We will just return if runtime constant folding doesn't exist.
|
||||
// Only models compiled after 2024 Feb has such capability.
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(__SIGRID_USE_GPU__)
|
||||
AOTInductorStreamHandle stream_handle = is_cpu() ? nullptr : [&] {
|
||||
const auto& cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
const auto stream_id = cuda_stream.stream();
|
||||
return reinterpret_cast<AOTInductorStreamHandle>(stream_id);
|
||||
}();
|
||||
#else
|
||||
AOTInductorStreamHandle stream_handle = nullptr;
|
||||
#endif // __SIGRID_USE_GPU__
|
||||
AOTIProxyExecutorHandle proxy_executor_handle =
|
||||
reinterpret_cast<AOTIProxyExecutorHandle>(proxyExecutor_.get());
|
||||
|
||||
auto result = runConstantFoldingFunc_(
|
||||
containerHandle_, use_inactive, stream_handle, proxy_executor_handle);
|
||||
|
||||
TORCH_CHECK(
|
||||
result == AOTI_RUNTIME_SUCCESS, "Unable to run constant folding.");
|
||||
}
|
||||
|
||||
void AOTInductorModelImpl::updateConstantBuffer(
|
||||
std::unordered_map<std::string, torch::Tensor*>&& constants,
|
||||
bool use_inactive,
|
||||
bool validate_full_update) {
|
||||
TORCH_CHECK(
|
||||
updateConstantBufferFunc_,
|
||||
"updateConstantBufferFunc_ was not loaded from .so");
|
||||
|
||||
auto result = updateConstantBufferFunc_(
|
||||
containerHandle_,
|
||||
(AOTInductorConstantMapHandle)&constants,
|
||||
use_inactive,
|
||||
validate_full_update);
|
||||
TORCH_CHECK(
|
||||
result == AOTI_RUNTIME_SUCCESS, "Unable to update constant buffer");
|
||||
}
|
||||
|
||||
void AOTInductorModelImpl::updateInactiveConstantBuffer(
|
||||
std::unordered_map<std::string, torch::Tensor*>&& constants) {
|
||||
TORCH_CHECK(
|
||||
updateInactiveConstantBufferFunc_,
|
||||
"updateInactiveConstantBufferFunc_ was not loaded from .so");
|
||||
|
||||
auto result = updateInactiveConstantBufferFunc_(
|
||||
containerHandle_, (AOTInductorConstantMapHandle)&constants);
|
||||
TORCH_CHECK(
|
||||
result == AOTI_RUNTIME_SUCCESS,
|
||||
"Unable to update inactive constant buffer");
|
||||
}
|
||||
|
||||
void AOTInductorModelImpl::swapConstantBuffers() {
|
||||
TORCH_CHECK(
|
||||
swapConstantBufferFunc_,
|
||||
"swapConstantBufferFunc_ was not loaded from .so");
|
||||
|
||||
auto result = swapConstantBufferFunc_(containerHandle_);
|
||||
TORCH_CHECK(
|
||||
result == AOTI_RUNTIME_SUCCESS, "Unable to swap constant buffers");
|
||||
}
|
||||
|
||||
thread_local std::unordered_map<std::string, std::string>
|
||||
AOTInductorModelImpl::lib_name_to_path_;
|
||||
|
||||
thread_local bool AOTInductorModelImpl::deserialize_pickled_model_{true};
|
||||
|
||||
thread_local std::optional<std::string> AOTInductorModelImpl::cubin_dir_;
|
||||
|
||||
thread_local std::unordered_map<std::string, std::string>
|
||||
AOTInductorModelImpl::extern_kernels_spec_name_to_path_;
|
||||
|
||||
void AOTInductorModelImpl::registerLibraryNameToPathMap(
|
||||
std::unordered_map<std::string, std::string> map) {
|
||||
std::ostringstream ss;
|
||||
ss << "{\n";
|
||||
for (const auto& [k, v] : map) {
|
||||
ss << " " << k << " => " << v << ",\n";
|
||||
}
|
||||
ss << "}";
|
||||
|
||||
LOG(INFO) << "Registering .so lib paths: " << ss.str();
|
||||
lib_name_to_path_ = std::move(map);
|
||||
}
|
||||
|
||||
std::string AOTInductorModelImpl::getFullPathForLibraryName(
|
||||
const std::string& name) {
|
||||
auto path = lib_name_to_path_.find(name);
|
||||
std::ostringstream ss;
|
||||
ss << "{\n";
|
||||
for (const auto& [k, v] : lib_name_to_path_) {
|
||||
ss << " " << k << " => " << v << ",\n";
|
||||
}
|
||||
if ((path == lib_name_to_path_.end()) ||
|
||||
(!std::filesystem::exists(path->second))) {
|
||||
for (const auto& lib_path : library_search_paths_) {
|
||||
std::string fullpath =
|
||||
lib_path + std::filesystem::path::preferred_separator + name;
|
||||
if (std::filesystem::exists(fullpath)) {
|
||||
return fullpath;
|
||||
}
|
||||
ss << " searched for " << name << " at " << lib_path << ",\n";
|
||||
}
|
||||
}
|
||||
ss << "}";
|
||||
TORCH_CHECK(
|
||||
path != lib_name_to_path_.end(),
|
||||
"could not find full path for AOTInductor model .so named ",
|
||||
name,
|
||||
". available paths: ",
|
||||
ss.str());
|
||||
return path->second;
|
||||
}
|
||||
|
||||
void AOTInductorModelImpl::setCubinDir(std::optional<std::string> cubin_dir) {
|
||||
cubin_dir_ = cubin_dir;
|
||||
}
|
||||
|
||||
std::optional<std::string> AOTInductorModelImpl::getCubinDir() {
|
||||
return cubin_dir_;
|
||||
}
|
||||
|
||||
void AOTInductorModelImpl::registerExternKernelsSpecNameToPathMap(
|
||||
std::unordered_map<std::string, std::string> map) {
|
||||
std::ostringstream ss;
|
||||
ss << "{\n";
|
||||
for (const auto& [k, v] : map) {
|
||||
ss << " " << k << " => " << v << ",\n";
|
||||
}
|
||||
ss << "}";
|
||||
|
||||
LOG(INFO) << "Registering extern kernels spec paths: " << ss.str();
|
||||
extern_kernels_spec_name_to_path_ = std::move(map);
|
||||
}
|
||||
|
||||
std::optional<std::string>
|
||||
AOTInductorModelImpl::getFullPathForExternKernelsSpecName(
|
||||
const std::string& name) {
|
||||
auto it = extern_kernels_spec_name_to_path_.find(name);
|
||||
if (it == extern_kernels_spec_name_to_path_.end()) {
|
||||
LOG(INFO) << "Didn't find extern kernels spec file for " << name;
|
||||
return {};
|
||||
}
|
||||
if (!std::filesystem::exists(it->second)) {
|
||||
TORCH_CHECK(false, "Extern kernels spec file doesn't exist: ", it->second);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool AOTInductorModelImpl::getDeserializePickledModel() {
|
||||
return deserialize_pickled_model_;
|
||||
}
|
||||
|
||||
// Set thread local boolean to disable real loading from .so file
|
||||
// for reusing the same module later on
|
||||
void AOTInductorModelImpl::setDeserializePickledModel(
|
||||
bool deserializePickledModel) {
|
||||
deserialize_pickled_model_ = deserializePickledModel;
|
||||
}
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
201
torch/csrc/nativert/executor/AOTInductorModelImpl.h
Normal file
201
torch/csrc/nativert/executor/AOTInductorModelImpl.h
Normal file
@ -0,0 +1,201 @@
|
||||
#pragma once
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/interface.h> // @manual
|
||||
#include <torch/csrc/inductor/aoti_runtime/model.h> // @manual
|
||||
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h> // @manual
|
||||
#include <torch/torch.h> // @manual=//caffe2:torch-cpp
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include "c10/util/FbcodeMaps.h"
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
struct ConstantInfo {
|
||||
at::ScalarType dtype;
|
||||
std::string originalFqn;
|
||||
ConstantType type;
|
||||
};
|
||||
|
||||
class AOTInductorModelImpl {
|
||||
public:
|
||||
explicit AOTInductorModelImpl(
|
||||
const std::string& model_path,
|
||||
std::optional<std::string> cubin_dir,
|
||||
std::vector<std::string> input_names,
|
||||
std::vector<std::string> output_names,
|
||||
std::optional<at::ScalarType> input_dtype,
|
||||
std::optional<at::ScalarType> output_dtype,
|
||||
std::optional<std::string> extern_kernel_nodes_path,
|
||||
const std::string& device_str,
|
||||
int64_t num_runtimes = 2,
|
||||
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs =
|
||||
std::nullopt);
|
||||
|
||||
~AOTInductorModelImpl() {
|
||||
if (containerHandle_) {
|
||||
deleteFunc_(containerHandle_);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> forward(std::vector<torch::Tensor>& inputs);
|
||||
|
||||
std::vector<const char*> get_call_spec();
|
||||
|
||||
std::unordered_map<std::string, ConstantInfo> getConstantInfos() const;
|
||||
|
||||
void updateConstantBuffer(
|
||||
std::unordered_map<std::string, torch::Tensor*>&& constants,
|
||||
bool use_inactive,
|
||||
bool validate_full_update);
|
||||
|
||||
void updateInactiveConstantBuffer(
|
||||
std::unordered_map<std::string, torch::Tensor*>&& constants);
|
||||
|
||||
void runConstantFolding(bool use_inactive);
|
||||
|
||||
void swapConstantBuffers();
|
||||
|
||||
void profile(
|
||||
std::vector<torch::Tensor>& inputs,
|
||||
const std::string& filename,
|
||||
size_t num_iters);
|
||||
|
||||
// If we need to move or copy this object, then we should just
|
||||
// define a unique_ptr with deleter for the handle.
|
||||
AOTInductorModelImpl(const AOTInductorModelImpl&) = delete;
|
||||
AOTInductorModelImpl& operator=(const AOTInductorModelImpl&) = delete;
|
||||
|
||||
static void registerLibraryNameToPathMap(
|
||||
std::unordered_map<std::string, std::string> map);
|
||||
|
||||
static std::string getFullPathForLibraryName(const std::string& name);
|
||||
|
||||
static void setCubinDir(std::optional<std::string> cubin_dir);
|
||||
|
||||
static std::optional<std::string> getCubinDir();
|
||||
|
||||
static void registerExternKernelsSpecNameToPathMap(
|
||||
std::unordered_map<std::string, std::string> mapping);
|
||||
|
||||
static std::optional<std::string> getFullPathForExternKernelsSpecName(
|
||||
const std::string& name);
|
||||
|
||||
static bool getDeserializePickledModel();
|
||||
|
||||
static void setDeserializePickledModel(bool deserializePickledModel);
|
||||
|
||||
/*
|
||||
* Returns a path to .so file (either relative or absolute).
|
||||
*/
|
||||
const std::string& libraryPath() const {
|
||||
return libraryPath_;
|
||||
}
|
||||
|
||||
const std::string& libraryBasename() const {
|
||||
return libraryBasename_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& inputNames() const {
|
||||
return inputNames_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& outputNames() const {
|
||||
return outputNames_;
|
||||
}
|
||||
|
||||
const std::optional<at::ScalarType> floatingPointInputDtype() const {
|
||||
return floatingPointInputDtype_;
|
||||
}
|
||||
|
||||
const std::optional<at::ScalarType> floatingPointOutputDtype() const {
|
||||
return floatingPointOutputDtype_;
|
||||
}
|
||||
|
||||
static void add_library_search_path(const std::string& path) {
|
||||
library_search_paths_.push_back(path);
|
||||
}
|
||||
|
||||
bool is_cpu() const {
|
||||
return deviceStr_ == "cpu";
|
||||
}
|
||||
|
||||
private:
|
||||
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
|
||||
static std::vector<std::string> library_search_paths_;
|
||||
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
|
||||
static thread_local std::unordered_map<std::string, std::string>
|
||||
lib_name_to_path_;
|
||||
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
|
||||
static thread_local std::optional<std::string> cubin_dir_;
|
||||
|
||||
/*
|
||||
* Example:
|
||||
* {
|
||||
* "aaa.json": "/tmp/abcdef/aaa.json",
|
||||
* "bbb.json": "/tmp/abcdef/bbb.json",
|
||||
* }
|
||||
*/
|
||||
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
|
||||
static thread_local std::unordered_map<std::string, std::string>
|
||||
extern_kernels_spec_name_to_path_;
|
||||
|
||||
static thread_local bool deserialize_pickled_model_;
|
||||
|
||||
struct DlcloseDeleter {
|
||||
void operator()(void* p) const {
|
||||
if (p) {
|
||||
dlclose(p);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<torch::Tensor> processInputs(
|
||||
std::vector<torch::Tensor>& python_inputs);
|
||||
|
||||
std::vector<torch::Tensor> processOutputs(
|
||||
std::vector<torch::Tensor>&& outputs);
|
||||
|
||||
std::unique_ptr<void, DlcloseDeleter> handle_ = nullptr;
|
||||
AOTInductorModelContainerHandle containerHandle_;
|
||||
|
||||
decltype(&AOTInductorModelContainerDelete) deleteFunc_ = nullptr;
|
||||
decltype(&AOTInductorModelContainerRun) runFunc_ = nullptr;
|
||||
decltype(&AOTInductorModelContainerGetOutputName) getOutputNameFunc_ =
|
||||
nullptr;
|
||||
decltype(&AOTInductorModelContainerGetCallSpec) getCallSpecFunc_ = nullptr;
|
||||
decltype(&AOTInductorModelContainerGetNumConstants) getNumConstantsFunc_{
|
||||
nullptr};
|
||||
decltype(&AOTInductorModelContainerGetConstantName) getConstantNameFunc_{
|
||||
nullptr};
|
||||
decltype(&AOTInductorModelContainerGetConstantOriginalFQN)
|
||||
getConstantOriginalFQNFunc_{nullptr};
|
||||
decltype(&AOTInductorModelContainerGetConstantFromFolded)
|
||||
getConstantFromFoldedFunc_{nullptr};
|
||||
decltype(&AOTInductorModelContainerGetConstantType) getConstantTypeFunc_{
|
||||
nullptr};
|
||||
decltype(&AOTInductorModelContainerGetConstantDtype) getConstantDtypeFunc_{
|
||||
nullptr};
|
||||
decltype(&AOTInductorModelContainerRunConstantFolding)
|
||||
runConstantFoldingFunc_{nullptr};
|
||||
decltype(&AOTInductorModelContainerUpdateConstantBuffer)
|
||||
updateConstantBufferFunc_{nullptr};
|
||||
decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer)
|
||||
updateInactiveConstantBufferFunc_{nullptr};
|
||||
decltype(&AOTInductorModelContainerSwapConstantBuffer)
|
||||
swapConstantBufferFunc_{nullptr};
|
||||
|
||||
const std::string libraryBasename_;
|
||||
const std::string libraryPath_;
|
||||
const std::vector<std::string> inputNames_;
|
||||
const std::vector<std::string> outputNames_;
|
||||
const std::optional<at::ScalarType> floatingPointInputDtype_;
|
||||
const std::optional<at::ScalarType> floatingPointOutputDtype_;
|
||||
c10::FastMap<const char*, size_t> inputNameToIndex_;
|
||||
c10::FastMap<const char*, size_t> outputNameToIndex_;
|
||||
|
||||
std::unique_ptr<ProxyExecutor> proxyExecutor_;
|
||||
std::string deviceStr_;
|
||||
};
|
||||
} // namespace torch::aot_inductor
|
||||
166
torch/csrc/nativert/executor/ConstantFolder.cpp
Normal file
166
torch/csrc/nativert/executor/ConstantFolder.cpp
Normal file
@ -0,0 +1,166 @@
|
||||
|
||||
|
||||
#include "torch/csrc/nativert/executor/ConstantFolder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
|
||||
#include "torch/csrc/nativert/common/Enumerate.h"
|
||||
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/Weights.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
/*
|
||||
side effects:
|
||||
1. nodes deemed const-foldable nodes are unlinked from the graph.
|
||||
they are still owned by the graph (i.e., show up in graph.nodeOwner_)
|
||||
but are not accessible through the node iterator.
|
||||
|
||||
2. kernels associated with const-foldable nodes are removed from the
|
||||
'kernels' input
|
||||
|
||||
3. mark values deemed foldable as such, removing thier producers
|
||||
*/
|
||||
|
||||
void ConstantFolder::unlinkConstants(
|
||||
std::vector<std::unique_ptr<OpKernel>>& kernels) {
|
||||
TORCH_CHECK_EQ(kernels.size(), graph_.nodes().size())
|
||||
<< "graph node count and kernel count should be equal";
|
||||
|
||||
unlinked_ = true;
|
||||
|
||||
/* resolve all of the nodes that are const foldable */
|
||||
|
||||
c10::FastMap<Node*, uint32_t> nodeDynInputs;
|
||||
nodeDynInputs.reserve(graph_.nodes().size());
|
||||
|
||||
c10::FastMap<const Node*, std::unique_ptr<OpKernel>*> nodeKernels;
|
||||
nodeKernels.reserve(graph_.nodes().size());
|
||||
|
||||
const auto* input = &*graph_.nodes().begin();
|
||||
const auto* output = &*graph_.nodes().end();
|
||||
|
||||
{ // ignore prim.Input and prim.Output
|
||||
auto ct = 0;
|
||||
for (auto& n : graph_.nodes()) {
|
||||
if (&n == input || &n == output) {
|
||||
continue;
|
||||
}
|
||||
nodeDynInputs[&n] = n.numInputs();
|
||||
nodeKernels[&n] = &kernels[++ct];
|
||||
}
|
||||
}
|
||||
|
||||
const auto& inputsToWeights = graph_.signature().inputsToWeights();
|
||||
for (const auto& [inputName, weightName] : inputsToWeights) {
|
||||
for (auto* user : graph_.getValue(inputName)->users()) {
|
||||
if (user == input || user == output) {
|
||||
continue;
|
||||
}
|
||||
nodeDynInputs[user] -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// set of foldable nodes for dedupe purposes
|
||||
c10::FastSet<const Node*> foldable;
|
||||
|
||||
std::queue<Node*> constFoldableCandidates;
|
||||
for (auto& [node, ct] : nodeDynInputs) {
|
||||
if (ct++ /* will be decremented once dequeued */ == 0) {
|
||||
constFoldableCandidates.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
while (!constFoldableCandidates.empty()) {
|
||||
auto* candidate = constFoldableCandidates.front();
|
||||
constFoldableCandidates.pop();
|
||||
if (auto& ct = nodeDynInputs[candidate]; --ct == 0) {
|
||||
foldable.insert(candidate);
|
||||
foldables_.push_back(Foldable{
|
||||
.node = candidate, .kernel = std::move(*nodeKernels[candidate])});
|
||||
|
||||
candidate->unlink();
|
||||
|
||||
for (auto* user : candidate->users()) {
|
||||
if (user == output) {
|
||||
continue;
|
||||
}
|
||||
if (foldable.find(user) == foldable.end()) {
|
||||
constFoldableCandidates.push(user);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* out : candidate->outputs()) {
|
||||
auto* value = graph_.getValue(out->name());
|
||||
|
||||
value->setIsFolded();
|
||||
|
||||
// we only store folded values if there is a non-foldable user
|
||||
if (const auto& users = value->users();
|
||||
std::any_of(users.begin(), users.end(), [&](const auto* u) {
|
||||
return foldable.find(u) == foldable.end();
|
||||
})) {
|
||||
foldedOutputValueIds_.insert(value->id());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& f : foldables_) {
|
||||
VLOG(1) << "Const-folded node: " << *f.node;
|
||||
}
|
||||
|
||||
// remove moved (i.e., associated w/ const-folded nodes) kernels
|
||||
// from the input kernel vector
|
||||
kernels.erase(
|
||||
std::remove_if(
|
||||
kernels.begin(),
|
||||
kernels.end(),
|
||||
[](const auto& k) { return k == nullptr; }),
|
||||
kernels.end());
|
||||
|
||||
graph_.renumberValues();
|
||||
graph_.finalize();
|
||||
graph_.lint();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
side effects:
|
||||
1. weights whose users are ONLY const-foldable nodes will be removed
|
||||
from the 'weights' input
|
||||
*/
|
||||
|
||||
void ConstantFolder::evaluate(Weights& weights) {
|
||||
CHECK(unlinked_)
|
||||
<< "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants";
|
||||
|
||||
weights.validateAllWeightsLoaded();
|
||||
|
||||
ExecutionFrame frame(graph_);
|
||||
frame.setWeights(weights);
|
||||
|
||||
c10::FastMap<ValueId, c10::IValue> foldedValues;
|
||||
|
||||
for (const auto& f : foldables_) {
|
||||
f.kernel->compute(frame);
|
||||
|
||||
for (auto&& [i, out] : nativert::enumerate(f.node->outputs())) {
|
||||
if (foldedOutputValueIds_.find(out->id()) !=
|
||||
foldedOutputValueIds_.end()) {
|
||||
foldedValues[out->id()] = f.kernel->output(i, frame);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it = std::make_move_iterator(foldedValues.begin());
|
||||
it != std::make_move_iterator(foldedValues.end());
|
||||
++it) {
|
||||
auto [v, iv] = std::move(*it);
|
||||
weights.setConstFoldedValue(v, std::move(iv));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
55
torch/csrc/nativert/executor/ConstantFolder.h
Normal file
55
torch/csrc/nativert/executor/ConstantFolder.h
Normal file
@ -0,0 +1,55 @@
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
#include "torch/csrc/nativert/executor/Weights.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
struct Foldable {
|
||||
Node* node;
|
||||
std::unique_ptr<OpKernel> kernel;
|
||||
};
|
||||
|
||||
class ConstantFolder {
|
||||
public:
|
||||
explicit ConstantFolder(Graph& graph) : graph_(graph) {}
|
||||
|
||||
/*
|
||||
1. identify nodes without dynamic inputs, mark as foldable
|
||||
|
||||
2. traverse the nodes deemed foldable as if they were being evaluated,
|
||||
pushing nodes that become foldable after it's inputs were traversed.
|
||||
|
||||
unlink foldable nodes from the graph in the topological order in which
|
||||
they were traversed, storing the node and its associated kernel (moved
|
||||
from 'kernels') as a foldable in Constantfolder
|
||||
*/
|
||||
void unlinkConstants(
|
||||
/* kernels for const-foldable nodes will be removed from this vector */
|
||||
std::vector<std::unique_ptr<OpKernel>>& kernels);
|
||||
|
||||
/*
|
||||
1. execute foldables_ on an execution frame initialized with the passed-in
|
||||
weights, calling Weights::setConstFoldedValue if the folded value is
|
||||
consumed by a non-foldable node
|
||||
*/
|
||||
void evaluate(Weights& weights);
|
||||
|
||||
private:
|
||||
Graph& graph_;
|
||||
// unlinked nodes sorted in their topological order
|
||||
// s.t., they can be evaluated sequentially
|
||||
std::vector<Foldable> foldables_;
|
||||
|
||||
bool unlinked_{false};
|
||||
|
||||
c10::FastSet<ValueId> foldedOutputValueIds_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
52
torch/csrc/nativert/executor/DelegateExecutor.cpp
Normal file
52
torch/csrc/nativert/executor/DelegateExecutor.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include "c10/util/Logging.h"
|
||||
|
||||
#include "torch/csrc/nativert/common/FileUtil.h"
|
||||
#include "torch/csrc/nativert/common/String.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
std::string extractToTemporaryFolder(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader,
|
||||
const std::string& targetPath) {
|
||||
char outputDir[] = "/tmp/model_XXXXXX";
|
||||
char* tempdir = mkdtemp(outputDir);
|
||||
TORCH_CHECK(
|
||||
tempdir != nullptr,
|
||||
"error creating temporary directory for compiled model. errno: ",
|
||||
errno);
|
||||
|
||||
std::vector<std::string> allRecords = packageReader->getAllRecords();
|
||||
|
||||
for (const auto& path : allRecords) {
|
||||
if (!starts_with(path, targetPath) || ends_with(path, "/")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
packageReader->hasRecord(path), path, " not present in model package");
|
||||
auto [dataPointer, dataSize] = packageReader->getRecord(path);
|
||||
|
||||
std::string fileName = path.substr(path.rfind('/') + 1);
|
||||
std::string extractedFilename = std::string(outputDir) + "/" + fileName;
|
||||
|
||||
VLOG(1) << "Extracting " << extractedFilename
|
||||
<< " from archive path: " << path << " size: " << dataSize;
|
||||
|
||||
File extracted(extractedFilename, O_CREAT | O_WRONLY, 0640);
|
||||
const auto bytesWritten =
|
||||
writeFull(extracted.fd(), dataPointer.get(), dataSize);
|
||||
TORCH_CHECK(
|
||||
bytesWritten != -1,
|
||||
"failure copying from archive path ",
|
||||
path,
|
||||
" to temporary file");
|
||||
}
|
||||
|
||||
return std::string(outputDir);
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
47
torch/csrc/nativert/executor/DelegateExecutor.h
Normal file
47
torch/csrc/nativert/executor/DelegateExecutor.h
Normal file
@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class Weights;
|
||||
|
||||
std::string extractToTemporaryFolder(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader,
|
||||
const std::string& targetPath);
|
||||
|
||||
// This is the extension point for delegation backends.
|
||||
// Please refer to AOTIDelegateExecutor as an example.
|
||||
class DelegateExecutor {
|
||||
public:
|
||||
virtual ~DelegateExecutor() {}
|
||||
|
||||
// Runtime calls processWeights() to pass the weights to the delegate backend.
|
||||
// Typically, a backend would perform some form of validation and processing,
|
||||
// such as constant folding. The processed weights stays in the inactivate
|
||||
// state until commitWeights() is called.
|
||||
//
|
||||
// Weights tensors are co-owned by the runtime and the delegate backend.
|
||||
// In the regular inference run() path, neither Runtime or Delegate backend
|
||||
// can modify the weights tensor.
|
||||
// To support inplace weight update, weight tensors are be exposed by
|
||||
// ModelRunner::getWeights() to an external caller. The external caller can
|
||||
// then modify the weight tensors in-place. Such mutation would instantly
|
||||
// affect the weight tensors in the delegate backend.
|
||||
// When a weight tensor is no longer used by the delegate backend, the backend
|
||||
// must release it by decreasing a refcount. Runtime would
|
||||
// also release the refcount for weight tensor if it's no longer activte. The
|
||||
// underlying storage for weight tensors will be freed when the refcount
|
||||
// reaches 0.
|
||||
virtual void processWeights(std::shared_ptr<Weights> weights) = 0;
|
||||
|
||||
// This call activate the processed weights.
|
||||
virtual void commitWeights() = 0;
|
||||
|
||||
virtual std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) = 0;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
141
torch/csrc/nativert/executor/ExecutionFrame.cpp
Normal file
141
torch/csrc/nativert/executor/ExecutionFrame.cpp
Normal file
@ -0,0 +1,141 @@
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
ExecutionFrame::ExecutionFrame(const Graph& graph)
|
||||
: graph_(graph),
|
||||
allValues_(graph.numValues()),
|
||||
persistent_(graph.numValues()) {
|
||||
// load constant SymInts into execution frame
|
||||
for (const auto& [valueId, constSymintValue] :
|
||||
graph_.getConstantSymIntValues()) {
|
||||
setPersistentIValue(valueId, constSymintValue);
|
||||
}
|
||||
|
||||
for (const Node& node : graph_.nodes()) {
|
||||
if (node.target() == "torch.ops.higher_order.run_const_graph") {
|
||||
const auto& const_graph =
|
||||
std::get<std::unique_ptr<Graph>>(node.attributes().at(0).value);
|
||||
for (size_t i = 0; i < node.outputs().size(); ++i) {
|
||||
foldedConstIds_[std::string{const_graph->outputs().at(i)->name()}] =
|
||||
node.outputs()[i]->id();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionFrame::ExecutionFrame(const Graph& graph, const Weights& weights)
|
||||
: ExecutionFrame(graph) {
|
||||
setWeights(weights);
|
||||
}
|
||||
|
||||
void ExecutionFrame::setWeights(const Weights& weights) {
|
||||
weightVersion_ = weights.version();
|
||||
|
||||
const auto& inputsToWeights = graph_.signature().inputsToWeights();
|
||||
for (const auto& [inputName, weightName] : inputsToWeights) {
|
||||
const Value* value = graph_.getValue(inputName);
|
||||
setPersistentIValue(value->id(), weights.at(weightName));
|
||||
}
|
||||
|
||||
const auto& inputsToCustomObjs = graph_.signature().inputsToCustomObjs();
|
||||
for (const auto& [inputName, customObjName] : inputsToCustomObjs) {
|
||||
const Value* value = graph_.getValue(inputName);
|
||||
setPersistentIValue(value->id(), weights.getCustomObj(customObjName));
|
||||
}
|
||||
|
||||
for (const auto& [value, tensor] : weights.getFoldedConsts()) {
|
||||
setPersistentIValue(foldedConstIds_.at(value), tensor);
|
||||
}
|
||||
|
||||
for (const auto& [v, iv] : weights.getConstFoldedValues()) {
|
||||
setPersistentIValue(v, iv);
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionFrame::ExecutionFrame(
|
||||
const Graph& graph,
|
||||
size_t numValues,
|
||||
const std::vector<ValueId>&,
|
||||
const std::vector<ValueId>&)
|
||||
: graph_(graph) {
|
||||
allValues_.resize(numValues);
|
||||
}
|
||||
|
||||
void ExecutionFrame::setIValue(ValueId id, c10::IValue ivalue) {
|
||||
DCHECK(id < allValues_.size());
|
||||
allValues_[id] = std::move(ivalue);
|
||||
}
|
||||
|
||||
at::Tensor ExecutionFrame::getTensor(ValueId id) const {
|
||||
const auto& ivalue = getIValue(id);
|
||||
if (C10_LIKELY(ivalue.isTensor())) {
|
||||
return ivalue.toTensor();
|
||||
} else {
|
||||
throw std::runtime_error("getTensor called on non-tensor value");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ExecutionFrame::getUserOutputs() const {
|
||||
std::vector<c10::IValue> ret;
|
||||
ret.reserve(graph_.userOutputs().size());
|
||||
for (const auto& outputValue : graph_.userOutputs()) {
|
||||
if (std::holds_alternative<Value*>(outputValue)) {
|
||||
Value* valuePtr = std::get<Value*>(outputValue);
|
||||
if (valuePtr) {
|
||||
const auto& id = valuePtr->id();
|
||||
ret.push_back(getIValue(id));
|
||||
}
|
||||
} else if (std::holds_alternative<Constant>(outputValue)) {
|
||||
const Constant& constValue = std::get<Constant>(outputValue);
|
||||
ret.push_back(constantToIValue(constValue));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
c10::List<c10::IValue> ExecutionFrame::getUserOutputsAsTensorList() const {
|
||||
c10::List<c10::IValue> ret(c10::TensorType::get());
|
||||
ret.reserve(graph_.userOutputs().size());
|
||||
for (const auto& outputValue : graph_.userOutputs()) {
|
||||
if (std::holds_alternative<Value*>(outputValue)) {
|
||||
Value* valuePtr = std::get<Value*>(outputValue);
|
||||
if (valuePtr) {
|
||||
const auto& id = valuePtr->id();
|
||||
ret.push_back(getIValue(id));
|
||||
}
|
||||
} else if (std::holds_alternative<Constant>(outputValue)) {
|
||||
const Constant& constValue = std::get<Constant>(outputValue);
|
||||
ret.push_back(constantToIValue(constValue));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> ExecutionFrame::getAllOutputs()
|
||||
const {
|
||||
std::unordered_map<std::string, at::Tensor> ret;
|
||||
for (const auto& outputValue : graph_.outputs()) {
|
||||
const auto& name = outputValue->name();
|
||||
const auto& id = outputValue->id();
|
||||
ret.emplace(name, getTensor(id));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> ExecutionFrame::getBufferMutations()
|
||||
const {
|
||||
// key is buffer name, value is tensor to be written to buffer
|
||||
std::unordered_map<std::string, at::Tensor> ret;
|
||||
const auto& buffersToMutate = graph_.signature().buffersToMutate();
|
||||
for (auto& [mutationOutputName, bufferName] : buffersToMutate) {
|
||||
const auto& id = graph_.getValue(mutationOutputName)->id();
|
||||
ret.emplace(bufferName, getTensor(id));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
131
torch/csrc/nativert/executor/ExecutionFrame.h
Normal file
131
torch/csrc/nativert/executor/ExecutionFrame.h
Normal file
@ -0,0 +1,131 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "torch/csrc/nativert/executor/Weights.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp> // @manual
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
/**
|
||||
* This class encapsulate the stateful values of an execution,
|
||||
* most notably, the tensor values passed between nodes, aka intermediate
|
||||
* activations.
|
||||
*/
|
||||
class ExecutionFrame {
|
||||
public:
|
||||
// Constructor for weight-less graph, used for higher order ops, e.g.
|
||||
// torch.cond
|
||||
explicit ExecutionFrame(const Graph& graph);
|
||||
|
||||
explicit ExecutionFrame(const Graph& graph, const Weights& weights);
|
||||
|
||||
// Constructor for testing purpose
|
||||
explicit ExecutionFrame(
|
||||
const Graph& graph,
|
||||
size_t numValues,
|
||||
const std::vector<ValueId>& graphInputIds,
|
||||
const std::vector<ValueId>& graphOutputIds);
|
||||
|
||||
~ExecutionFrame() {}
|
||||
|
||||
std::vector<c10::IValue> getUserOutputs() const;
|
||||
c10::List<c10::IValue> getUserOutputsAsTensorList() const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> getBufferMutations() const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> getAllOutputs() const;
|
||||
|
||||
const c10::IValue& getIValue(ValueId id, bool allowNone = true) const {
|
||||
const auto& iValue = allValues_[id];
|
||||
if (allowNone && iValue.isNone()) {
|
||||
return iValue;
|
||||
}
|
||||
DCHECK(!iValue.isNone());
|
||||
return iValue;
|
||||
}
|
||||
|
||||
c10::IValue& getIValue(ValueId id, bool allowNone = true) {
|
||||
auto& iValue = allValues_[id];
|
||||
if (allowNone && iValue.isNone()) {
|
||||
return iValue;
|
||||
}
|
||||
DCHECK(!iValue.isNone());
|
||||
return iValue;
|
||||
}
|
||||
|
||||
void setIValue(ValueId id, c10::IValue ivalue);
|
||||
|
||||
at::Tensor getTensor(ValueId id) const;
|
||||
|
||||
std::vector<at::Tensor> getTensorVector(ValueId id) const {
|
||||
return getIValue(id).toTensorVector();
|
||||
}
|
||||
|
||||
int64_t getSymInt(ValueId id) const {
|
||||
return getIValue(id).toInt();
|
||||
}
|
||||
|
||||
double getSymFloat(ValueId id) const {
|
||||
return getIValue(id).toDouble();
|
||||
}
|
||||
|
||||
void setPersistentIValue(ValueId id, c10::IValue ivalue) {
|
||||
setIValue(id, std::move(ivalue));
|
||||
persistent_[id] = true;
|
||||
}
|
||||
|
||||
void releaseValue(ValueId id) {
|
||||
CHECK(!persistent_[id]) << "Cannot release persistent value";
|
||||
allValues_[id] = c10::IValue();
|
||||
}
|
||||
|
||||
void releaseUserOutputs() {
|
||||
for (const auto& outputValue : graph_.userOutputs()) {
|
||||
if (std::holds_alternative<Value*>(outputValue)) {
|
||||
Value* valuePtr = std::get<Value*>(outputValue);
|
||||
if (valuePtr) {
|
||||
const auto& id = valuePtr->id();
|
||||
if (!persistent_[id]) {
|
||||
releaseValue(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void setWork(int64_t workId, const c10::intrusive_ptr<c10d::Work>& work) {
|
||||
work_[workId] = work;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::Work> getWork(int64_t workId) const {
|
||||
CHECK(work_.find(workId) != work_.end())
|
||||
<< "Couldn't find work with Id: " << workId;
|
||||
return work_.at(workId);
|
||||
}
|
||||
|
||||
WeightVersion weightVersion() const {
|
||||
return weightVersion_;
|
||||
}
|
||||
|
||||
void setWeights(const Weights& weights);
|
||||
|
||||
private:
|
||||
const Graph& graph_;
|
||||
WeightVersion weightVersion_ = -1;
|
||||
|
||||
// All the intermediate values for the entire graph, including graph inputs
|
||||
// and outputs This table is fixed once constructed
|
||||
std::vector<c10::IValue> allValues_;
|
||||
std::vector<bool> persistent_;
|
||||
|
||||
std::unordered_map<int64_t, c10::intrusive_ptr<c10d::Work>> work_;
|
||||
|
||||
std::unordered_map<std::string, ValueId> foldedConstIds_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
117
torch/csrc/nativert/executor/ExecutionPlanner.cpp
Normal file
117
torch/csrc/nativert/executor/ExecutionPlanner.cpp
Normal file
@ -0,0 +1,117 @@
|
||||
#include <unordered_map>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/Enumerate.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
std::unique_ptr<ExecutionPlan> ExecutionPlanner::createPlan() {
|
||||
auto plan = std::make_unique<ExecutionPlan>();
|
||||
|
||||
// Current implementation assume that nodes will be executed
|
||||
// in the same order as the thrift graph.
|
||||
// In the future, we can do execution order plan, as long as it's
|
||||
// comply with topological order
|
||||
|
||||
generateDeallocationPlan(*plan);
|
||||
return plan;
|
||||
}
|
||||
|
||||
/* static */ c10::FastSet<ValueId> ExecutionPlanner::staticValues(
|
||||
const Graph& graph) {
|
||||
c10::FastSet<ValueId> staticValues;
|
||||
// Filter lastUsedBy by graph inputs
|
||||
// parameters/buffer values should not be freed
|
||||
// It's a policy decision to whether to free user inputs. For now, we don't
|
||||
// free user inputs.
|
||||
// TODO: It should be fine to "free" the user inputs. If the user holds a ref
|
||||
// to it, it won't be deallocated.
|
||||
for (const auto* input : graph.inputs()) {
|
||||
if (input) {
|
||||
const auto& id = input->id();
|
||||
staticValues.insert(id);
|
||||
}
|
||||
}
|
||||
|
||||
// Filter lastUsedBy by graph outputs, as they are still needed to be returned
|
||||
for (const auto& output : graph.outputs()) {
|
||||
const auto& id = output->id();
|
||||
staticValues.insert(id);
|
||||
}
|
||||
|
||||
for (const auto& [id, _] : graph.getConstantSymIntValues()) {
|
||||
staticValues.insert(id);
|
||||
}
|
||||
|
||||
for (const Node& node : graph.nodes()) {
|
||||
if (node.target() == "torch.ops.higher_order.run_const_graph") {
|
||||
for (const auto& output : node.outputs()) {
|
||||
// Do not free the outputs of run_const_graph, as they are newly
|
||||
// produced folded constants
|
||||
staticValues.insert(output->id());
|
||||
}
|
||||
} else {
|
||||
for (const auto& input : node.inputs()) {
|
||||
if (input.value->isFolded()) {
|
||||
staticValues.insert(input.value->id());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return staticValues;
|
||||
}
|
||||
|
||||
void ExecutionPlanner::generateDeallocationPlan(ExecutionPlan& plan) {
|
||||
const auto& nodes = graph_.nodes();
|
||||
size_t numNodes = nodes.size();
|
||||
|
||||
std::unordered_map<ValueId, NodeIndex> lastUsedBy;
|
||||
|
||||
// Traverse from the last node to the first node
|
||||
// For each Value, find out which is the last node that uses it
|
||||
// the Value can freed after executing the node
|
||||
size_t nodeIdx = nodes.size() - 1;
|
||||
for (auto it = std::rbegin(nodes); it != std::rend(nodes); it++) {
|
||||
const auto& inputs = it->inputs();
|
||||
for (const auto& input : inputs) {
|
||||
const auto& id = input.value->id();
|
||||
if (lastUsedBy.find(id) == lastUsedBy.end()) {
|
||||
lastUsedBy.insert({id, nodeIdx});
|
||||
}
|
||||
}
|
||||
nodeIdx--;
|
||||
}
|
||||
|
||||
std::vector<std::vector<ValueId>> valuesToFree(numNodes);
|
||||
|
||||
const auto& statics = staticValues(graph_);
|
||||
for (auto& [id, nodeIndex] : lastUsedBy) {
|
||||
if (statics.find(id) == statics.end()) {
|
||||
valuesToFree[nodeIndex].push_back(id);
|
||||
}
|
||||
}
|
||||
|
||||
plan.valuesToFree = std::move(valuesToFree);
|
||||
|
||||
// print allocation plan
|
||||
VLOG(2) << plan;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ExecutionPlan& plan) {
|
||||
out << "****** Deallocation Plan ******\n";
|
||||
for (auto&& [i, values] : enumerate(plan.valuesToFree)) {
|
||||
out << "Node #" << i << ", valuesToFree = [";
|
||||
for (const auto& value : values) {
|
||||
out << value << ", ";
|
||||
}
|
||||
out << "]\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
41
torch/csrc/nativert/executor/ExecutionPlanner.h
Normal file
41
torch/csrc/nativert/executor/ExecutionPlanner.h
Normal file
@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class Graph;
|
||||
|
||||
// ExecutionPlan is the result produced by ExecutionPlanner
|
||||
// ATM, it only contains value deallocation plan.
|
||||
// In the future, it can include execution order plan, allocation plan for
|
||||
// parameter/gradient alignment, static memory plan for activation buffer reuse
|
||||
// ect...
|
||||
struct ExecutionPlan {
|
||||
// i-th entry in this list are the Values can be freed *after* execution i-th
|
||||
// node
|
||||
std::vector<std::vector<ValueId>> valuesToFree;
|
||||
};
|
||||
|
||||
class ExecutionPlanner {
|
||||
public:
|
||||
explicit ExecutionPlanner(const Graph& graph) : graph_(graph) {}
|
||||
|
||||
std::unique_ptr<ExecutionPlan> createPlan();
|
||||
// get list of values we can't free
|
||||
static c10::FastSet<ValueId> staticValues(const Graph& graph);
|
||||
|
||||
private:
|
||||
void generateDeallocationPlan(ExecutionPlan& plan);
|
||||
|
||||
// NYI
|
||||
void generatedMemoryPlan(ExecutionPlan& plan);
|
||||
|
||||
const Graph& graph_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ExecutionPlan& plan);
|
||||
|
||||
} // namespace torch::nativert
|
||||
273
torch/csrc/nativert/executor/Executor.cpp
Normal file
273
torch/csrc/nativert/executor/Executor.cpp
Normal file
@ -0,0 +1,273 @@
|
||||
#include "torch/csrc/nativert/executor/Executor.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/AutoTimer.h"
|
||||
#include "torch/csrc/nativert/common/Enumerate.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/ParallelGraphExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/SerialGraphExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/Weights.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
Executor::Executor(
|
||||
ExecutorConfig executorConfig,
|
||||
std::shared_ptr<Graph> graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const Placement& placement,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader)
|
||||
: executorConfig_(std::move(executorConfig)),
|
||||
graph_(std::move(graph)),
|
||||
placement_(placement),
|
||||
constantFolder_(
|
||||
executorConfig_.runConstFolding
|
||||
? std::optional<ConstantFolder>(*graph_)
|
||||
: std::nullopt),
|
||||
executionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||
numExecutionFrames_(0) {
|
||||
if (weights) {
|
||||
initialize(std::move(weights), std::move(pytorchStreamReader));
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::initialize(
|
||||
std::shared_ptr<Weights> weights,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader) {
|
||||
AutoTimer t("Initialization completed");
|
||||
|
||||
auto executionKernels = KernelFactory().initializeNodeKernels(
|
||||
*graph_,
|
||||
weights,
|
||||
executorConfig_,
|
||||
placement_,
|
||||
std::move(pytorchStreamReader));
|
||||
|
||||
if (constantFolder_.has_value()) {
|
||||
constantFolder_->unlinkConstants(executionKernels.nodeKernels);
|
||||
}
|
||||
|
||||
if (executorConfig_.maxParallelOps > 1) {
|
||||
graphExecutor_ = std::make_unique<ParallelGraphExecutor>(
|
||||
*graph_, std::move(executionKernels.nodeKernels), executorConfig_);
|
||||
} else {
|
||||
graphExecutor_ = std::make_unique<SerialGraphExecutor>(
|
||||
*graph_, std::move(executionKernels.nodeKernels), executorConfig_);
|
||||
}
|
||||
|
||||
delegateExecutors_ = std::move(executionKernels.delegateExecutors);
|
||||
constFoldingExecutions_ = std::move(executionKernels.constFoldingExecutions);
|
||||
|
||||
// initialize weights_
|
||||
processWeights(weights);
|
||||
atomicSwapWeights(std::move(weights));
|
||||
}
|
||||
|
||||
void Executor::atomicSwapWeights(std::shared_ptr<Weights> weights) {
|
||||
weights_.withLock([&](auto& w) { w = std::move(weights); });
|
||||
|
||||
// update weights in delegate executors
|
||||
for (auto& delegateExecutor : delegateExecutors_) {
|
||||
delegateExecutor->commitWeights();
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
|
||||
for (auto& execution : constFoldingExecutions_) {
|
||||
ExecutionFrame constFoldingFrame(execution.executor->graph());
|
||||
std::vector<c10::IValue> inputs;
|
||||
inputs.reserve(graph_->signature().inputsToWeights().size());
|
||||
for (const auto& [_, name] : graph_->signature().inputsToWeights()) {
|
||||
inputs.push_back(weights->at(name));
|
||||
}
|
||||
|
||||
auto outputs = execution.executor->execute(constFoldingFrame, inputs);
|
||||
for (const auto& [idx, value] :
|
||||
enumerate(execution.executor->graph().outputs())) {
|
||||
weights->updateFoldedConst(value->name(), outputs.at(idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::processWeights(std::shared_ptr<Weights> weights) {
|
||||
maybeRunConstantFolding(weights);
|
||||
if (constantFolder_.has_value()) {
|
||||
constantFolder_->evaluate(*weights);
|
||||
}
|
||||
for (auto& delegateExecutor : delegateExecutors_) {
|
||||
delegateExecutor->processWeights(weights);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void validateInput(
|
||||
const std::string& inputName,
|
||||
const at::Tensor& inputTensor,
|
||||
const TensorMeta& tensorValueMeta) {
|
||||
CHECK(inputTensor.dtype() == tensorValueMeta.dtype())
|
||||
<< "Input tensor dtype mismatch for " << inputName << ", expecting "
|
||||
<< c10::toString(tensorValueMeta.dtype()) << " but got "
|
||||
<< inputTensor.dtype().name();
|
||||
|
||||
CHECK(inputTensor.device() == tensorValueMeta.device())
|
||||
<< "Input tensor device mismatch for " << inputName << ", expecting "
|
||||
<< tensorValueMeta.device().str() << " but got "
|
||||
<< inputTensor.device().str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// validate input tensor's dtype matches tensorMeta
|
||||
void Executor::validateInputs(const std::vector<c10::IValue>& inputs) const {
|
||||
const auto& inputValues = graph_->userInputs();
|
||||
const auto& tensorValuesMeta = graph_->tensorValuesMeta();
|
||||
TORCH_CHECK(inputs.size() == inputValues.size(), "Input size mismatch");
|
||||
for (auto&& [i, actualInput] : enumerate(inputs)) {
|
||||
if (actualInput.isTensor()) {
|
||||
const auto& inputName = std::string(inputValues[i]->name());
|
||||
auto it = tensorValuesMeta.find(inputName);
|
||||
CHECK(it != tensorValuesMeta.end())
|
||||
<< "Couldn't find " << inputName << " in tensorValuesMeta";
|
||||
validateInput(inputName, actualInput.toTensor(), it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<ExecutionFrame> Executor::getExecutorFrameFromPool() {
|
||||
std::shared_ptr<Weights> weights;
|
||||
weights_.withLock([&](auto& w) { weights = w; });
|
||||
|
||||
std::unique_ptr<ExecutionFrame> frame;
|
||||
while (!executionFrames_.readIfNotEmpty(frame)) {
|
||||
int numFrames = numExecutionFrames_.load();
|
||||
if (numFrames < executorConfig_.maxNumConcurrentThreads) {
|
||||
if (numExecutionFrames_.compare_exchange_strong(
|
||||
numFrames, numFrames + 1)) {
|
||||
return std::make_unique<ExecutionFrame>(*graph_, *weights);
|
||||
}
|
||||
} else {
|
||||
sem_.acquire();
|
||||
}
|
||||
}
|
||||
|
||||
if (frame->weightVersion() != weights->version()) {
|
||||
frame->setWeights(*weights);
|
||||
}
|
||||
return frame;
|
||||
}
|
||||
|
||||
void Executor::returnExecutorFrameToPool(
|
||||
std::unique_ptr<ExecutionFrame> frame) {
|
||||
if (executorConfig_.enableStaticCPUKernels) {
|
||||
frame->releaseUserOutputs();
|
||||
}
|
||||
CHECK(executionFrames_.writeIfNotFull(std::move(frame)))
|
||||
<< "ExecutionFrame pool full";
|
||||
sem_.release();
|
||||
}
|
||||
|
||||
std::unique_ptr<ExecutionFrame> Executor::executeInternal(
|
||||
std::vector<c10::IValue> inputs) {
|
||||
if (executorConfig_.validateInputs) {
|
||||
validateInputs(inputs);
|
||||
}
|
||||
|
||||
auto executionFrame = getExecutorFrameFromPool();
|
||||
try {
|
||||
graphExecutor_->execute(*executionFrame, std::move(inputs));
|
||||
} catch (...) {
|
||||
returnExecutorFrameToPool(std::move(executionFrame));
|
||||
throw;
|
||||
}
|
||||
|
||||
return executionFrame;
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> Executor::execute(std::vector<c10::IValue> inputs) {
|
||||
auto executionFrame = executeInternal(std::move(inputs));
|
||||
auto outputs = executionFrame->getUserOutputs();
|
||||
|
||||
returnExecutorFrameToPool(std::move(executionFrame));
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> Executor::execute(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& inputTreeSpec) {
|
||||
auto executionFrame = getExecutorFrameFromPool();
|
||||
|
||||
std::optional<std::vector<c10::IValue>> outputs;
|
||||
try {
|
||||
const auto& userInputs = graph_->userInputs();
|
||||
const auto& tensorValuesMeta = graph_->tensorValuesMeta();
|
||||
TORCH_CHECK_EQ(userInputs.size(), inputTreeSpec.numLeaves());
|
||||
|
||||
size_t input_idx = 0;
|
||||
auto executionFrameFillUserInputs = [&](const c10::IValue& leaf) {
|
||||
auto value = userInputs[input_idx];
|
||||
// skip if the value is not used
|
||||
if (value && value->users().size() > 0) {
|
||||
// validate input tensor's dtype and device matches tensorMeta
|
||||
if (executorConfig_.validateInputs && leaf.isTensor()) {
|
||||
const auto& inputName = std::string(value->name());
|
||||
auto it = tensorValuesMeta.find(inputName);
|
||||
CHECK(it != tensorValuesMeta.end())
|
||||
<< "Couldn't find " << inputName << " in tensorValuesMeta";
|
||||
validateInput(inputName, leaf.toTensor(), it->second);
|
||||
}
|
||||
executionFrame->setIValue(value->id(), leaf);
|
||||
}
|
||||
input_idx++;
|
||||
};
|
||||
leafApplyFromArgs(
|
||||
executionFrameFillUserInputs, args, kwargs, inputTreeSpec);
|
||||
outputs = graphExecutor_->executeWithPrefilledFrame(*executionFrame);
|
||||
} catch (...) {
|
||||
returnExecutorFrameToPool(std::move(executionFrame));
|
||||
throw;
|
||||
}
|
||||
|
||||
returnExecutorFrameToPool(std::move(executionFrame));
|
||||
return *outputs;
|
||||
}
|
||||
|
||||
ProfileMetrics Executor::benchmarkIndividualNodes(
|
||||
std::vector<std::vector<c10::IValue>> inputsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns) {
|
||||
CHECK(inputsList.size() > 0) << "Need at least one input to benchmark";
|
||||
CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run";
|
||||
|
||||
for (const auto& inputs : inputsList) {
|
||||
for (const auto& input : inputs) {
|
||||
CHECK(input.isTensor() || input.isCustomClass())
|
||||
<< "For now, all graph inputs should be tensor, or custom class object, but got "
|
||||
<< input.tagKind();
|
||||
}
|
||||
if (executorConfig_.validateInputs) {
|
||||
validateInputs(inputs);
|
||||
}
|
||||
}
|
||||
auto executionFrame = getExecutorFrameFromPool();
|
||||
auto benchmarkResults = graphExecutor_->benchmarkIndividualNodes(
|
||||
*executionFrame, inputsList, warmupRuns, mainRuns);
|
||||
|
||||
returnExecutorFrameToPool(std::move(executionFrame));
|
||||
return benchmarkResults;
|
||||
}
|
||||
|
||||
std::vector<DelegateExecutor*> Executor::getDelegates() {
|
||||
std::vector<DelegateExecutor*> delegates;
|
||||
for (const auto& delegateExecutor : delegateExecutors_) {
|
||||
delegates.push_back(delegateExecutor.get());
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
119
torch/csrc/nativert/executor/Executor.h
Normal file
119
torch/csrc/nativert/executor/Executor.h
Normal file
@ -0,0 +1,119 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/MPMCQueue.h"
|
||||
#include "torch/csrc/nativert/common/Semaphore.h"
|
||||
#include "torch/csrc/nativert/executor/ConstantFolder.h"
|
||||
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
|
||||
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
|
||||
#include "torch/csrc/nativert/executor/Placement.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
#include "torch/csrc/nativert/graph/GraphSignature.h"
|
||||
#include "torch/csrc/nativert/kernels/KernelFactory.h"
|
||||
|
||||
#include "torch/csrc/nativert/common/Pytree.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class Weights;
|
||||
struct DistributedRunConfig;
|
||||
|
||||
/**
|
||||
* A very dumb executor. Basically just runs each node in order and contains a
|
||||
* giant unordered map for every intermediate, no optimizations applied.
|
||||
*/
|
||||
class Executor {
|
||||
public:
|
||||
// Constrcutor used for Inference Path
|
||||
Executor(
|
||||
ExecutorConfig executorConfig,
|
||||
std::shared_ptr<Graph> graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const Placement& placement = Placement(),
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader = nullptr);
|
||||
|
||||
std::shared_ptr<Weights> getWeights() {
|
||||
std::shared_ptr<Weights> ret;
|
||||
weights_.withLock([&](auto& w) { ret = w; });
|
||||
return ret;
|
||||
}
|
||||
|
||||
void processWeights(std::shared_ptr<Weights> weights);
|
||||
void atomicSwapWeights(std::shared_ptr<Weights> weights);
|
||||
|
||||
// This API only returns the flattened UserOutputs,
|
||||
// intended to be used for Inference path
|
||||
std::vector<c10::IValue> execute(std::vector<c10::IValue> inputs);
|
||||
|
||||
std::vector<c10::IValue> execute(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const TreeSpec& inputTreeSpec);
|
||||
|
||||
ProfileMetrics benchmarkIndividualNodes(
|
||||
std::vector<std::vector<c10::IValue>> inputsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns);
|
||||
|
||||
const GraphSignature& graphSignature() const {
|
||||
return graph_->signature();
|
||||
}
|
||||
|
||||
static std::string className() {
|
||||
return "Executor.v0";
|
||||
}
|
||||
|
||||
const ExecutorConfig& executorConfig() const {
|
||||
return executorConfig_;
|
||||
}
|
||||
|
||||
std::vector<DelegateExecutor*> getDelegates();
|
||||
|
||||
protected:
|
||||
ExecutorConfig executorConfig_;
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
// manages the parameters, buffers and tensor constants
|
||||
c10::Synchronized<std::shared_ptr<Weights>> weights_;
|
||||
|
||||
std::unique_ptr<ExecutionFrame> executeInternal(
|
||||
std::vector<c10::IValue> inputs);
|
||||
|
||||
void initialize(
|
||||
std::shared_ptr<Weights> weights,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader);
|
||||
|
||||
std::unique_ptr<ExecutionFrame> getExecutorFrameFromPool();
|
||||
void returnExecutorFrameToPool(std::unique_ptr<ExecutionFrame> frame);
|
||||
|
||||
private:
|
||||
void maybeRunConstantFolding(std::shared_ptr<Weights> weights);
|
||||
void validateInputs(const std::vector<c10::IValue>& inputs) const;
|
||||
|
||||
std::unique_ptr<GraphExecutorBase> graphExecutor_;
|
||||
|
||||
const Placement placement_;
|
||||
|
||||
// NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_.
|
||||
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors_;
|
||||
|
||||
std::vector<ConstFoldingExecution> constFoldingExecutions_;
|
||||
|
||||
std::optional<ConstantFolder> constantFolder_;
|
||||
|
||||
Semaphore sem_;
|
||||
MPMCQueue<std::unique_ptr<ExecutionFrame>> executionFrames_;
|
||||
std::atomic_int numExecutionFrames_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
27
torch/csrc/nativert/executor/ExecutorConfig.h
Normal file
27
torch/csrc/nativert/executor/ExecutorConfig.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
struct ExecutorConfig {
|
||||
bool validateInputs = false;
|
||||
|
||||
bool debugNan = false;
|
||||
|
||||
// allows up to max number of concurrent threads.
|
||||
int64_t maxNumConcurrentThreads = 8;
|
||||
|
||||
// allows up to max number of parallel ops.
|
||||
int64_t maxParallelOps = 1;
|
||||
|
||||
bool enableStaticCPUKernels = false;
|
||||
|
||||
bool enableStaticMemoryPlanning = false;
|
||||
|
||||
std::string modelName = "unknown";
|
||||
|
||||
bool runConstFolding = false;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
120
torch/csrc/nativert/executor/GraphExecutorBase.cpp
Normal file
120
torch/csrc/nativert/executor/GraphExecutorBase.cpp
Normal file
@ -0,0 +1,120 @@
|
||||
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
|
||||
#include "torch/csrc/nativert/common/RecordFunction.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <caffe2/core/timer.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
GraphExecutorBase::GraphExecutorBase(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const ExecutorConfig& executorConfig)
|
||||
: graph_(graph),
|
||||
nodeKernels_(std::move(nodeKernels)),
|
||||
executorConfig_(executorConfig),
|
||||
execPlan_(ExecutionPlanner{graph_}.createPlan()){};
|
||||
|
||||
void GraphExecutorBase::fillUserInputs(
|
||||
ExecutionFrame& frame,
|
||||
std::vector<c10::IValue> inputs) {
|
||||
RecordFunction recordFunction("Executor::fillUserInputs");
|
||||
const auto& inputValues = graph_.userInputs();
|
||||
TORCH_CHECK_EQ(inputValues.size(), inputs.size());
|
||||
|
||||
// load user input tensor into execution frame
|
||||
for (size_t i = 0; i < inputValues.size(); i++) {
|
||||
if (inputValues[i]) {
|
||||
frame.setIValue(inputValues[i]->id(), std::move(inputs[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<std::vector<c10::IValue>> inputsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns) {
|
||||
// TODO: add support for memory profiling
|
||||
TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1);
|
||||
|
||||
ProfileMetrics results;
|
||||
const auto numNodes = static_cast<uint32_t>(nodeKernels_.size());
|
||||
results.timePerNode.resize(numNodes, 0);
|
||||
if (inputsList.empty()) {
|
||||
auto i = 0;
|
||||
for (const auto& nodeKernel : nodeKernels_) {
|
||||
std::string target(nodeKernel->node()->target());
|
||||
results.timePerNode[i] = 0;
|
||||
results.timePerNodeType[target] = 0;
|
||||
results.instancesPerNodeType[target]++;
|
||||
if (nodeKernel->hasPrimKernel()) {
|
||||
results.primNodesCount++;
|
||||
results.primNodes.insert(target);
|
||||
} else if (nodeKernel->hasStaticDispatch()) {
|
||||
results.staticDispatchNodesCount++;
|
||||
results.staticDispatchNodes.insert(target);
|
||||
}
|
||||
i++;
|
||||
}
|
||||
results.totalNodesCount = numNodes;
|
||||
for (const auto& p : results.timePerNodeType) {
|
||||
const std::string& kind = p.first;
|
||||
results.percentPerNodeType[kind] = 0;
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// Warmup
|
||||
for (auto i = 0; i < warmupRuns; i++) {
|
||||
for (const auto& inputs : inputsList) {
|
||||
execute(executionFrame, inputs);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute kernels
|
||||
caffe2::Timer timer;
|
||||
for (auto i = 0; i < mainRuns; i++) {
|
||||
for (auto inputs : inputsList) {
|
||||
const auto& inputValues = graph_.userInputs();
|
||||
|
||||
TORCH_CHECK_EQ(inputValues.size(), inputs.size());
|
||||
for (size_t j = 0; j < inputValues.size(); j++) {
|
||||
executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j]));
|
||||
}
|
||||
for (NodeIndex nodeIdx = 0; nodeIdx < nodeKernels_.size(); ++nodeIdx) {
|
||||
timer.Start();
|
||||
nodeKernels_[nodeIdx]->compute(executionFrame);
|
||||
float millis = timer.MilliSeconds();
|
||||
results.timePerNode[nodeIdx] += millis;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Summarize results
|
||||
const float numTotalIters =
|
||||
(static_cast<float>(mainRuns) * static_cast<float>(inputsList.size()));
|
||||
for (const auto i : c10::irange(numNodes)) {
|
||||
const Node* node = nodeKernels_[i]->node();
|
||||
std::string target(node->target());
|
||||
results.timePerNode[i] /= numTotalIters;
|
||||
results.timePerNodeType[target] += results.timePerNode[i];
|
||||
results.instancesPerNodeType[target]++;
|
||||
if (nodeKernels_[i]->hasPrimKernel()) {
|
||||
results.primNodes.insert(target);
|
||||
results.primNodesCount++;
|
||||
} else if (nodeKernels_[i]->hasStaticDispatch()) {
|
||||
results.staticDispatchNodes.insert(target);
|
||||
results.staticDispatchNodesCount++;
|
||||
}
|
||||
results.totalTime += results.timePerNode[i];
|
||||
}
|
||||
results.totalNodesCount = numNodes;
|
||||
for (const auto& r : results.timePerNodeType) {
|
||||
const std::string& target = r.first;
|
||||
results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime;
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
79
torch/csrc/nativert/executor/GraphExecutorBase.h
Normal file
79
torch/csrc/nativert/executor/GraphExecutorBase.h
Normal file
@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
#include "torch/csrc/nativert/graph/GraphSignature.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
struct ProfileMetrics {
|
||||
size_t primNodesCount{0};
|
||||
size_t staticDispatchNodesCount{0};
|
||||
size_t totalNodesCount{0};
|
||||
std::vector<float> timePerNode;
|
||||
std::unordered_map<std::string, float> timePerNodeType;
|
||||
std::unordered_map<std::string, float> percentPerNodeType;
|
||||
std::unordered_map<std::string, int> instancesPerNodeType;
|
||||
std::unordered_set<std::string> staticDispatchNodes;
|
||||
std::unordered_set<std::string> primNodes;
|
||||
float totalTime{0};
|
||||
};
|
||||
|
||||
/**
|
||||
* GraphExecutor is a lightweight abstraction to execute a graph with
|
||||
* execution frames without actually owning the graph nor the weights. This is
|
||||
* introduced to decouple the state management of the top level runtime from the
|
||||
* kernel executions so that sub graphs from higher order ops can be supported.
|
||||
*/
|
||||
class GraphExecutorBase {
|
||||
public:
|
||||
GraphExecutorBase(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const ExecutorConfig& executorConfig);
|
||||
virtual ~GraphExecutorBase() = default;
|
||||
|
||||
const Graph& graph() const {
|
||||
return graph_;
|
||||
}
|
||||
|
||||
// This API only returns the flattened UserOutputs,
|
||||
// intended to be used for Inference path
|
||||
virtual std::vector<c10::IValue> execute(
|
||||
ExecutionFrame& frame,
|
||||
std::vector<c10::IValue> inputs) = 0;
|
||||
|
||||
virtual std::vector<c10::IValue> executeWithPrefilledFrame(
|
||||
ExecutionFrame& frame) = 0;
|
||||
|
||||
ProfileMetrics benchmarkIndividualNodes(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<std::vector<c10::IValue>> inputs,
|
||||
const uint32_t warmup_runs,
|
||||
const uint32_t main_runs);
|
||||
|
||||
std::vector<std::unique_ptr<OpKernel>> stealKernels() {
|
||||
return std::move(nodeKernels_);
|
||||
}
|
||||
|
||||
void setKernels(std::vector<std::unique_ptr<OpKernel>>&& kernels) {
|
||||
nodeKernels_ = std::move(kernels);
|
||||
}
|
||||
|
||||
protected:
|
||||
void fillUserInputs(ExecutionFrame& frame, std::vector<c10::IValue> inputs);
|
||||
|
||||
const Graph& graph_;
|
||||
|
||||
// cache of the constructed kernels to avoid reconstruction per execution
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels_;
|
||||
|
||||
const ExecutorConfig& executorConfig_;
|
||||
|
||||
std::unique_ptr<ExecutionPlan> execPlan_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
389
torch/csrc/nativert/executor/ModelRunnerBase.cpp
Normal file
389
torch/csrc/nativert/executor/ModelRunnerBase.cpp
Normal file
@ -0,0 +1,389 @@
|
||||
|
||||
|
||||
#include "torch/csrc/nativert/executor/ModelRunnerBase.h"
|
||||
|
||||
#include "torch/csrc/nativert/common/RecordFunction.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/Weights.h"
|
||||
#include "torch/csrc/nativert/graph/TensorMeta.h"
|
||||
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
ModelRunnerBase::ModelRunnerBase(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
const std::function<Placement(const torch::nativert::Graph& graph)>&
|
||||
buildPlacementFn)
|
||||
: modelName_(modelName),
|
||||
executorType_(executorType),
|
||||
runtimeConfigs_(runtimeConfigs) {}
|
||||
|
||||
void ModelRunnerBase::setExecutorType(
|
||||
ExecutorType type,
|
||||
const std::string& platformArch) {
|
||||
LOG(INFO) << fmt::format(
|
||||
"Setting executor type to {} with platformArch='{}'", type, platformArch);
|
||||
executorType_ = type;
|
||||
if (type == ExecutorType::AOTINDUCTOR) {
|
||||
runtimeConfigs_.platformArch = platformArch;
|
||||
} else if (type == ExecutorType::MTIA) {
|
||||
// TODO: hardcoded for now (Sigmoid packages specify platformArch as
|
||||
// "mtia")
|
||||
runtimeConfigs_.platformArch = "mtia";
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& ModelRunnerBase::getModelName() const {
|
||||
return modelName_;
|
||||
}
|
||||
|
||||
std::shared_ptr<Weights> ModelRunnerBase::getWeights() {
|
||||
if (executor_ != nullptr) {
|
||||
return executor_->getWeights();
|
||||
} else if (newWeights_ != nullptr) {
|
||||
return newWeights_;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "ModelRunner is not initialized, and no weights are loaded.");
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Weights> ModelRunnerBase::loadNewWeights(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageStreamReader,
|
||||
std::function<bool(const std::string&)> skipSizeCheck,
|
||||
std::function<bool(const std::string&)> skipDtypeCheck) {
|
||||
LOG(INFO) << "ModelRunner loading new weights";
|
||||
newWeights_ = std::make_shared<Weights>(
|
||||
graph_.get(),
|
||||
packageStreamReader,
|
||||
stateDictPath_,
|
||||
archive_spec::WEIGHTS_DIR,
|
||||
constantPaths_,
|
||||
archive_spec::CONSTANTS_DIR,
|
||||
placement_,
|
||||
std::move(skipSizeCheck),
|
||||
std::move(skipDtypeCheck));
|
||||
|
||||
return newWeights_;
|
||||
}
|
||||
|
||||
void ModelRunnerBase::commitNewWeights() {
|
||||
CHECK(newWeights_) << "No new weights loaded";
|
||||
CHECK(executor_) << "ModelRunner not initialized";
|
||||
|
||||
executor_->processWeights(newWeights_);
|
||||
|
||||
executor_->atomicSwapWeights(std::move(newWeights_));
|
||||
|
||||
newWeights_ = nullptr;
|
||||
}
|
||||
|
||||
bool ModelRunnerBase::loadExtraFiles(
|
||||
ExtraFilesMap& extraFiles,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader) {
|
||||
auto filesExist = false;
|
||||
for (const auto& kv : extraFiles) {
|
||||
const auto key = std::string{archive_spec::EXTRA_DIR} + kv.first;
|
||||
if (pytorchStreamReader->hasRecord(key)) {
|
||||
auto [metaPtr, metaSize] = pytorchStreamReader->getRecord(key);
|
||||
extraFiles[kv.first] =
|
||||
std::string(static_cast<char*>(metaPtr.get()), metaSize);
|
||||
filesExist = true;
|
||||
}
|
||||
}
|
||||
return filesExist;
|
||||
}
|
||||
|
||||
c10::IValue ModelRunnerBase::run(
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const RunConfigs& runConfigs) {
|
||||
return run({}, kwargs, runConfigs);
|
||||
}
|
||||
|
||||
c10::IValue ModelRunnerBase::run(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const RunConfigs& runConfigs) {
|
||||
RecordFunction recordFunction("nativert::ModelRunner::run");
|
||||
|
||||
CHECK(executor_) << "ModelRunner not initialized";
|
||||
|
||||
// ModelRunner is only used for inference
|
||||
c10::InferenceMode mode;
|
||||
|
||||
return treeUnflatten(
|
||||
executor_->execute(args, kwargs, inputSpec_), outputSpec_);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ModelRunnerBase::runWithFlatInputsAndOutputs(
|
||||
std::vector<c10::IValue>&& flatInputs,
|
||||
const RunConfigs& /* runConfigs */) {
|
||||
RecordFunction recordFunction(
|
||||
"nativert::ModelRunner::runWithFlatInputsAndOutputs");
|
||||
|
||||
CHECK(executor_) << "ModelRunner not initialized";
|
||||
|
||||
// ModelRunner is only used for inference
|
||||
c10::InferenceMode mode;
|
||||
|
||||
return executor_->execute(std::move(flatInputs));
|
||||
}
|
||||
|
||||
void ModelRunnerBase::benchmarkIndividualNodes(
|
||||
const std::vector<std::vector<c10::IValue>>& argsList,
|
||||
const std::vector<std::unordered_map<std::string, c10::IValue>>& kwargsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns,
|
||||
const bool printPerNodeTime,
|
||||
const RunConfigs& runConfigs) {
|
||||
std::vector<std::vector<c10::IValue>> flatInputsList;
|
||||
for (const auto& args : argsList) {
|
||||
if (!kwargsList.empty()) {
|
||||
for (const auto& kwargs : kwargsList) {
|
||||
flatInputsList.emplace_back(
|
||||
treeFlattenFromArgs(args, kwargs, inputSpec_));
|
||||
}
|
||||
} else {
|
||||
flatInputsList.emplace_back(treeFlattenFromArgs(args, {}, inputSpec_));
|
||||
}
|
||||
}
|
||||
c10::InferenceMode mode;
|
||||
auto results =
|
||||
executor_->benchmarkIndividualNodes(flatInputsList, warmupRuns, mainRuns);
|
||||
|
||||
if (printPerNodeTime) {
|
||||
size_t i = 0;
|
||||
for (const auto& node : graph_->nodes()) {
|
||||
LOG(INFO) << "Node #" << i << ": " << node.toString()
|
||||
<< "\n Time: " << results.timePerNode[i] << " ms/iter, ";
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, double>> sortedTimePerOp{
|
||||
results.timePerNodeType.begin(), results.timePerNodeType.end()};
|
||||
if (argsList.empty()) {
|
||||
// alphabetical sort
|
||||
std::sort(
|
||||
sortedTimePerOp.begin(),
|
||||
sortedTimePerOp.end(),
|
||||
[&results](auto& left, auto& right) {
|
||||
return results.instancesPerNodeType[left.first] >
|
||||
results.instancesPerNodeType[right.first];
|
||||
});
|
||||
} else {
|
||||
// sort by time
|
||||
std::sort(
|
||||
sortedTimePerOp.begin(),
|
||||
sortedTimePerOp.end(),
|
||||
[](auto& left, auto& right) { return left.second > right.second; });
|
||||
}
|
||||
|
||||
LOG(INFO) << "Time per node type:" << '\n';
|
||||
std::ostringstream unsupportedNodeKinds;
|
||||
for (const auto& p : sortedTimePerOp) {
|
||||
const std::string& kind = p.first;
|
||||
const double ms = p.second;
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << std::setw(15) << ms << " ms. " << std::setw(10)
|
||||
<< results.percentPerNodeType[kind] << "%. " << kind << " ("
|
||||
<< results.instancesPerNodeType[kind] << " nodes";
|
||||
if (results.primNodes.count(kind)) {
|
||||
oss << ", prim) \n";
|
||||
} else if (results.staticDispatchNodes.count(kind)) {
|
||||
oss << ", static dispatch) \n";
|
||||
} else {
|
||||
unsupportedNodeKinds << kind << ", ";
|
||||
oss << ")\n";
|
||||
}
|
||||
LOG(INFO) << oss.str();
|
||||
}
|
||||
LOG(INFO) << std::setw(15) << results.totalTime << " ms. in Total" << '\n';
|
||||
LOG(INFO) << "Number of nodes: " << graph_->nodes().size() << '\n';
|
||||
|
||||
auto unsupportedCount = results.totalNodesCount -
|
||||
results.staticDispatchNodesCount - results.primNodesCount;
|
||||
LOG(INFO) << "Total number of static dispatch nodes/total number of nodes: "
|
||||
<< results.staticDispatchNodesCount << "/"
|
||||
<< results.totalNodesCount << " ("
|
||||
<< 100.0 * static_cast<float>(results.staticDispatchNodesCount) /
|
||||
static_cast<float>(results.totalNodesCount)
|
||||
<< "%)" << '\n';
|
||||
LOG(INFO) << "Total number of prim nodes/total number of nodes: "
|
||||
<< results.primNodesCount << "/" << results.totalNodesCount << " ("
|
||||
<< 100.0 * static_cast<float>(results.primNodesCount) /
|
||||
static_cast<float>(results.totalNodesCount)
|
||||
<< "%)" << '\n';
|
||||
LOG(INFO)
|
||||
<< "Total number of nodes not covered by static dispatch/total number of nodes: "
|
||||
<< unsupportedCount << "/" << results.totalNodesCount << " ("
|
||||
<< 100.0 * static_cast<float>(unsupportedCount) /
|
||||
static_cast<float>(results.totalNodesCount)
|
||||
<< "%)" << "\n Uncovered node kinds: " << unsupportedNodeKinds.str()
|
||||
<< '\n';
|
||||
}
|
||||
|
||||
std::vector<std::optional<c10::Device>>
|
||||
ModelRunnerBase::getUserInputTargetDevices() const {
|
||||
std::vector<std::optional<c10::Device>> devices;
|
||||
for (const auto& tensorMeta : graph_->userInputsMeta()) {
|
||||
c10::Device targetDevice = placement_.getMappedDevice(tensorMeta.device());
|
||||
devices.push_back(targetDevice);
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
std::pair<
|
||||
std::vector<c10::IValue>,
|
||||
std::unordered_map<std::string, c10::IValue>>
|
||||
ModelRunnerBase::loadSampleInputs(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const Placement& placement) {
|
||||
LOG(INFO) << "Loading sample inputs in ModelRunner for model " << modelName_;
|
||||
|
||||
std::string sampleInputsPath =
|
||||
fmt::format(archive_spec::SAMPLE_INPUTS_FILENAME_FORMAT, modelName_);
|
||||
|
||||
CHECK(pytorchStreamReader->hasRecord(sampleInputsPath))
|
||||
<< sampleInputsPath << " is not found in package";
|
||||
size_t size = pytorchStreamReader->getRecordSize(sampleInputsPath);
|
||||
std::vector<char> buffers(size);
|
||||
size_t sizeRead =
|
||||
pytorchStreamReader->getRecord(sampleInputsPath, buffers.data(), size);
|
||||
CHECK(sizeRead == size);
|
||||
|
||||
c10::IValue value = torch::jit::pickle_load(buffers);
|
||||
|
||||
// Move userInputs on the target device
|
||||
std::vector<TensorMeta> userInputsMeta = graph_->userInputsMeta();
|
||||
size_t tensorInputId = 0;
|
||||
value = treeMap(
|
||||
[&](const c10::IValue& inputVal) -> c10::IValue {
|
||||
if (inputVal.isTensor()) {
|
||||
auto& tensorMeta = userInputsMeta[tensorInputId];
|
||||
|
||||
c10::Device targetDevice =
|
||||
placement.getMappedDevice(tensorMeta.device());
|
||||
auto r = inputVal.toTensor().to(targetDevice);
|
||||
|
||||
VLOG(1) << "input #" << tensorInputId << " has been placed on "
|
||||
<< targetDevice;
|
||||
tensorInputId++;
|
||||
return r;
|
||||
} else {
|
||||
return inputVal;
|
||||
}
|
||||
},
|
||||
value,
|
||||
inputSpec_);
|
||||
|
||||
CHECK(value.isTuple());
|
||||
CHECK(value.toTupleRef().elements().size() == 2);
|
||||
const auto& argsVal = value.toTupleRef().elements().at(0);
|
||||
const auto& kwargsVal = value.toTupleRef().elements().at(1);
|
||||
CHECK(argsVal.isTuple());
|
||||
CHECK(kwargsVal.isGenericDict());
|
||||
|
||||
std::vector<c10::IValue> args;
|
||||
for (const auto& arg : argsVal.toTupleRef().elements()) {
|
||||
args.push_back(arg);
|
||||
}
|
||||
std::unordered_map<std::string, c10::IValue> kwargs;
|
||||
for (const auto& entry : kwargsVal.toGenericDict()) {
|
||||
kwargs[entry.key().toStringRef()] = entry.value();
|
||||
}
|
||||
return {std::move(args), std::move(kwargs)};
|
||||
}
|
||||
|
||||
const std::vector<std::string>& ModelRunnerBase::getArgumentNames() const {
|
||||
return graph_->signature().userInputsVec();
|
||||
}
|
||||
|
||||
c10::ArrayRef<const Value*> ModelRunnerBase::getArguments() const {
|
||||
return graph_->userInputs();
|
||||
}
|
||||
|
||||
const TreeSpec& ModelRunnerBase::getOutputSpec() const {
|
||||
return outputSpec_;
|
||||
}
|
||||
|
||||
const TreeSpec& ModelRunnerBase::getInputSpec() const {
|
||||
return inputSpec_;
|
||||
}
|
||||
|
||||
std::string ModelRunnerBase::loadSerializedModel(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader)
|
||||
const {
|
||||
std::string modelFilePath =
|
||||
fmt::format(archive_spec::MODELS_FILENAME_FORMAT, modelName_);
|
||||
LOG(INFO) << "Loading model from: " << modelFilePath;
|
||||
|
||||
CHECK(pytorchStreamReader->hasRecord(modelFilePath))
|
||||
<< modelFilePath << " not found in package";
|
||||
const auto& [modelData, modelSize] =
|
||||
pytorchStreamReader->getRecord(modelFilePath);
|
||||
const std::string modelSerialized{
|
||||
reinterpret_cast<char*>(modelData.get()), modelSize};
|
||||
return modelSerialized;
|
||||
}
|
||||
|
||||
void ModelRunnerBase::initialize(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader) {
|
||||
if (executor_ != nullptr) {
|
||||
LOG(WARNING)
|
||||
<< "ModelRunner already initialized, re-initialization is an no op.";
|
||||
return;
|
||||
}
|
||||
|
||||
initializeExecutor(newWeights_, std::move(pytorchStreamReader));
|
||||
newWeights_ = nullptr;
|
||||
}
|
||||
|
||||
void ModelRunnerBase::initializeExecutor(
|
||||
std::shared_ptr<Weights> weights,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader) {
|
||||
CHECK(executor_ == nullptr) << "ModelRunner already initialized";
|
||||
weights->validateAllWeightsLoaded();
|
||||
|
||||
ExecutorConfig config;
|
||||
config.maxParallelOps = runtimeConfigs_.maxParallelOps;
|
||||
config.validateInputs = runtimeConfigs_.validateInputs;
|
||||
config.enableStaticCPUKernels = runtimeConfigs_.enableStaticCPUKernels;
|
||||
config.maxNumConcurrentThreads = runtimeConfigs_.maxNumConcurrentThreads;
|
||||
config.enableStaticMemoryPlanning =
|
||||
runtimeConfigs_.enableStaticMemoryPlanning;
|
||||
config.runConstFolding = runtimeConfigs_.enableRuntimeConstFolding;
|
||||
|
||||
if (executorType_ == ExecutorType::INTERPRETER) {
|
||||
executor_ = std::make_unique<Executor>(
|
||||
config, graph_, weights, placement_, pytorchStreamReader);
|
||||
} else if (executorType_ == ExecutorType::AOTINDUCTOR) {
|
||||
delegateGraph_ = deserializeDelegateGraph();
|
||||
delegateGraph_->applyDevicePlacement(placement_);
|
||||
VLOG(1) << "Delegate graph: \n" << *delegateGraph_;
|
||||
|
||||
executor_ = std::make_unique<Executor>(
|
||||
config, delegateGraph_, weights, placement_, pytorchStreamReader);
|
||||
} else if (executorType_ == ExecutorType::MTIA) {
|
||||
delegateGraph_ = deserializeDelegateGraph();
|
||||
delegateGraph_->applyDevicePlacement(placement_);
|
||||
VLOG(1) << "Delegate graph: \n" << *delegateGraph_;
|
||||
config.modelName = modelName_;
|
||||
executor_ = std::make_unique<Executor>(
|
||||
config, delegateGraph_, weights, placement_, pytorchStreamReader);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
260
torch/csrc/nativert/executor/ModelRunnerBase.h
Normal file
260
torch/csrc/nativert/executor/ModelRunnerBase.h
Normal file
@ -0,0 +1,260 @@
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/Pytree.h"
|
||||
#include "torch/csrc/nativert/executor/Executor.h"
|
||||
#include "torch/csrc/nativert/executor/Placement.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
||||
|
||||
enum class ExecutorType {
|
||||
INTERPRETER = 0,
|
||||
AOTINDUCTOR = 1,
|
||||
MTIA = 2,
|
||||
};
|
||||
|
||||
struct BaseRuntimeConfigs {
|
||||
bool isDebug = false;
|
||||
|
||||
bool validateInputs = false;
|
||||
|
||||
// use static kernels
|
||||
bool enableStaticCPUKernels = false;
|
||||
|
||||
// whether to enable static memory planning
|
||||
bool enableStaticMemoryPlanning = false;
|
||||
|
||||
// whether to load node's metadata, e.g. stacktrace etc.
|
||||
// This is only used for debugging purpose. For production, we commonly set
|
||||
// this to false, as it would incur extra memory usage.
|
||||
bool loadNodeMetadata = false;
|
||||
|
||||
// whether to initialize the executor in the constructor
|
||||
// In some cases, Weights are not available when the constructor is called.
|
||||
// In this case, the executor should be initialized later.
|
||||
bool deferInitialization = false;
|
||||
|
||||
// platform arch for delegate, e.g "cpu", "sm80_x86" etc
|
||||
// See https://fburl.com/code/3pym9ipj for supported platforms
|
||||
std::string platformArch;
|
||||
|
||||
// allows up to max number of concurrent threads.
|
||||
int64_t maxNumConcurrentThreads = 8;
|
||||
|
||||
// allows up to max number of parallel ops.
|
||||
int64_t maxParallelOps = 1;
|
||||
|
||||
// whether to enable runtime const folding
|
||||
bool enableRuntimeConstFolding = false;
|
||||
};
|
||||
|
||||
struct RunConfigs {};
|
||||
|
||||
class TORCH_API ModelRunnerBase {
|
||||
public:
|
||||
ModelRunnerBase(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader,
|
||||
const std::string& modelName,
|
||||
ExecutorType executorType,
|
||||
const BaseRuntimeConfigs& runtimeConfigs,
|
||||
// functor to build the placement after the graph is loaded, but before
|
||||
// loading the weights.
|
||||
const std::function<Placement(const torch::nativert::Graph& graph)>&
|
||||
buildPlacementFn);
|
||||
|
||||
ModelRunnerBase(ModelRunnerBase&&) = default;
|
||||
ModelRunnerBase& operator=(ModelRunnerBase&&) = default;
|
||||
|
||||
ModelRunnerBase(const ModelRunnerBase&) = delete;
|
||||
ModelRunnerBase& operator=(const ModelRunnerBase&) = delete;
|
||||
|
||||
virtual ~ModelRunnerBase() = default;
|
||||
|
||||
const std::string& getModelName() const;
|
||||
|
||||
std::shared_ptr<Weights> getWeights();
|
||||
|
||||
// loadNewWeights() loads the weights from the specified model into
|
||||
// the newWeights_ buffer. The weights will stay shadow and should be
|
||||
// actualized by the commitNewWeights() function.
|
||||
std::shared_ptr<Weights> loadNewWeights(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
packageStreamReader,
|
||||
std::function<bool(const std::string&)> skipSizeCheck = {},
|
||||
std::function<bool(const std::string&)> skipDtypeCheck = {});
|
||||
|
||||
void commitNewWeights();
|
||||
|
||||
c10::IValue run(
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const RunConfigs& runConfigs = RunConfigs());
|
||||
|
||||
c10::IValue run(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
||||
const RunConfigs& runConfigs = RunConfigs());
|
||||
|
||||
/**
|
||||
* A low level API which expects user to always pass in flattened inputs.
|
||||
* The ownership of the entire input list must be transferred to the
|
||||
* executor via std::move or in-place construction.
|
||||
*/
|
||||
std::vector<c10::IValue> runWithFlatInputsAndOutputs(
|
||||
std::vector<c10::IValue>&& flatInputs,
|
||||
const RunConfigs& runConfigs = RunConfigs());
|
||||
|
||||
void benchmarkIndividualNodes(
|
||||
const std::vector<std::vector<c10::IValue>>& argsList,
|
||||
const std::vector<std::unordered_map<std::string, c10::IValue>>&
|
||||
kwargsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns,
|
||||
const bool printPerNodeTime,
|
||||
const RunConfigs& runConfigs);
|
||||
|
||||
std::vector<std::optional<c10::Device>> getUserInputTargetDevices() const;
|
||||
|
||||
std::pair<
|
||||
std::vector<c10::IValue>,
|
||||
std::unordered_map<std::string, c10::IValue>>
|
||||
loadSampleInputs(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader,
|
||||
const Placement& placement = Placement());
|
||||
|
||||
const std::vector<std::string>& getArgumentNames() const;
|
||||
c10::ArrayRef<const Value*> getArguments() const;
|
||||
|
||||
/*
|
||||
* Load extra files indicated in extraFiles from the model package.
|
||||
* Return true if any extra files were loaded
|
||||
* and false otherwise
|
||||
*/
|
||||
bool loadExtraFiles(
|
||||
ExtraFilesMap& extraFiles,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader);
|
||||
|
||||
const TreeSpec& getOutputSpec() const;
|
||||
const TreeSpec& getInputSpec() const;
|
||||
|
||||
void setExecutorType(ExecutorType type, const std::string& platformArch = "");
|
||||
|
||||
ExecutorType getExecutorType() const {
|
||||
return executorType_;
|
||||
}
|
||||
|
||||
void setEnableStaticDispatchKernels(bool enabled) {
|
||||
runtimeConfigs_.enableStaticCPUKernels = enabled;
|
||||
}
|
||||
|
||||
void setEnableStaticMemoryPlanning(bool enabled) {
|
||||
runtimeConfigs_.enableStaticMemoryPlanning = enabled;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T*> getDelegates() {
|
||||
std::vector<T*> delegates;
|
||||
for (const auto& delegate : executor_->getDelegates()) {
|
||||
if (auto* d = dynamic_cast<T*>(delegate)) {
|
||||
delegates.push_back(d);
|
||||
}
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
|
||||
// Manually initialize the executor when config.deferInitialization is True.
|
||||
//
|
||||
// initlaize() must be call after
|
||||
// - weights are fully loaded
|
||||
// - executor is selected via setExecutorType()
|
||||
// ModelRunner is not ready to serve with run() until initlaized.
|
||||
//
|
||||
// Note that pytorchStreamReader is required to load lowered modules.
|
||||
//
|
||||
// When initialization failed at any point, it will throw an exception. Caller
|
||||
// should catch the exception and call initialize() again with valid weights
|
||||
// and executor type.
|
||||
//
|
||||
// ModelRunner can only be initialized once, it will be noop to call this
|
||||
// function when ModelRunner is already ready to serve.
|
||||
void initialize(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader = nullptr);
|
||||
|
||||
protected:
|
||||
#ifdef ModelRunnerTest_TEST_FRIENDS
|
||||
ModelRunnerTest_TEST_FRIENDS;
|
||||
#endif
|
||||
|
||||
virtual std::unique_ptr<Graph> deserializeDelegateGraph() const = 0;
|
||||
void initializeExecutor(
|
||||
std::shared_ptr<Weights> weights,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader);
|
||||
|
||||
std::string loadSerializedModel(
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader) const;
|
||||
|
||||
std::string modelName_;
|
||||
ExecutorType executorType_;
|
||||
BaseRuntimeConfigs runtimeConfigs_;
|
||||
|
||||
Placement placement_;
|
||||
|
||||
// original non-delegated graph from torch.export()
|
||||
std::shared_ptr<Graph> graph_;
|
||||
|
||||
// graph with delegated nodes after lowering/compilation
|
||||
std::shared_ptr<Graph> delegateGraph_;
|
||||
|
||||
// key is weight name, value is archive path for weight
|
||||
std::unordered_map<std::string, std::string> stateDictPath_;
|
||||
|
||||
// contains both tensor constants and CustomClassHolder (aka. torchbind
|
||||
// object)
|
||||
std::unordered_map<std::string, std::string> constantPaths_;
|
||||
|
||||
std::unique_ptr<Executor> executor_;
|
||||
|
||||
// recently loaded and not yet committed weights.
|
||||
std::shared_ptr<Weights> newWeights_;
|
||||
|
||||
TreeSpec inputSpec_;
|
||||
TreeSpec outputSpec_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
||||
template <>
|
||||
struct fmt::formatter<torch::nativert::ExecutorType> {
|
||||
template <typename ParseContext>
|
||||
constexpr auto parse(ParseContext& ctx) {
|
||||
return ctx.begin();
|
||||
}
|
||||
|
||||
template <typename FormatContext>
|
||||
auto format(torch::nativert::ExecutorType et, FormatContext& ctx) const {
|
||||
using namespace torch::nativert;
|
||||
switch (et) {
|
||||
case ExecutorType::INTERPRETER:
|
||||
return format_to(ctx.out(), "INTERPRETER");
|
||||
case ExecutorType::AOTINDUCTOR:
|
||||
return format_to(ctx.out(), "AOTINDUCTOR");
|
||||
case ExecutorType::MTIA:
|
||||
return format_to(ctx.out(), "MTIA");
|
||||
default:
|
||||
return format_to(ctx.out(), "UNKNOWN");
|
||||
}
|
||||
}
|
||||
};
|
||||
168
torch/csrc/nativert/executor/OpKernel.cpp
Normal file
168
torch/csrc/nativert/executor/OpKernel.cpp
Normal file
@ -0,0 +1,168 @@
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/ostream.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/ConfigUtils.h"
|
||||
#include "torch/csrc/nativert/common/Enumerate.h"
|
||||
#include "torch/csrc/nativert/common/String.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
|
||||
#ifdef __SIGRID_USE_GPU__
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#endif
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
c10::OperatorHandle getOperatorForTarget(
|
||||
std::string_view target,
|
||||
const Node* node) {
|
||||
// target could come as either "torch.ops.aten.add.default" or
|
||||
// "aten.add.default"
|
||||
std::vector<std::string_view> atoms = split(target, '.');
|
||||
|
||||
size_t numAtoms = atoms.size();
|
||||
if (numAtoms < 3) {
|
||||
TORCH_CHECK(false, "Invalid target: ", target);
|
||||
}
|
||||
|
||||
const std::string_view ns = atoms[numAtoms - 3];
|
||||
const std::string_view opName = atoms[numAtoms - 2];
|
||||
const std::string_view overloadName = atoms[numAtoms - 1];
|
||||
|
||||
const auto operatorName = fmt::format("{}::{}", ns, opName);
|
||||
std::string normalizedOverloadName;
|
||||
if (overloadName == "default") {
|
||||
normalizedOverloadName = "";
|
||||
} else {
|
||||
normalizedOverloadName = overloadName;
|
||||
}
|
||||
|
||||
auto handle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
||||
operatorName.c_str(), normalizedOverloadName.c_str());
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
std::string readableArgs(
|
||||
const c10::FunctionSchema& schema,
|
||||
const std::vector<c10::IValue>& stack) {
|
||||
auto schemaArgs = schema.arguments();
|
||||
std::stringstream ss;
|
||||
for (const auto& [i, arg] : enumerate(stack)) {
|
||||
ss << "arg" << i << " " << schemaArgs[i].name() << ": " << arg.tagKind()
|
||||
<< " ";
|
||||
if (arg.isTensor()) {
|
||||
auto t = arg.toTensor();
|
||||
ss << t.dtype() << t.sizes() << t.device();
|
||||
} else if (arg.isTensorList()) {
|
||||
auto tl = arg.toTensorVector();
|
||||
ss << "[";
|
||||
for (const auto& t : tl) {
|
||||
ss << t.dtype() << t.sizes() << t.device() << ", ";
|
||||
}
|
||||
ss << "]";
|
||||
} else if (arg.isNone()) {
|
||||
// pass
|
||||
} else {
|
||||
ss << arg;
|
||||
}
|
||||
ss << "\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
const bool OpKernel::blockingEnabled_ =
|
||||
maybeGetEnv("TORCH_NATIVE_RUNTIME_CUDA_LAUNCH_BLOCKING").value_or("0") == "1";
|
||||
|
||||
void OpKernel::compute(ExecutionFrame& executionFrame) const {
|
||||
VLOG(2) << "Executing: " << *node_;
|
||||
|
||||
computeInternal(executionFrame);
|
||||
|
||||
#ifdef __SIGRID_USE_GPU__
|
||||
if (device_.has_value() && device_->is_cuda() && blockingEnabled_) {
|
||||
AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
#endif
|
||||
|
||||
VLOG(2) << "Completed: " << *node_;
|
||||
}
|
||||
|
||||
Arguments prefillStackWithStaticArgs(
|
||||
const Node* node,
|
||||
const c10::FunctionSchema& schema) {
|
||||
std::vector<c10::IValue> stackWithStaticArgs;
|
||||
std::vector<Value*> dynamicArgs;
|
||||
const auto& schemaArgs = schema.arguments();
|
||||
stackWithStaticArgs.resize(schemaArgs.size());
|
||||
dynamicArgs.resize(schemaArgs.size());
|
||||
|
||||
// initialized stackWithStaticArgs_ with static inputs
|
||||
for (const auto& [idx, schemaArg] : enumerate(schemaArgs)) {
|
||||
const auto& argName = schemaArg.name();
|
||||
|
||||
// Check if this is a dynamic input to the op.
|
||||
const auto input = node->tryGetInput(argName);
|
||||
if (input != nullptr) {
|
||||
stackWithStaticArgs.at(idx) = c10::IValue();
|
||||
dynamicArgs.at(idx) = input->value;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this is a statically known input to the op.
|
||||
const auto attribute = node->tryGetAttribute(argName);
|
||||
if (attribute != nullptr) {
|
||||
stackWithStaticArgs.at(idx) = constantToIValue(attribute->value);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, this must have a default value
|
||||
if (schemaArg.default_value().has_value()) {
|
||||
stackWithStaticArgs.at(idx) = schemaArg.default_value().value();
|
||||
continue;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot initialize argument ",
|
||||
argName,
|
||||
" for node ",
|
||||
*node,
|
||||
" with schema ",
|
||||
schema);
|
||||
}
|
||||
return Arguments{std::move(stackWithStaticArgs), std::move(dynamicArgs)};
|
||||
}
|
||||
|
||||
void fillDynamicInputs(
|
||||
const ExecutionFrame& executionFrame,
|
||||
const Arguments& arguments,
|
||||
std::vector<c10::IValue>& stack) {
|
||||
// fill the stack with dynamic values from execution frame,
|
||||
// including tensor, tensors, symint, symints
|
||||
|
||||
for (auto [idx, value] : arguments.getDynamicArgs()) {
|
||||
CHECK(idx < stack.size()) << "invalid idx";
|
||||
CHECK(stack.at(idx).isNone()) << "stack[idx] shouldn't have been populated";
|
||||
if (value->type() == Type::TensorList) {
|
||||
// TODO: This for passing List<Tensor> as an input to op that takes a
|
||||
// List<Optional<Tensor>>.
|
||||
// this is awful, but if I don't cast it to a vector and back to a
|
||||
// list, I get list covariance problems where List<Tensor> is not a
|
||||
// subtype of List<Optional<Tensor>>, which pops up when trying to
|
||||
// execute aten.index.Tensor. See the code in:
|
||||
// https://fburl.com/code/t1poq3z3. Our lists should be covariant
|
||||
// because they are static, but IValues don't know that :(
|
||||
stack[idx] = executionFrame.getIValue(value->id()).toTensorList().vec();
|
||||
} else if (value->type() == Type::None) {
|
||||
stack[idx] = c10::IValue();
|
||||
} else {
|
||||
stack[idx] = executionFrame.getIValue(value->id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
156
torch/csrc/nativert/executor/OpKernel.h
Normal file
156
torch/csrc/nativert/executor/OpKernel.h
Normal file
@ -0,0 +1,156 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/script.h>
|
||||
#include "c10/core/Device.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
c10::OperatorHandle getOperatorForTarget(
|
||||
std::string_view target,
|
||||
const Node* node = nullptr);
|
||||
|
||||
class Arguments {
|
||||
public:
|
||||
Arguments(
|
||||
std::vector<c10::IValue> stackWithStaticArgs,
|
||||
std::vector<Value*> dynamicArgs)
|
||||
: stackWithStaticArgs_(std::move(stackWithStaticArgs)),
|
||||
dynamicArgs_(std::move(dynamicArgs)) {
|
||||
for (size_t i = 0; i < dynamicArgs_.size(); i++) {
|
||||
if (dynamicArgs_[i]) {
|
||||
indices_.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic arguments are the inputs that were not baked in the graph
|
||||
* during graph capture, i.e. all the tensor inputs to operators.
|
||||
*
|
||||
* This API will return a view of pairs consist of the argument index
|
||||
* and the corresponding Value pointer from the graph.
|
||||
*/
|
||||
auto getDynamicArgs() const {
|
||||
std::vector<std::pair<size_t, Value*>> ret;
|
||||
ret.reserve(indices_.size());
|
||||
for (auto i : indices_) {
|
||||
ret.emplace_back(i, dynamicArgs_[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Argument i means the i-th input to the operator in the argument list.
|
||||
// Will return nullptr if the argument is not dynamic.
|
||||
Value* findDynamic(size_t i) const {
|
||||
DCHECK(i < dynamicArgs_.size()) << "Invalid input index: " << i;
|
||||
return dynamicArgs_[i];
|
||||
}
|
||||
|
||||
// Argument i means the i-th input to the operator in the argument list.
|
||||
// Will return None as IValue if the argument is not static.
|
||||
const c10::IValue& getStatic(size_t i) const {
|
||||
DCHECK(i < stackWithStaticArgs_.size()) << "Invalid input index: " << i;
|
||||
return stackWithStaticArgs_[i];
|
||||
}
|
||||
|
||||
/**
|
||||
* Static arguments are the inputs that were specialized to a fixed value
|
||||
* during graph capture phase. For example, scalar inputs and device
|
||||
* are considered arguments.
|
||||
*/
|
||||
const std::vector<c10::IValue>& getStackWithStaticArgs() const {
|
||||
return stackWithStaticArgs_;
|
||||
}
|
||||
|
||||
private:
|
||||
// stack pre-populated with attributes, aka static arguments
|
||||
const std::vector<c10::IValue> stackWithStaticArgs_;
|
||||
|
||||
// Argument can only be asTensor, asTensors, asSymInt, asSymInts
|
||||
const std::vector<Value*> dynamicArgs_;
|
||||
std::vector<size_t> indices_;
|
||||
};
|
||||
|
||||
void fillDynamicInputs(
|
||||
const ExecutionFrame& executionFrame,
|
||||
const Arguments& arguments,
|
||||
std::vector<c10::IValue>& stack);
|
||||
|
||||
Arguments prefillStackWithStaticArgs(
|
||||
const Node* node,
|
||||
const c10::FunctionSchema& schema);
|
||||
|
||||
std::string readableArgs(
|
||||
const c10::FunctionSchema& schema,
|
||||
const std::vector<c10::IValue>& stack);
|
||||
|
||||
// Abstract interface representing a kernel, which is responsible for executing
|
||||
// a single Node.
|
||||
class OpKernel {
|
||||
public:
|
||||
explicit OpKernel(
|
||||
const Node* node,
|
||||
std::optional<c10::Device> device = std::nullopt)
|
||||
: node_(node), device_(device) {
|
||||
VLOG(1) << "Initializing kernel for node: " << *node_;
|
||||
}
|
||||
|
||||
enum class Kind : uint8_t {
|
||||
kPrimKernel,
|
||||
kStaticDispatchKernel,
|
||||
kInterpreterFallbackKernel,
|
||||
};
|
||||
|
||||
const Node* node() const {
|
||||
return node_;
|
||||
}
|
||||
void compute(ExecutionFrame& executionFrame) const;
|
||||
|
||||
Kind kind() const {
|
||||
return kind_;
|
||||
}
|
||||
|
||||
bool hasPrimKernel() const {
|
||||
return kind() == Kind::kPrimKernel;
|
||||
}
|
||||
|
||||
bool hasStaticDispatch() const {
|
||||
return kind() == Kind::kStaticDispatchKernel;
|
||||
}
|
||||
|
||||
size_t numInputs() const {
|
||||
return node_->inputs().size();
|
||||
}
|
||||
|
||||
size_t numOutputs() const {
|
||||
return node_->outputs().size();
|
||||
}
|
||||
|
||||
// Input is readonly
|
||||
[[nodiscard]] virtual const c10::IValue& input(
|
||||
uint32_t i,
|
||||
ExecutionFrame& executionFrame) const {
|
||||
CHECK(i < numInputs()) << "Invalid input index: " << i;
|
||||
return executionFrame.getIValue(node_->inputs()[i].value->id());
|
||||
}
|
||||
|
||||
// Output is readwrite
|
||||
c10::IValue& output(uint32_t i, ExecutionFrame& executionFrame) const {
|
||||
CHECK(i < numOutputs()) << "Invalid output index: " << i;
|
||||
return executionFrame.getIValue(node_->outputs()[i]->id(), true);
|
||||
}
|
||||
|
||||
virtual ~OpKernel() = default;
|
||||
|
||||
protected:
|
||||
virtual void computeInternal(ExecutionFrame& executionFrame) const = 0;
|
||||
|
||||
const Node* node_;
|
||||
std::optional<c10::Device> device_;
|
||||
const static bool blockingEnabled_;
|
||||
Kind kind_ = Kind::kInterpreterFallbackKernel;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
245
torch/csrc/nativert/executor/ParallelGraphExecutor.cpp
Normal file
245
torch/csrc/nativert/executor/ParallelGraphExecutor.cpp
Normal file
@ -0,0 +1,245 @@
|
||||
#include "torch/csrc/nativert/executor/ParallelGraphExecutor.h"
|
||||
|
||||
#include "torch/csrc/nativert/common/concurrentqueue.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
|
||||
|
||||
namespace {
|
||||
|
||||
#define LIKELY(x) (__builtin_expect((x), 1))
|
||||
#define UNLIKELY(x) (__builtin_expect((x), 0))
|
||||
|
||||
#define WITH_LOCK(m, block) \
|
||||
{ \
|
||||
std::unique_lock<decltype(m)> lk_(m); \
|
||||
block \
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
ThreadPoolExecutor::ThreadPoolExecutor()
|
||||
: work_(std::make_unique<moodycamel::ConcurrentQueue<Work>>()) {}
|
||||
|
||||
ThreadPoolExecutor::~ThreadPoolExecutor() {
|
||||
stop();
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() {
|
||||
thread_local moodycamel::ProducerToken ptok(*work_);
|
||||
return ptok;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() {
|
||||
thread_local moodycamel::ConsumerToken ctok(*work_);
|
||||
return ctok;
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) {
|
||||
session->addWork();
|
||||
unit->run(this, session);
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::start(int32_t numThreads) {
|
||||
stopped_ = false;
|
||||
for (uint32_t i = 0; i < numThreads; ++i) {
|
||||
threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this));
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::loop() {
|
||||
while (true) {
|
||||
Work unit;
|
||||
|
||||
sem_->acquire();
|
||||
|
||||
if (stopped_) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (!work_->try_dequeue(ctok(), unit)) {
|
||||
};
|
||||
|
||||
unit();
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) {
|
||||
session->addWork();
|
||||
work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session));
|
||||
sem_->release();
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::add(
|
||||
SessionState* session,
|
||||
std::vector<WorkUnit*>::const_iterator&& begin,
|
||||
const std::vector<WorkUnit*>::const_iterator&& end) {
|
||||
const auto count = end - begin;
|
||||
|
||||
switch (count) {
|
||||
case 0: {
|
||||
return;
|
||||
}
|
||||
case 1: {
|
||||
return add(session, *begin);
|
||||
}
|
||||
}
|
||||
|
||||
session->addWork(count);
|
||||
|
||||
std::vector<Work> runnables;
|
||||
runnables.reserve(count);
|
||||
for (; begin != end; ++begin) {
|
||||
runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session));
|
||||
}
|
||||
|
||||
work_->enqueue_bulk(ptok(), runnables.begin(), count);
|
||||
sem_->release(count);
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::stop() {
|
||||
stopped_ = true;
|
||||
sem_->release(threads_.size());
|
||||
|
||||
std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); });
|
||||
threads_.clear();
|
||||
|
||||
{
|
||||
// reset sem
|
||||
auto tmp = std::make_unique<Semaphore>();
|
||||
sem_.swap(tmp);
|
||||
}
|
||||
|
||||
{
|
||||
// flush queue
|
||||
auto tmp = moodycamel::ConcurrentQueue<Work>();
|
||||
work_->swap(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::run(
|
||||
SessionState& session,
|
||||
const std::vector<WorkUnit*>& roots) {
|
||||
// case where thread ptok exists but work_ was swapped
|
||||
if (auto& tok = ptok(); UNLIKELY(!tok.valid())) {
|
||||
moodycamel::ProducerToken tmp(*work_);
|
||||
tok.swap(tmp);
|
||||
}
|
||||
|
||||
const auto rootCount = roots.size();
|
||||
|
||||
if (UNLIKELY(rootCount == 0)) {
|
||||
return;
|
||||
} else if (LIKELY(rootCount > 1)) {
|
||||
add(&session, roots.begin() + 1, roots.end());
|
||||
}
|
||||
|
||||
execute_inline(&session, roots[0]);
|
||||
|
||||
session.wait();
|
||||
}
|
||||
|
||||
void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) {
|
||||
thread_local std::vector<WorkUnit*> newWorkUnits;
|
||||
thread_local c10::InferenceMode mode;
|
||||
|
||||
WorkUnit* unit = this;
|
||||
|
||||
while (true) {
|
||||
unit->kernel->compute(session->frame());
|
||||
|
||||
for (auto* user : unit->users) {
|
||||
if (session->decrementProducers(user->node)) {
|
||||
newWorkUnits.push_back(user);
|
||||
}
|
||||
}
|
||||
|
||||
switch (newWorkUnits.size()) {
|
||||
case 0: {
|
||||
return session->removeWork();
|
||||
}
|
||||
case 1: {
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
executor->add(session, newWorkUnits[1]);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
executor->add(session, newWorkUnits.begin() + 1, newWorkUnits.end());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
unit = newWorkUnits[0];
|
||||
newWorkUnits.clear();
|
||||
}
|
||||
}
|
||||
|
||||
ParallelGraphExecutor::ParallelGraphExecutor(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const ExecutorConfig& executorConfig)
|
||||
: GraphExecutorBase(graph, std::move(nodeKernels), executorConfig),
|
||||
workUnits_(
|
||||
graph.nodes().size() - 2 /* no need for prim.Input or Prim.Output */),
|
||||
graph_(graph) {
|
||||
auto& nodes = graph_.nodes();
|
||||
|
||||
auto input = &*nodes.begin();
|
||||
auto output = &*nodes.rbegin();
|
||||
|
||||
{
|
||||
// get rid of prim.Input and prim.Output kernels
|
||||
// since we won't be needing them
|
||||
nodeKernels_.erase(nodeKernels_.begin());
|
||||
nodeKernels_.pop_back();
|
||||
}
|
||||
|
||||
size_t idx = 0;
|
||||
for (const auto& node : nodes) {
|
||||
if (&node == input || &node == output) {
|
||||
continue;
|
||||
}
|
||||
auto& workUnit =
|
||||
nodeToWorkUnit_.insert_or_assign(&node, &workUnits_[idx]).first->second;
|
||||
workUnit->node = &node;
|
||||
workUnit->kernel = nodeKernels_[idx++].get();
|
||||
producers_.insert({&node, 0});
|
||||
}
|
||||
|
||||
for (auto& unit : workUnits_) {
|
||||
for (const auto* dep : unit.node->users()) {
|
||||
if (dep != output) {
|
||||
unit.users.push_back(nodeToWorkUnit_[dep]);
|
||||
producers_[dep] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& [node, p] : producers_) {
|
||||
if (p == 0) {
|
||||
inputWorkUnits_.push_back(nodeToWorkUnit_[node]);
|
||||
}
|
||||
}
|
||||
|
||||
executor_.start(executorConfig.maxParallelOps);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ParallelGraphExecutor::execute(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<c10::IValue> inputs) {
|
||||
fillUserInputs(executionFrame, std::move(inputs));
|
||||
return executeWithPrefilledFrame(executionFrame);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ParallelGraphExecutor::executeWithPrefilledFrame(
|
||||
ExecutionFrame& executionFrame) {
|
||||
auto session = SessionState(executionFrame, producers_);
|
||||
executor_.run(session, inputWorkUnits_);
|
||||
|
||||
return executionFrame.getUserOutputs();
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
94
torch/csrc/nativert/executor/ParallelGraphExecutor.h
Normal file
94
torch/csrc/nativert/executor/ParallelGraphExecutor.h
Normal file
@ -0,0 +1,94 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/common/Semaphore.h"
|
||||
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
|
||||
#include "torch/csrc/nativert/executor/SessionState.h"
|
||||
|
||||
namespace moodycamel {
|
||||
struct ProducerToken;
|
||||
struct ConsumerToken;
|
||||
struct ConcurrentQueueDefaultTraits;
|
||||
template <typename T, typename Traits>
|
||||
class ConcurrentQueue;
|
||||
} // namespace moodycamel
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class ThreadPoolExecutor;
|
||||
|
||||
typedef std::function<void()> Work;
|
||||
|
||||
struct WorkUnit {
|
||||
const Node* node;
|
||||
OpKernel* kernel;
|
||||
std::vector<WorkUnit*> users;
|
||||
void run(ThreadPoolExecutor* executor, SessionState* sessionState);
|
||||
};
|
||||
|
||||
class ThreadPoolExecutor {
|
||||
public:
|
||||
explicit ThreadPoolExecutor();
|
||||
~ThreadPoolExecutor();
|
||||
ThreadPoolExecutor(const ThreadPoolExecutor&) = delete;
|
||||
ThreadPoolExecutor& operator=(ThreadPoolExecutor const&) = delete;
|
||||
ThreadPoolExecutor(ThreadPoolExecutor&&) = delete;
|
||||
ThreadPoolExecutor& operator=(ThreadPoolExecutor&&) = delete;
|
||||
|
||||
void run(SessionState& session, const std::vector<WorkUnit*>& roots);
|
||||
|
||||
void start(int32_t numThreads);
|
||||
void stop();
|
||||
|
||||
// execute unit on the current thread
|
||||
// NOTE: children can still be offloaded to other threads
|
||||
C10_ALWAYS_INLINE void execute_inline(SessionState* session, WorkUnit* unit);
|
||||
|
||||
void add(SessionState* session, WorkUnit* unit);
|
||||
void add(
|
||||
SessionState* session,
|
||||
std::vector<WorkUnit*>::const_iterator&& begin,
|
||||
const std::vector<WorkUnit*>::const_iterator&& end);
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok();
|
||||
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok();
|
||||
|
||||
private:
|
||||
void loop();
|
||||
|
||||
std::atomic_bool stopped_{false};
|
||||
|
||||
std::unique_ptr<Semaphore> sem_{std::make_unique<Semaphore>()};
|
||||
|
||||
std::unique_ptr<moodycamel::ConcurrentQueue<
|
||||
Work,
|
||||
moodycamel::ConcurrentQueueDefaultTraits>>
|
||||
work_;
|
||||
std::vector<std::thread> threads_;
|
||||
};
|
||||
|
||||
class ParallelGraphExecutor : public GraphExecutorBase {
|
||||
public:
|
||||
ParallelGraphExecutor(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const ExecutorConfig& executorConfig);
|
||||
|
||||
std::vector<c10::IValue> execute(
|
||||
ExecutionFrame& frame,
|
||||
std::vector<c10::IValue> inputs) override;
|
||||
|
||||
std::vector<c10::IValue> executeWithPrefilledFrame(
|
||||
ExecutionFrame& frame) override;
|
||||
|
||||
private:
|
||||
ThreadPoolExecutor executor_;
|
||||
|
||||
std::vector<WorkUnit*> inputWorkUnits_;
|
||||
c10::FastMap<const Node*, WorkUnit*> nodeToWorkUnit_;
|
||||
std::vector<WorkUnit> workUnits_;
|
||||
|
||||
const Graph& graph_;
|
||||
c10::FastMap<const Node*, __copyable_atomic<std::uint_fast32_t>> producers_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
91
torch/csrc/nativert/executor/Placement.cpp
Normal file
91
torch/csrc/nativert/executor/Placement.cpp
Normal file
@ -0,0 +1,91 @@
|
||||
#include "torch/csrc/nativert/executor/Placement.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Placement& placement) {
|
||||
std::map<std::string, c10::Device> keys;
|
||||
for (const auto& pair : placement.deviceMap_) {
|
||||
keys.insert({pair.first.str(), pair.first});
|
||||
}
|
||||
|
||||
bool first = true;
|
||||
auto checkComma = [&]() {
|
||||
if (!first) {
|
||||
os << ",";
|
||||
}
|
||||
first = false;
|
||||
};
|
||||
|
||||
os << "";
|
||||
for (const auto& pair : keys) {
|
||||
checkComma();
|
||||
const auto& key = pair.second;
|
||||
const auto& value = placement.deviceMap_.at(key);
|
||||
os << pair.first << "|" << value.str();
|
||||
}
|
||||
|
||||
if (placement.defaultDevice_.has_value()) {
|
||||
checkComma();
|
||||
os << "|" << placement.defaultDevice_.value().str();
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
c10::Device normalizeDevice(const c10::Device& device) {
|
||||
// cpu device doesn't have index
|
||||
// cuda device index must have a index
|
||||
if (device.is_cpu()) {
|
||||
return c10::Device(c10::DeviceType::CPU);
|
||||
} else if (device.is_cuda()) {
|
||||
return c10::Device(
|
||||
c10::DeviceType::CUDA, device.has_index() ? device.index() : 0);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported device type", device);
|
||||
}
|
||||
}
|
||||
|
||||
bool isSameDevice(const c10::Device& a, const c10::Device& b) {
|
||||
if (a.is_cpu()) {
|
||||
return b.is_cpu();
|
||||
}
|
||||
if (a.is_cuda()) {
|
||||
if (b.is_cuda()) {
|
||||
auto aIndex = a.has_index() ? a.index() : 0;
|
||||
auto bIndex = b.has_index() ? b.index() : 0;
|
||||
return aIndex == bIndex;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(false, "Unsupported device type", a, " and ", b);
|
||||
return false;
|
||||
}
|
||||
|
||||
Placement::Placement(std::optional<c10::Device> defaultDevice)
|
||||
: Placement({}, defaultDevice) {}
|
||||
|
||||
Placement::Placement(
|
||||
const std::unordered_map<c10::Device, c10::Device>& deviceMap,
|
||||
std::optional<c10::Device> defaultDevice) {
|
||||
for (const auto& [srcDevice, dstDevice] : deviceMap) {
|
||||
deviceMap_.emplace(normalizeDevice(srcDevice), normalizeDevice(dstDevice));
|
||||
}
|
||||
if (defaultDevice.has_value()) {
|
||||
defaultDevice_ = normalizeDevice(defaultDevice.value());
|
||||
}
|
||||
}
|
||||
|
||||
c10::Device Placement::getMappedDevice(const c10::Device& srcDevice) const {
|
||||
auto it = deviceMap_.find(normalizeDevice(srcDevice));
|
||||
if (it != deviceMap_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
if (defaultDevice_.has_value()) {
|
||||
return defaultDevice_.value();
|
||||
}
|
||||
return srcDevice;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
29
torch/csrc/nativert/executor/Placement.h
Normal file
29
torch/csrc/nativert/executor/Placement.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
c10::Device normalizeDevice(const c10::Device& device);
|
||||
bool isSameDevice(const c10::Device& device1, const c10::Device& device2);
|
||||
|
||||
struct TORCH_API Placement {
|
||||
Placement() = default;
|
||||
explicit Placement(std::optional<c10::Device> defaultDevice);
|
||||
explicit Placement(
|
||||
const std::unordered_map<c10::Device, c10::Device>& deviceMap,
|
||||
std::optional<c10::Device> defaultDevice = std::nullopt);
|
||||
c10::Device getMappedDevice(const c10::Device& srcDevice) const;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const Placement& obj);
|
||||
|
||||
protected:
|
||||
std::unordered_map<c10::Device, c10::Device> deviceMap_;
|
||||
std::optional<c10::Device> defaultDevice_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
35
torch/csrc/nativert/executor/SerialGraphExecutor.cpp
Normal file
35
torch/csrc/nativert/executor/SerialGraphExecutor.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
#include "torch/csrc/nativert/executor/SerialGraphExecutor.h"
|
||||
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
std::vector<c10::IValue> SerialGraphExecutor::execute(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<c10::IValue> inputs) {
|
||||
fillUserInputs(executionFrame, std::move(inputs));
|
||||
|
||||
return executeWithPrefilledFrame(executionFrame);
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> SerialGraphExecutor::executeWithPrefilledFrame(
|
||||
ExecutionFrame& executionFrame) {
|
||||
// Execute kernels for all nodes except prim.Input and prim.Output
|
||||
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
|
||||
nodeKernels_[nodeIdx]->compute(executionFrame);
|
||||
|
||||
// don't free intermediate values when static memory planning is enabled
|
||||
if (!executorConfig_.enableStaticMemoryPlanning) {
|
||||
// Free the intermediate values that are no used anymore
|
||||
for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) {
|
||||
executionFrame.releaseValue(valueKey);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return executionFrame.getUserOutputs();
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
23
torch/csrc/nativert/executor/SerialGraphExecutor.h
Normal file
23
torch/csrc/nativert/executor/SerialGraphExecutor.h
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class SerialGraphExecutor : public GraphExecutorBase {
|
||||
public:
|
||||
SerialGraphExecutor(
|
||||
const Graph& graph,
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
|
||||
const ExecutorConfig& executorConfig)
|
||||
: GraphExecutorBase(graph, std::move(nodeKernels), executorConfig) {}
|
||||
|
||||
std::vector<c10::IValue> execute(
|
||||
ExecutionFrame& frame,
|
||||
std::vector<c10::IValue> inputs) override;
|
||||
|
||||
std::vector<c10::IValue> executeWithPrefilledFrame(
|
||||
ExecutionFrame& frame) override;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
77
torch/csrc/nativert/executor/SessionState.h
Normal file
77
torch/csrc/nativert/executor/SessionState.h
Normal file
@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
template <typename T, typename __atomic_base = std::atomic<T>>
|
||||
struct __copyable_atomic : public __atomic_base {
|
||||
public:
|
||||
__copyable_atomic() = default;
|
||||
__copyable_atomic(const T& t) noexcept(__atomic_base::is_always_lock_free)
|
||||
: __atomic_base(t) {}
|
||||
__copyable_atomic(const __copyable_atomic& other) noexcept(
|
||||
__atomic_base::is_always_lock_free)
|
||||
: __atomic_base(other.load()) {}
|
||||
__copyable_atomic& operator=(const __copyable_atomic& other) noexcept(
|
||||
__atomic_base::is_always_lock_free) {
|
||||
this->store(other.load());
|
||||
return this;
|
||||
}
|
||||
};
|
||||
|
||||
class SessionState {
|
||||
public:
|
||||
explicit SessionState(
|
||||
ExecutionFrame& frame,
|
||||
c10::FastMap<const Node*, __copyable_atomic<std::uint_fast32_t>>
|
||||
producers = {})
|
||||
: producers_(std::move(producers)), frame_(frame) {}
|
||||
|
||||
C10_ALWAYS_INLINE void wait() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [&]() {
|
||||
return workOutstanding_.load(std::memory_order_seq_cst) == 0;
|
||||
});
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE void addWork(uint32_t ct = 1) {
|
||||
workOutstanding_.fetch_add(ct, std::memory_order_seq_cst);
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE void removeWork() {
|
||||
if (workOutstanding_.fetch_sub(1, std::memory_order_seq_cst) == 1) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE ExecutionFrame& frame() {
|
||||
return frame_;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE /* producersRemaining == 0 */ bool decrementProducers(
|
||||
const Node* node) {
|
||||
return producers_.at(node).fetch_sub(1, std::memory_order_seq_cst) == 1;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE void setProducers(const Node* node, uint32_t v = 1) {
|
||||
producers_[node] += v;
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic_uint_fast32_t workOutstanding_;
|
||||
c10::FastMap<const Node*, __copyable_atomic<std::uint_fast32_t>> producers_;
|
||||
|
||||
std::condition_variable cv_;
|
||||
std::mutex mutex_;
|
||||
|
||||
ExecutionFrame& frame_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
394
torch/csrc/nativert/executor/Weights.cpp
Normal file
394
torch/csrc/nativert/executor/Weights.cpp
Normal file
@ -0,0 +1,394 @@
|
||||
#include "torch/csrc/nativert/executor/Weights.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "c10/util/Logging.h"
|
||||
|
||||
#include <torch/csrc/jit/serialization/import.h> // @manual=//caffe2:torch-cpp-cpu
|
||||
#include <torch/csrc/jit/serialization/import_read.h> // @manual=//caffe2:torch-cpp-cpu
|
||||
#include "caffe2/serialize/inline_container.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
#include "torch/csrc/nativert/common/String.h"
|
||||
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
WeightVersion Weights::globalVersion_ = 0;
|
||||
|
||||
Weights::Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict,
|
||||
const Placement& placement)
|
||||
: graph_(graph),
|
||||
weightsMeta_(graph->weightsMeta()),
|
||||
placement_(placement),
|
||||
version_(globalVersion_++) {
|
||||
if (stateDict.has_value()) {
|
||||
loadStateDict(stateDict.value());
|
||||
}
|
||||
}
|
||||
|
||||
Weights::Weights(
|
||||
const Graph* graph,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const std::unordered_map<std::string, std::string>& stateDictPaths,
|
||||
std::string_view stateDictPathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& constantPaths,
|
||||
std::string_view constantPathPrefix,
|
||||
const Placement& placement,
|
||||
std::function<bool(const std::string&)> skipSizeCheck,
|
||||
std::function<bool(const std::string&)> skipDtypeCheck)
|
||||
: graph_(graph),
|
||||
weightsMeta_(graph->weightsMeta()),
|
||||
placement_(placement),
|
||||
version_(globalVersion_++),
|
||||
skipSizeCheck_(std::move(skipSizeCheck)),
|
||||
skipDtypeCheck_(std::move(skipDtypeCheck)) {
|
||||
auto loadAndInsert =
|
||||
[&](const std::string& tensorName,
|
||||
std::string_view pathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& tensorPaths,
|
||||
bool isUsed) {
|
||||
auto pathIt = tensorPaths.find(tensorName);
|
||||
CHECK(pathIt != tensorPaths.end())
|
||||
<< "Couldn't find " << tensorName << " in tensorPaths";
|
||||
|
||||
const std::string tensorPath = std::string{pathPrefix} + pathIt->second;
|
||||
VLOG(1) << "Loading weight from: " << tensorPath;
|
||||
CHECK(pytorchStreamReader->hasRecord(tensorPath))
|
||||
<< tensorPath << " not found";
|
||||
|
||||
auto [tensorData, tensorDataSize] =
|
||||
pytorchStreamReader->getRecord(tensorPath);
|
||||
|
||||
// TODO: We now have two copies of metadata for weights, one in
|
||||
// model definition /models/<model_name>.json, another in
|
||||
// /extra/xl_weights/<model_name>_model_param_config.json
|
||||
// Currently, we only use the metadata from model definition.
|
||||
std::optional<TensorMeta> tensorMeta;
|
||||
if (weightsMeta_.find(tensorName) != weightsMeta_.end()) {
|
||||
tensorMeta = weightsMeta_.at(tensorName);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Tensor meta not found for: ", tensorName);
|
||||
}
|
||||
|
||||
if (tensorDataSize == 0 && tensorMeta->numel() > 0) {
|
||||
VLOG(1) << "Tensor " << tensorName
|
||||
<< " does not have data and create on Meta device";
|
||||
allValues_[tensorName] = torch::empty_strided(
|
||||
tensorMeta->sizes(),
|
||||
tensorMeta->strides(),
|
||||
tensorMeta->asTensorOptions().device(torch::kMeta));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isUsed) {
|
||||
VLOG(1) << "Tensor " << tensorName << " is not used during inference";
|
||||
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
|
||||
allValues_[tensorName] =
|
||||
at::scalar_tensor(0, at::TensorOptions().device(targetDevice));
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bytesPerEntry =
|
||||
c10::scalarTypeToTypeMeta(tensorMeta->dtype()).itemsize();
|
||||
auto device = tensorData.device();
|
||||
auto storage = c10::Storage(
|
||||
c10::Storage::use_byte_size_t(),
|
||||
at::detail::computeStorageNbytes(
|
||||
tensorMeta->sizes(), tensorMeta->strides(), bytesPerEntry),
|
||||
std::move(tensorData), // ownership is transferred
|
||||
nullptr,
|
||||
false);
|
||||
const auto tensorOptions = at::TensorOptions(device)
|
||||
.dtype(tensorMeta->dtype())
|
||||
.requires_grad(false);
|
||||
auto tensor =
|
||||
at::empty({0}, tensorOptions)
|
||||
.set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides());
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
|
||||
VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice;
|
||||
if (!isSameDevice(targetDevice, tensor.device())) {
|
||||
tensor = tensor.to(targetDevice);
|
||||
}
|
||||
|
||||
allValues_[tensorName] = tensor;
|
||||
};
|
||||
|
||||
auto loadAndInsertParamsBuffers = [&](const std::string& tensorName,
|
||||
bool isUsed) {
|
||||
return loadAndInsert(
|
||||
tensorName, stateDictPathPrefix, stateDictPaths, isUsed);
|
||||
};
|
||||
|
||||
size_t weightIndex = 0;
|
||||
bool isUsed;
|
||||
const auto& weightValues = graph->weightValues();
|
||||
|
||||
for (const auto& tensorName : graph->signature().parameters()) {
|
||||
isUsed = (weightValues[weightIndex]->users().size() > 0);
|
||||
if (!isUsed) {
|
||||
unusedWeights_.insert(tensorName);
|
||||
}
|
||||
loadAndInsertParamsBuffers(tensorName, isUsed);
|
||||
weightIndex++;
|
||||
}
|
||||
for (const auto& tensorName : graph->signature().buffers()) {
|
||||
isUsed = (weightValues[weightIndex]->users().size() > 0);
|
||||
if (!isUsed) {
|
||||
unusedWeights_.insert(tensorName);
|
||||
}
|
||||
loadAndInsertParamsBuffers(tensorName, isUsed);
|
||||
weightIndex++;
|
||||
}
|
||||
|
||||
// Load tensor constants and custom object constants, they are both stored
|
||||
// in the same directory in the archive, i.e. "extra/constants/" tensor
|
||||
// constants are prefixed with "tensor_" custom objects are prefixed with
|
||||
// "custom_obj_"
|
||||
auto loadConstants = [&](const std::vector<std::string>& constants) {
|
||||
for (const auto& constantName : constants) {
|
||||
auto pathIt = constantPaths.find(constantName);
|
||||
CHECK(pathIt != constantPaths.end())
|
||||
<< "Couldn't find " << constantName << " in constantPaths";
|
||||
auto& fileName = pathIt->second;
|
||||
|
||||
if (starts_with(
|
||||
fileName, archive_spec::TENSOR_CONSTANT_FILENAME_PREFIX)) {
|
||||
// tensor constants
|
||||
isUsed = (weightValues[weightIndex]->users().size() > 0);
|
||||
if (!isUsed) {
|
||||
unusedWeights_.insert(constantName);
|
||||
}
|
||||
loadAndInsert(constantName, constantPathPrefix, constantPaths, isUsed);
|
||||
weightIndex++;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unknown constant path: ", fileName);
|
||||
}
|
||||
}
|
||||
};
|
||||
loadConstants(graph->signature().nonPersistentBuffers());
|
||||
loadConstants(graph->signature().tensorConstants());
|
||||
|
||||
// custom object constants
|
||||
for (const auto& customObjName : graph->signature().customObjs()) {
|
||||
auto pathIt = constantPaths.find(customObjName);
|
||||
CHECK(pathIt != constantPaths.end())
|
||||
<< "Couldn't find " << customObjName << " in constantPaths";
|
||||
auto& fileName = pathIt->second;
|
||||
|
||||
if (!starts_with(fileName, archive_spec::CUSTOM_OBJ_FILENAME_PREFIX)) {
|
||||
TORCH_CHECK(false, "Unknown constant path: ", fileName);
|
||||
}
|
||||
std::string customObjPath = std::string{constantPathPrefix} + fileName;
|
||||
LOG(INFO) << "Loading custom object from: " << customObjPath;
|
||||
|
||||
CHECK(pytorchStreamReader->hasRecord(customObjPath))
|
||||
<< customObjPath << " not found";
|
||||
|
||||
const auto& [customObjData, customObjDataSize] =
|
||||
pytorchStreamReader->getRecord(customObjPath);
|
||||
|
||||
const char* customObjDataPtr =
|
||||
reinterpret_cast<const char*>(customObjData.get());
|
||||
std::string customObjBytes(
|
||||
customObjDataPtr, customObjDataPtr + customObjDataSize);
|
||||
|
||||
c10::IValue customObj = torch::jit::pickle_load_obj(customObjBytes);
|
||||
CHECK(customObj.isCustomClass());
|
||||
CHECK(!customObj.isNone());
|
||||
customObjs_[customObjName] = std::move(customObj);
|
||||
customObjsPaths_[customObjPath] = customObjName;
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> Weights::parameters() const {
|
||||
std::unordered_map<std::string, at::Tensor> result;
|
||||
for (const auto& name : graph_->signature().parameters()) {
|
||||
result.emplace(name, allValues_.at(name));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> Weights::buffers() const {
|
||||
std::unordered_map<std::string, at::Tensor> result;
|
||||
for (const auto& name : graph_->signature().buffers()) {
|
||||
result.emplace(name, allValues_.at(name));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> Weights::attributes() const {
|
||||
return allValues_;
|
||||
}
|
||||
|
||||
at::Tensor Weights::at(const std::string& name) const {
|
||||
auto it = allValues_.find(name);
|
||||
if (it != allValues_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, name, " not found in Weights ", toString());
|
||||
}
|
||||
|
||||
at::Tensor& Weights::at(const std::string& name) {
|
||||
auto it = allValues_.find(name);
|
||||
if (it != allValues_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, name, " not found in Weights ", toString());
|
||||
}
|
||||
|
||||
bool Weights::contains(const std::string& name) const {
|
||||
return allValues_.find(name) != allValues_.end();
|
||||
}
|
||||
|
||||
c10::IValue Weights::getCustomObj(const std::string& name) const {
|
||||
auto it = customObjs_.find(name);
|
||||
if (it != customObjs_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Custom objects ", name, " not found in Weights");
|
||||
}
|
||||
|
||||
c10::IValue Weights::getCustomObjByFileName(const std::string& name) const {
|
||||
auto it = customObjsPaths_.find(name);
|
||||
TORCH_CHECK(
|
||||
it != customObjsPaths_.end(),
|
||||
"Custom objects with file name ",
|
||||
name,
|
||||
" not found in Weights");
|
||||
const std::string obj_name = it->second;
|
||||
return getCustomObj(obj_name);
|
||||
}
|
||||
|
||||
void Weights::loadStateDict(
|
||||
const std::unordered_map<std::string, c10::IValue>& stateDict) {
|
||||
auto validateAndInsert = [&](const std::string& name) {
|
||||
auto stateDictIt = stateDict.find(name);
|
||||
CHECK(stateDictIt != stateDict.end())
|
||||
<< "Couldn't find " << name << " in stateDict";
|
||||
|
||||
// Verify that the tensor matches the tensorMeta
|
||||
auto it = weightsMeta_.find(name);
|
||||
CHECK(it != weightsMeta_.end())
|
||||
<< "Couldn't find " << name << " in weightsMeta";
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(it->second.device());
|
||||
auto tensor = stateDictIt->second.toTensor().to(targetDevice);
|
||||
|
||||
CHECK(tensor.sizes() == it->second.sizes());
|
||||
CHECK(tensor.dtype() == it->second.dtype());
|
||||
|
||||
allValues_.emplace(name, tensor);
|
||||
};
|
||||
|
||||
for (const auto& name : graph_->signature().parameters()) {
|
||||
validateAndInsert(name);
|
||||
}
|
||||
for (const auto& name : graph_->signature().buffers()) {
|
||||
validateAndInsert(name);
|
||||
}
|
||||
// TensorConstants_ not filled !!
|
||||
}
|
||||
|
||||
void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
|
||||
const {
|
||||
auto& weightMeta = weightsMeta_.at(name);
|
||||
|
||||
CHECK(
|
||||
weightMeta.sizes() == newValue.sizes() ||
|
||||
(skipSizeCheck_ && skipSizeCheck_(name)) ||
|
||||
unusedWeights_.find(name) != unusedWeights_.end())
|
||||
<< "Mismatched sizes for " << name << ": " << weightMeta.sizes() << " vs "
|
||||
<< newValue.sizes();
|
||||
CHECK(
|
||||
weightMeta.dtype() == newValue.dtype() ||
|
||||
(skipDtypeCheck_ && skipDtypeCheck_(name)) ||
|
||||
unusedWeights_.find(name) != unusedWeights_.end())
|
||||
<< "Mismatched dtype for " << name << ": " << weightMeta.dtype() << " vs "
|
||||
<< newValue.dtype();
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(weightMeta.device());
|
||||
if (targetDevice.is_cpu() && targetDevice.has_index()) {
|
||||
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
|
||||
}
|
||||
CHECK(isSameDevice(targetDevice, newValue.device()))
|
||||
<< "Mismatched device for " << name << ": " << targetDevice << " vs "
|
||||
<< newValue.device();
|
||||
}
|
||||
|
||||
void Weights::setValue(const std::string& name, const at::Tensor& newValue) {
|
||||
if (allValues_.find(name) != allValues_.end()) {
|
||||
validateValue(name, newValue);
|
||||
} else {
|
||||
LOG(WARNING) << name << " is not found in the registered weights";
|
||||
}
|
||||
|
||||
allValues_[name] = newValue;
|
||||
}
|
||||
|
||||
void Weights::updateValue(const std::string& name, const at::Tensor& newValue) {
|
||||
auto it = allValues_.find(name);
|
||||
CHECK(it != allValues_.end())
|
||||
<< name << " not found in Weights " << toString();
|
||||
validateValue(name, newValue);
|
||||
|
||||
it->second.copy_(newValue);
|
||||
}
|
||||
|
||||
void Weights::updateValues(
|
||||
const std::unordered_map<std::string, at::Tensor>& newValues) {
|
||||
for (auto& [name, newValue] : newValues) {
|
||||
updateValue(name, newValue);
|
||||
}
|
||||
}
|
||||
|
||||
std::string Weights::toString() const {
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (const auto& [name, _] : allValues_) {
|
||||
ss << name << ", ";
|
||||
}
|
||||
ss << "]";
|
||||
ss << "[";
|
||||
for (const auto& [name, _] : customObjs_) {
|
||||
ss << name << ", ";
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void Weights::validateAllWeightsLoaded() {
|
||||
auto checkNames = [&](const std::vector<std::string>& names) {
|
||||
for (const auto& name : names) {
|
||||
if (unusedWeights_.find(name) != unusedWeights_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto it = allValues_.find(name);
|
||||
CHECK(it != allValues_.end()) << "Missing weight: " << name;
|
||||
CHECK(it->second.defined()) << "Weight not defined: " << name;
|
||||
if (it->second.device().is_meta()) {
|
||||
LOG(WARNING) << "Weight is on meta device: " << name;
|
||||
}
|
||||
}
|
||||
};
|
||||
checkNames(graph_->signature().parameters());
|
||||
checkNames(graph_->signature().buffers());
|
||||
checkNames(graph_->signature().nonPersistentBuffers());
|
||||
checkNames(graph_->signature().tensorConstants());
|
||||
}
|
||||
|
||||
void Weights::updateFoldedConst(std::string_view name, c10::IValue tensor) {
|
||||
foldedConstsMap_[std::string{name}] = std::move(tensor);
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, c10::IValue>& Weights::getFoldedConsts()
|
||||
const {
|
||||
return foldedConstsMap_;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
136
torch/csrc/nativert/executor/Weights.h
Normal file
136
torch/csrc/nativert/executor/Weights.h
Normal file
@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/Placement.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
class Graph;
|
||||
class Value;
|
||||
using WeightVersion = int;
|
||||
|
||||
class Weights {
|
||||
public:
|
||||
explicit Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict = std::nullopt,
|
||||
const Placement& placement = Placement());
|
||||
|
||||
// Arguments
|
||||
// - pytorchStreamReader: the reader for the model archive
|
||||
// - stateDictPath: a map from parameter/buffer/constant name to file path in
|
||||
// the archive
|
||||
// - stateDictPathPrefix: a prefix that will be prepended to paths in
|
||||
// stateDictPathPrefix
|
||||
// - constantPaths: a map from constant name to file path in the archive
|
||||
// - constantPathPrefix: a prefix that will be prepended to paths in
|
||||
// constantPathPrefix
|
||||
// - placement: the device placement of the weights, default to follow the
|
||||
// original device in the weight's metadata
|
||||
explicit Weights(
|
||||
const Graph* graph,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader,
|
||||
const std::unordered_map<std::string, std::string>& stateDictPaths,
|
||||
std::string_view stateDictPathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& constantPaths,
|
||||
std::string_view constantPathPrefix,
|
||||
const Placement& placement = Placement(),
|
||||
std::function<bool(const std::string&)> skipSizeCheck = {},
|
||||
std::function<bool(const std::string&)> skipDtypeCheck = {});
|
||||
|
||||
at::Tensor at(const std::string& name) const;
|
||||
at::Tensor& at(const std::string& name);
|
||||
bool contains(const std::string& name) const;
|
||||
c10::IValue getCustomObj(const std::string& name) const;
|
||||
c10::IValue getCustomObjByFileName(const std::string& name) const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> parameters() const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> buffers() const;
|
||||
|
||||
std::unordered_map<std::string, at::Tensor> attributes() const;
|
||||
|
||||
void loadStateDict(
|
||||
const std::unordered_map<std::string, c10::IValue>& stateDict);
|
||||
|
||||
/*
|
||||
* Replace the value stored at the weight with name "name".
|
||||
*/
|
||||
void setValue(const std::string& name, const at::Tensor& newValue);
|
||||
|
||||
/*
|
||||
* Update the value stored at the weight with name "name".
|
||||
* This is done in-place.
|
||||
*/
|
||||
void updateValue(const std::string& name, const at::Tensor& newValue);
|
||||
|
||||
void updateValues(
|
||||
const std::unordered_map<std::string, at::Tensor>& newValues);
|
||||
|
||||
void validateValue(const std::string& name, const at::Tensor& newValue) const;
|
||||
|
||||
void validateAllWeightsLoaded();
|
||||
|
||||
const std::unordered_map<std::string, c10::IValue>& getFoldedConsts() const;
|
||||
|
||||
C10_ALWAYS_INLINE const c10::FastMap<ValueId, c10::IValue>&
|
||||
getConstFoldedValues() const {
|
||||
return constFoldedValues_;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE void setConstFoldedValue(const ValueId v, c10::IValue iv) {
|
||||
constFoldedValues_.insert_or_assign(v, std::move(iv));
|
||||
}
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
WeightVersion version() const {
|
||||
return version_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class Executor;
|
||||
friend class ExecutionFrame;
|
||||
void updateFoldedConst(std::string_view name, c10::IValue tensor);
|
||||
|
||||
const Graph* graph_;
|
||||
const std::unordered_map<std::string, TensorMeta>& weightsMeta_;
|
||||
Placement placement_;
|
||||
|
||||
// keys are parameter/buffer/constant names, not graph input names!
|
||||
std::unordered_map<std::string, at::Tensor> allValues_;
|
||||
|
||||
std::unordered_map<std::string, c10::IValue> customObjs_;
|
||||
|
||||
// contains CustomClassHolder map from a file name to an arbitray
|
||||
// key in customObjs_ that hold the loaded content of the file.
|
||||
// This is used in AOTIDelegateExecutor.
|
||||
std::unordered_map<std::string, std::string> customObjsPaths_;
|
||||
|
||||
// The liftcycle of folded consts should be tied with the weights from which
|
||||
// it was derived. The ordering of the constant should be consistent with
|
||||
// the output order of const graph.
|
||||
std::vector<c10::IValue> foldedConsts_;
|
||||
std::unordered_map<std::string, c10::IValue> foldedConstsMap_;
|
||||
|
||||
c10::FastMap<ValueId, c10::IValue> constFoldedValues_;
|
||||
|
||||
// unique version number for this instance of weight
|
||||
const WeightVersion version_;
|
||||
|
||||
// every instance of Weight has a unique version number
|
||||
static WeightVersion globalVersion_;
|
||||
|
||||
std::function<bool(const std::string&)> skipSizeCheck_ = {};
|
||||
std::function<bool(const std::string&)> skipDtypeCheck_ = {};
|
||||
|
||||
// save the names of unused weights
|
||||
std::unordered_set<std::string> unusedWeights_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
1479
torch/csrc/nativert/graph/Graph.cpp
Normal file
1479
torch/csrc/nativert/graph/Graph.cpp
Normal file
File diff suppressed because it is too large
Load Diff
695
torch/csrc/nativert/graph/Graph.h
Normal file
695
torch/csrc/nativert/graph/Graph.h
Normal file
@ -0,0 +1,695 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/IntrusiveList.h"
|
||||
#include "torch/csrc/nativert/executor/Placement.h"
|
||||
#include "torch/csrc/nativert/graph/GraphSignature.h"
|
||||
#include "torch/csrc/nativert/graph/TensorMeta.h"
|
||||
|
||||
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
using NodeIndex = size_t;
|
||||
|
||||
class Value;
|
||||
|
||||
class Type {
|
||||
public:
|
||||
enum Kind {
|
||||
None,
|
||||
Tensor,
|
||||
TensorList,
|
||||
OptionalTensorList,
|
||||
SymInt,
|
||||
SymIntList,
|
||||
SymBool,
|
||||
SymFloat,
|
||||
CustomObj,
|
||||
};
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const Type& ty);
|
||||
|
||||
/* implicit */ Type(Kind kind) : kind_(kind) {}
|
||||
|
||||
explicit Type(Kind kind, const std::string& classFqn)
|
||||
: kind_(kind), classFqn_(classFqn) {
|
||||
CHECK(kind == Kind::CustomObj);
|
||||
CHECK(!classFqn.empty());
|
||||
}
|
||||
|
||||
Kind kind() const {
|
||||
return kind_;
|
||||
}
|
||||
|
||||
std::string classFqn() const {
|
||||
CHECK(kind_ == Kind::CustomObj)
|
||||
<< "Only CustomObj type can have classFqn, got " << kind_;
|
||||
return classFqn_;
|
||||
}
|
||||
|
||||
private:
|
||||
const Kind kind_;
|
||||
const std::string classFqn_;
|
||||
};
|
||||
|
||||
bool operator==(const Type& left, const Type& right);
|
||||
|
||||
// These are all the constant types that are allowed as attributes on Nodes.
|
||||
struct None {};
|
||||
// None always equals itself
|
||||
inline bool operator==(const None&, const None&) {
|
||||
return true;
|
||||
}
|
||||
using Constant = std::variant<
|
||||
None,
|
||||
int64_t,
|
||||
std::vector<int64_t>,
|
||||
double,
|
||||
std::vector<double>,
|
||||
std::string,
|
||||
c10::ScalarType,
|
||||
c10::MemoryFormat,
|
||||
c10::Layout,
|
||||
c10::Device,
|
||||
bool,
|
||||
std::vector<bool>,
|
||||
std::vector<std::string>,
|
||||
std::unique_ptr<Graph>>;
|
||||
|
||||
c10::IValue constantToIValue(const Constant& constant);
|
||||
|
||||
class Node;
|
||||
|
||||
/**
|
||||
* Represents a single symbolic value (tensor/symint/list of them). Values are
|
||||
* inputs and outputs of Nodes.
|
||||
*/
|
||||
using ValueId = int;
|
||||
class Value {
|
||||
public:
|
||||
explicit Value(ValueId id, std::string name, Type t, Node* producer)
|
||||
: name_(std::move(name)), id_(id), type_(t), producer_(producer) {
|
||||
TORCH_CHECK_EQ(name_, this->name());
|
||||
}
|
||||
|
||||
// Each Value should be uniquely created and managed by a Graph. It's an anti
|
||||
// pattern to copy/move Value instances anyway.
|
||||
Value(Value&&) = delete;
|
||||
Value& operator=(Value&&) = delete;
|
||||
Value(const Value&) = delete;
|
||||
Value& operator=(Value&) = delete;
|
||||
|
||||
Type type() const {
|
||||
return type_;
|
||||
}
|
||||
|
||||
ValueId id() const {
|
||||
return id_;
|
||||
}
|
||||
|
||||
std::string_view name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
const Node* producer(bool resolve_folded = false) const {
|
||||
return !resolve_folded && isFolded() ? nullptr : producer_;
|
||||
}
|
||||
|
||||
Node* producer() {
|
||||
return producer_;
|
||||
}
|
||||
|
||||
void addUser(Node* node);
|
||||
void eraseUser(Node* node);
|
||||
void eraseAllUsers() {
|
||||
users_.clear();
|
||||
}
|
||||
|
||||
// Fatals if the value is not a TensorList
|
||||
std::vector<const Value*> getListElements() const;
|
||||
|
||||
const auto& users() const {
|
||||
return users_;
|
||||
}
|
||||
|
||||
auto& users() {
|
||||
return users_;
|
||||
}
|
||||
|
||||
void setId(ValueId newId) {
|
||||
// This should only be used inside the renumberValues pass
|
||||
id_ = newId;
|
||||
}
|
||||
|
||||
void setIsFolded() {
|
||||
isFolded_ = true;
|
||||
}
|
||||
|
||||
bool isFolded() const {
|
||||
return isFolded_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<<(std::ostream& out, const Value& v);
|
||||
std::string name_;
|
||||
bool isFolded_{false};
|
||||
ValueId id_;
|
||||
Type type_;
|
||||
Node* producer_;
|
||||
// All nodes which have this value as input.
|
||||
// Note that this is a vector to avoid nondeterminism in iteration, but
|
||||
// probably should be an unordered set given usage patterns. If this becomes a
|
||||
// perf problem we should revise.
|
||||
std::vector<Node*> users_;
|
||||
};
|
||||
|
||||
struct NamedArgument {
|
||||
std::string name;
|
||||
Value* value;
|
||||
};
|
||||
|
||||
struct Attribute {
|
||||
std::string name;
|
||||
Constant value;
|
||||
};
|
||||
|
||||
class Graph;
|
||||
|
||||
/**
|
||||
* Node represents a single unit of execution, typically a PyTorch operator.
|
||||
*/
|
||||
class Node : public IntrusiveListHook {
|
||||
public:
|
||||
Node(
|
||||
Graph* owningGraph,
|
||||
std::string target,
|
||||
std::vector<NamedArgument> inputs,
|
||||
std::unordered_map<std::string, std::string> metadata);
|
||||
|
||||
std::string_view target() const {
|
||||
return target_;
|
||||
}
|
||||
|
||||
void setTarget(std::string_view target) {
|
||||
target_ = target;
|
||||
}
|
||||
|
||||
const auto& inputs() const {
|
||||
return inputs_;
|
||||
}
|
||||
|
||||
auto& inputs() {
|
||||
return inputs_;
|
||||
}
|
||||
|
||||
// NOTE: this invalidates spans given out by inputs()
|
||||
Value* addInput(NamedArgument input);
|
||||
void addInputs(const std::vector<NamedArgument>& inputs);
|
||||
|
||||
// NOTE: this invalidates spans given out by attributes()
|
||||
void addAttribute(Attribute attr);
|
||||
|
||||
// NOTE: this is ONLY for graph's constant inputs and NOT the common case
|
||||
void addOutput(void);
|
||||
|
||||
Value* addOutput(Type type);
|
||||
|
||||
// NOTE: this invalidates spans given out by outputs()
|
||||
Value* addOutput(std::string_view name, Type type);
|
||||
|
||||
size_t numInputs() const {
|
||||
return inputs_.size();
|
||||
}
|
||||
|
||||
size_t numOutputs() const {
|
||||
return outputs_.size();
|
||||
}
|
||||
|
||||
// Return the next node in the Graph's node ordering.
|
||||
// NOTE: Calling next on the last node (prim.Output) returns nullptr.
|
||||
Node* next();
|
||||
const Node* next() const;
|
||||
|
||||
// Return the previous node in the Graph's node ordering.
|
||||
// NOTE: Calling prev on the first node (prim.Input) returns nullptr.
|
||||
Node* prev();
|
||||
const Node* prev() const;
|
||||
|
||||
bool isBefore(const Node* n) const;
|
||||
|
||||
std::vector<Node*> producers() const;
|
||||
std::vector<Node*> users() const;
|
||||
|
||||
// Returns nullptr if `name` is not an input
|
||||
const NamedArgument* tryGetInput(std::string_view name) const;
|
||||
// Fatals if `name` is not an input
|
||||
const NamedArgument& getInput(std::string_view name) const;
|
||||
|
||||
const auto& attributes() const {
|
||||
return attributes_;
|
||||
}
|
||||
|
||||
// Returns nullptr if `name` is not an attribute
|
||||
const Attribute* tryGetAttribute(std::string_view name) const;
|
||||
// Fatals if `name` is not an attribute
|
||||
const Attribute& getAttribute(std::string_view name) const;
|
||||
|
||||
const auto& outputs() const {
|
||||
return outputs_;
|
||||
}
|
||||
|
||||
void applyDevicePlacement(const Placement& placement);
|
||||
|
||||
std::optional<std::string_view> getMetadata(std::string_view key) const {
|
||||
return metadata_.find(std::string{key}) != metadata_.end()
|
||||
? std::optional(std::string_view{metadata_.at(std::string{key})})
|
||||
: std::nullopt;
|
||||
}
|
||||
|
||||
Graph* owningGraph() {
|
||||
return owningGraph_;
|
||||
}
|
||||
|
||||
const Graph* owningGraph() const {
|
||||
return owningGraph_;
|
||||
}
|
||||
|
||||
void destroy();
|
||||
|
||||
const std::unordered_map<std::string, std::string>& metadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
std::string toString() const {
|
||||
std::stringstream ss;
|
||||
ss << *this;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void updateInputName(std::string_view oldName, std::string_view newName) {
|
||||
for (auto& input : inputs_) {
|
||||
if (input.name == oldName) {
|
||||
input.name = newName;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void updateAttributeName(std::string_view oldName, std::string_view newName) {
|
||||
for (auto& attr : attributes_) {
|
||||
if (attr.name == oldName) {
|
||||
attr.name = newName;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<<(std::ostream& out, const Node& n);
|
||||
Graph* owningGraph_;
|
||||
|
||||
// Target used to retrieve the actual thing to execute.
|
||||
// If an aten operator, we expect this to be fully qualified, including an
|
||||
// overload name, e.g. "aten.unsqueeze.default"
|
||||
std::string target_;
|
||||
// *Symbolic* inputs to this node. NOTE: this does not match the ATen operator
|
||||
// schema inputs directly. It only represents things that actually participate
|
||||
// in dataflow, like tensors/symints and lists thereof.
|
||||
//
|
||||
// The "name" of the NamedArgument refers to the name of the parameter.
|
||||
std::vector<NamedArgument> inputs_;
|
||||
// Constant inputs to the node. The "name" of the Attribute refers to the
|
||||
// name of the parameter.
|
||||
std::vector<Attribute> attributes_;
|
||||
std::vector<Value*> outputs_;
|
||||
|
||||
// Extra bits of info added to the node. Contents that are guaranteed will be
|
||||
// eventually moved to a first-class field on the thrift struct.
|
||||
// Current contents:
|
||||
// "stack_trace" => original Python source traceback
|
||||
std::unordered_map<std::string, std::string> metadata_;
|
||||
};
|
||||
|
||||
/**
|
||||
* A representation of a model's computation graph. This representation is
|
||||
* designed to facilicate transformation/analysis.
|
||||
*
|
||||
* Ownership semantics:
|
||||
* - Graph owns Nodes and Values
|
||||
* - Nodes own their constant attributes (which we treat as value types)
|
||||
* - Nodes have non-owning pointers back to the graph.
|
||||
*
|
||||
* NOTE: this class is marked noncopyable/nonmovable and only can be
|
||||
* heap-allocated via `createGraph()`. This is to ensure stability of
|
||||
* back-pointers held by Nodes/Values.
|
||||
*/
|
||||
class Graph {
|
||||
public:
|
||||
static std::unique_ptr<Graph> createGraph() {
|
||||
return std::unique_ptr<Graph>(new Graph());
|
||||
}
|
||||
|
||||
Graph(const Graph&) = delete;
|
||||
Graph& operator=(const Graph&) = delete;
|
||||
Graph(Graph&&) = delete;
|
||||
Graph& operator=(Graph&&) = delete;
|
||||
~Graph() = default;
|
||||
|
||||
// NOTE: this invalidates spans given out by inputs()
|
||||
Value* addInput(std::string_view name, Type type);
|
||||
|
||||
// NOTE: this is ONLY for graph's constant inputs and NOT the common case
|
||||
void addInput(void);
|
||||
|
||||
// NOTE: this invalidates spans given out by outputs()
|
||||
Value* addOutput(Value* v);
|
||||
|
||||
void addConstantOutput(Constant& c);
|
||||
|
||||
// Create and insert a node at insertionPoint_
|
||||
Node* insertNode(
|
||||
std::string target,
|
||||
std::vector<NamedArgument> inputs = {},
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
|
||||
// Returns the inserted node.
|
||||
Node* insertBefore(Node* toInsert, Node* insertionPoint);
|
||||
// Returns the inserted node.
|
||||
Node* insertAfter(Node* toInsert, Node* insertionPoint);
|
||||
// Insert at the insertionPoint. Returns the inserted node.
|
||||
Node* insert(Node* toInsert);
|
||||
|
||||
// Create a node without inserting it into the execution graph.
|
||||
Node* createNode(
|
||||
std::string target,
|
||||
std::vector<NamedArgument> inputs = {},
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
|
||||
Value* createConstantSymIntValue(int value);
|
||||
|
||||
Node* createListPack(std::vector<Value*> inputs, Type inputType);
|
||||
|
||||
Node* createOptionalListPack(std::vector<Value*> inputs);
|
||||
|
||||
size_t numValues() const {
|
||||
return values_.size();
|
||||
}
|
||||
|
||||
// throws on missing name
|
||||
Value* getValue(std::string_view name) const;
|
||||
// returns nullptr on missing name
|
||||
Value* tryGetValue(std::string_view name) const;
|
||||
|
||||
const std::unordered_map<ValueId, int> getConstantSymIntValues() const {
|
||||
return constantSymIntValues_;
|
||||
}
|
||||
|
||||
Value*
|
||||
addValue(const std::optional<std::string>& name, Type type, Node* producer);
|
||||
void removeValue(Value* value);
|
||||
|
||||
void replaceAllUses(Value* old, Value* replacement);
|
||||
void replaceAllUsesAfterNode(Value* old, Value* replacement, Node* afterThis);
|
||||
void removeNode(Node* node);
|
||||
|
||||
void applyDevicePlacement(const Placement& placement);
|
||||
|
||||
std::string getUniqueValueName();
|
||||
|
||||
ValueId getNextValueId() {
|
||||
return uniqueValueId_++;
|
||||
}
|
||||
|
||||
// NOTE: this range can be invalidated by mutations to the graph.
|
||||
const auto& inputs() const {
|
||||
return inputNode_->outputs();
|
||||
}
|
||||
|
||||
c10::ArrayRef<const Value*> userInputs() const {
|
||||
size_t offset = signature().inputsToWeights().size() +
|
||||
signature().inputsToCustomObjs().size();
|
||||
return {inputs().data() + offset, inputs().data() + inputs().size()};
|
||||
}
|
||||
|
||||
c10::ArrayRef<const Value*> weightValues() const {
|
||||
return {
|
||||
inputs().data(),
|
||||
inputs().data() + signature().inputsToWeights().size()};
|
||||
}
|
||||
|
||||
// Return a bidirectional range over `const Value*`
|
||||
// NOTE: this range can be invalidated by mutations to the graph.
|
||||
auto outputs() const {
|
||||
std::vector<const Value*> ret;
|
||||
ret.reserve(outputNode_->inputs().size());
|
||||
for (const auto& namedArg : outputNode_->inputs()) {
|
||||
ret.push_back(namedArg.value);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Return a bidirectional range over `Value*`
|
||||
// NOTE: this range can be invalidated by mutations to the graph.
|
||||
auto outputs() {
|
||||
std::vector<Value*> ret;
|
||||
ret.reserve(outputNode_->inputs().size());
|
||||
for (const auto& namedArg : outputNode_->inputs()) {
|
||||
ret.push_back(namedArg.value);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
const auto& userOutputs() const {
|
||||
return userOutputs_;
|
||||
}
|
||||
|
||||
// Return a list over `const Node&`.
|
||||
// NOTE: this can be invalidated by mutations to the graph.
|
||||
const auto& nodes() const {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
auto& nodes() {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
// Return a forward range over `const Value*`.
|
||||
// NOTE: this range can be invalidated by mutations to the graph.
|
||||
auto values() const {
|
||||
std::vector<const Value*> ret;
|
||||
ret.reserve(values_.size());
|
||||
for (const auto& [_, value] : values_) {
|
||||
ret.push_back(value.get());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* inputNode() {
|
||||
return inputNode_;
|
||||
}
|
||||
|
||||
Node* outputNode() {
|
||||
return outputNode_;
|
||||
}
|
||||
|
||||
const Node* outputNode() const {
|
||||
return outputNode_;
|
||||
}
|
||||
|
||||
// Assert various graph invariants
|
||||
void lint() const;
|
||||
|
||||
void cleanupDeadNodes();
|
||||
|
||||
void finalize();
|
||||
|
||||
Node* insertionPoint() {
|
||||
// This should never happen, since the last-most insertion point is the
|
||||
// prim.Outputs node, not end().
|
||||
CHECK(insertBefore_ != nodes_.end());
|
||||
auto& node = *insertBefore_;
|
||||
return &node;
|
||||
}
|
||||
|
||||
void setInsertionPoint(Node* n) {
|
||||
CHECK(n != inputNode_) << "can't insert before prim.Input";
|
||||
insertBefore_ = nodes_.iterator_to(*n);
|
||||
}
|
||||
|
||||
void setInsertionPointAfter(Node* n) {
|
||||
CHECK(n != outputNode_) << "can't insert after prim.Output";
|
||||
auto it = nodes_.iterator_to(*n);
|
||||
++it;
|
||||
insertBefore_ = it;
|
||||
}
|
||||
|
||||
// Return the next node in the Graph's node ordering.
|
||||
// NOTE: Calling on the last node (prim.Output) returns nullptr.
|
||||
Node* nodeAfter(Node* n);
|
||||
const Node* nodeAfter(const Node* n) const;
|
||||
|
||||
// Return the previous node in the Graph's node ordering.
|
||||
// NOTE: Calling on the first node (prim.Input) returns nullptr.
|
||||
Node* nodeBefore(Node* n);
|
||||
const Node* nodeBefore(const Node* n) const;
|
||||
|
||||
// Clone each node from subgraph (except prim.Input/prim.Output) into current
|
||||
// graph.
|
||||
// @param subgraph: the subgraph to be cloned
|
||||
// @param inputs: values from the target graph that will serve as the
|
||||
// subgraph's inputs
|
||||
// @param valueMap: a map from the cloned subgraph's values to the target
|
||||
// graph's values
|
||||
std::vector<Value*> insertGraph(
|
||||
const Graph& subgraph,
|
||||
std::vector<Value*> inputs,
|
||||
std::unordered_map<const Value*, Value*>& valueMap);
|
||||
|
||||
const GraphSignature& signature() const {
|
||||
return signature_;
|
||||
}
|
||||
|
||||
void setSignature(GraphSignature signature) {
|
||||
signature_ = std::move(signature);
|
||||
}
|
||||
|
||||
void setWeightsMeta(
|
||||
const std::unordered_map<std::string, torch::_export::TensorMeta>&
|
||||
tensorsMeta) {
|
||||
for (auto [name, tensorMeta] : tensorsMeta) {
|
||||
weightsMeta_.emplace(name, TensorMeta{tensorMeta});
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, TensorMeta>& weightsMeta() const {
|
||||
return weightsMeta_;
|
||||
}
|
||||
|
||||
std::vector<TensorMeta> userInputsMeta() const {
|
||||
std::vector<TensorMeta> userInputsMeta;
|
||||
for (auto inputName : signature_.userInputs()) {
|
||||
userInputsMeta.push_back(tensorValuesMeta_.at(inputName));
|
||||
}
|
||||
return userInputsMeta;
|
||||
}
|
||||
|
||||
void setTensorValuesMeta(
|
||||
const std::unordered_map<std::string, torch::_export::TensorMeta>&
|
||||
tensorsMeta) {
|
||||
for (auto [name, tensorMeta] : tensorsMeta) {
|
||||
tensorValuesMeta_.emplace(name, TensorMeta{tensorMeta});
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, TensorMeta>& tensorValuesMeta() const {
|
||||
return tensorValuesMeta_;
|
||||
}
|
||||
|
||||
std::string toString() const {
|
||||
std::stringstream ss;
|
||||
ss << *this;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* Reassigns IDs to every Value in this Graph so that they are contiguous from
|
||||
* 0..(numValues()-1). Should be used after values are removed
|
||||
*/
|
||||
void renumberValues();
|
||||
|
||||
private:
|
||||
Graph();
|
||||
friend std::ostream& operator<<(std::ostream& out, const Graph& g);
|
||||
GraphSignature signature_;
|
||||
|
||||
// keys are parameters, buffers, tensor_constants' names
|
||||
std::unordered_map<std::string, TensorMeta> weightsMeta_;
|
||||
|
||||
// keys are tensor_values' names
|
||||
std::unordered_map<std::string, TensorMeta> tensorValuesMeta_;
|
||||
|
||||
// Node lifetime is managed by nodesOwner_, but the actual ordering is
|
||||
// maintained intrusively using nodes_.
|
||||
// This is to facilitate quick insertion before/after a given Node*.
|
||||
std::vector<std::unique_ptr<Node>> nodesOwner_;
|
||||
IntrusiveList<Node> nodes_;
|
||||
// The current insertion point. New nodes are inserted before this node.
|
||||
// Defaults to prim.Output.
|
||||
IntrusiveList<Node>::iterator insertBefore_;
|
||||
|
||||
// Graphs always start with an input and output node.
|
||||
// "prim.input() -> Value[]" take no input, and produces some outputs. AKA
|
||||
// "source“ of a graph.
|
||||
Node* inputNode_; // target: prim.Input
|
||||
// "prim.output(Value[]) -> None", take some inputs, but produce no output.
|
||||
// AKA "sink" of a graph.
|
||||
Node* outputNode_; // target: prim.Output
|
||||
|
||||
std::unordered_map<std::string, std::unique_ptr<Value>> values_;
|
||||
// constantSymIntValues_ is a subset of values_
|
||||
std::unordered_map<ValueId, int> constantSymIntValues_;
|
||||
// Output values of the graph, which is a subset of values_.
|
||||
std::vector<std::variant<Value*, Constant>> userOutputs_;
|
||||
// Output constant values of the graph
|
||||
std::vector<Constant> constantOutputs_;
|
||||
|
||||
size_t uniqueValueName_ = 0;
|
||||
|
||||
ValueId uniqueValueId_ = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* Scoped utility class for setting temporary insertion points.
|
||||
*
|
||||
* Use like:
|
||||
* {
|
||||
* InsertingAfter guard(node)
|
||||
* graph.insertNode(...) // this will be inserted after `node`.
|
||||
* }
|
||||
*/
|
||||
class InsertingAfter {
|
||||
public:
|
||||
explicit InsertingAfter(Node* n)
|
||||
: insertAfter_(n), prev_(n->owningGraph()->insertionPoint()) {
|
||||
insertAfter_->owningGraph()->setInsertionPointAfter(insertAfter_);
|
||||
}
|
||||
~InsertingAfter() {
|
||||
insertAfter_->owningGraph()->setInsertionPoint(prev_);
|
||||
}
|
||||
|
||||
private:
|
||||
Node* insertAfter_;
|
||||
Node* prev_;
|
||||
};
|
||||
|
||||
inline constexpr std::string_view kMemoryFormatPrefix = "MemoryFormat::";
|
||||
inline constexpr std::string_view kLayoutPrefix = "Layout::";
|
||||
inline constexpr std::string_view kDevicePrefix = "Device";
|
||||
inline constexpr std::string_view kScalarTypePrefix = "ScalarType::";
|
||||
|
||||
/**
|
||||
* Debug format serialization. The format here is intended to be human readable
|
||||
* and easy to work with, and is intended for debugging and testing only.
|
||||
* If you want stable serialization, use the thrift conversion utils.
|
||||
*
|
||||
* NOTE: node metadata currently not serialized
|
||||
*/
|
||||
std::string graphToString(const Graph& g, bool include_signature = false);
|
||||
std::unique_ptr<Graph> stringToGraph(std::string_view source);
|
||||
|
||||
// Standalone functions to parse common constructs
|
||||
// Parse something that looks like `Device{cuda:1}` to a thrift device.
|
||||
c10::Device convertDevice(std::string_view symbol);
|
||||
// We have separate functions for parsing atomic and list constants because
|
||||
// there are restrictive rules about which constants can go in lists (i.e.
|
||||
// it's not recursive).
|
||||
Constant convertAtomicConstant(std::string_view symbol);
|
||||
Constant convertListConstant(std::string_view symbol);
|
||||
|
||||
} // namespace torch::nativert
|
||||
162
torch/csrc/nativert/graph/GraphPasses.cpp
Normal file
162
torch/csrc/nativert/graph/GraphPasses.cpp
Normal file
@ -0,0 +1,162 @@
|
||||
#include "torch/csrc/nativert/graph/GraphPasses.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/String.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
bool isScalar(const Constant& c) {
|
||||
return std::holds_alternative<int64_t>(c) ||
|
||||
std::holds_alternative<double>(c);
|
||||
}
|
||||
|
||||
bool isScalar(const Value& v) {
|
||||
return v.type() == Type::SymInt || v.type() == Type::SymFloat;
|
||||
}
|
||||
|
||||
bool schemaTypeMatch(const c10::FunctionSchema& schema, const Node& node) {
|
||||
for (const auto& input : node.inputs()) {
|
||||
// The number of arguments is always O(10), so we can just do a linear scan.
|
||||
for (const auto& schemaArg : schema.arguments()) {
|
||||
if (schemaArg.name() == input.name) {
|
||||
if (schemaArg.type() == c10::TensorType::get() && input.value &&
|
||||
isScalar(*input.value)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& constant : node.attributes()) {
|
||||
for (const auto& schemaArg : schema.arguments()) {
|
||||
if (schemaArg.name() == constant.name) {
|
||||
if (schemaArg.type() == c10::TensorType::get() &&
|
||||
isScalar(constant.value)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// PT2 intentionally broadcast things like aten.sub.Scalar
|
||||
// to aten.sub.Tensor. https://github.com/pytorch/pytorch/issues/90923.
|
||||
std::string selectScalarOverloadName(const Node& node) {
|
||||
// Copied from torch/csrc/utils/python_arg_parser.cpp
|
||||
// torch::should_allow_numbers_as_tensors() to workaround
|
||||
// some linking issues.
|
||||
static std::unordered_set<std::string> allowed = {
|
||||
"add",
|
||||
"add_",
|
||||
"add_out",
|
||||
"div",
|
||||
"div_",
|
||||
"div_out",
|
||||
"divide",
|
||||
"divide_",
|
||||
"divide_out", // alias of div
|
||||
"mul",
|
||||
"mul_",
|
||||
"mul_out",
|
||||
"multiply",
|
||||
"multiply_",
|
||||
"multiply_out", // alias of mul
|
||||
"sub",
|
||||
"sub_",
|
||||
"sub_out",
|
||||
"subtract",
|
||||
"subtract_",
|
||||
"subtract_out", // alias of sub
|
||||
"true_divide",
|
||||
"true_divide_",
|
||||
"true_divide_out",
|
||||
"to",
|
||||
"_to_copy",
|
||||
"copy_",
|
||||
"copy",
|
||||
"floor_divide",
|
||||
"floor_divide_",
|
||||
"floor_divide_out",
|
||||
"_conj"};
|
||||
std::vector<std::string_view> atoms = split(node.target(), '.');
|
||||
TORCH_CHECK_GE(atoms.size(), 3);
|
||||
|
||||
std::string ns = std::string{atoms[atoms.size() - 3]};
|
||||
std::string opName = std::string{atoms[atoms.size() - 2]};
|
||||
std::string overloadName = std::string{atoms[atoms.size() - 1]};
|
||||
if (overloadName != "Tensor" && overloadName != "Tensor_Tensor") {
|
||||
return overloadName;
|
||||
}
|
||||
if (allowed.find(std::string{opName}) == allowed.end()) {
|
||||
return overloadName;
|
||||
}
|
||||
auto op = c10::Dispatcher::singleton().findSchemaOrThrow(
|
||||
fmt::format("{}::{}", ns, opName.c_str()).c_str(), overloadName.c_str());
|
||||
if (schemaTypeMatch(op.schema(), node)) {
|
||||
return overloadName;
|
||||
}
|
||||
for (const auto& variant : {"Scalar", "Scalar_Tensor", "Tensor_Scalar"}) {
|
||||
if (auto schema = c10::Dispatcher::singleton().findSchema(
|
||||
{fmt::format("{}::{}", ns, opName.c_str()).c_str(), variant})) {
|
||||
if (schemaTypeMatch(schema->schema(), node)) {
|
||||
return variant;
|
||||
}
|
||||
}
|
||||
}
|
||||
return overloadName;
|
||||
}
|
||||
|
||||
void selectScalarOverload(Graph* graph) {
|
||||
for (auto& node : graph->nodes()) {
|
||||
for (auto& attr : node.attributes()) {
|
||||
if (std::holds_alternative<std::unique_ptr<Graph>>(attr.value)) {
|
||||
selectScalarOverload(
|
||||
std::get<std::unique_ptr<Graph>>(attr.value).get());
|
||||
}
|
||||
}
|
||||
|
||||
auto target = node.target();
|
||||
std::vector<std::string_view> atoms = split(target, '.');
|
||||
|
||||
size_t numAtoms = atoms.size();
|
||||
if (numAtoms != 5) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string_view ns = atoms[numAtoms - 3];
|
||||
const std::string_view opName = atoms[numAtoms - 2];
|
||||
if (atoms[0] != "torch" || atoms[1] != "ops" || ns != "aten") {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto overloadName = selectScalarOverloadName(node);
|
||||
if (overloadName != atoms[numAtoms - 1]) {
|
||||
node.setTarget(
|
||||
fmt::format("torch.ops.{}.{}.{}", ns, opName, overloadName));
|
||||
} else if (ns == "aten" && opName == "sub" && overloadName == "Tensor") {
|
||||
// Special case for aten.sub.Tensor.
|
||||
if (auto i = node.tryGetInput("self")) {
|
||||
if (isScalar(*i->value)) {
|
||||
node.updateInputName("self", "other");
|
||||
node.updateInputName("other", "self");
|
||||
node.setTarget("torch.ops.aten.rsub.Scalar");
|
||||
}
|
||||
}
|
||||
if (auto a = node.tryGetAttribute("self")) {
|
||||
if (isScalar(a->value)) {
|
||||
node.updateAttributeName("self", "other");
|
||||
node.updateInputName("other", "self");
|
||||
node.setTarget("torch.ops.aten.rsub.Scalar");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
11
torch/csrc/nativert/graph/GraphPasses.h
Normal file
11
torch/csrc/nativert/graph/GraphPasses.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
void selectScalarOverload(Graph* graph);
|
||||
|
||||
std::string selectScalarOverloadName(const Node& node);
|
||||
|
||||
} // namespace torch::nativert
|
||||
450
torch/csrc/nativert/graph/GraphSignature.cpp
Normal file
450
torch/csrc/nativert/graph/GraphSignature.cpp
Normal file
@ -0,0 +1,450 @@
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/ranges.h>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "c10/util/Exception.h"
|
||||
#include "torch/csrc/nativert/graph/GraphSignature.h"
|
||||
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
|
||||
bool isSymbolicOutput(torch::_export::Argument::Tag t) {
|
||||
switch (t) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR:
|
||||
case torch::_export::Argument::Tag::AS_TENSORS:
|
||||
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string> getSpecDetails(
|
||||
const torch::_export::InputSpec& inputSpec) {
|
||||
// Retrieve the argument name and spec tag name
|
||||
std::string argName;
|
||||
std::string tagName;
|
||||
switch (inputSpec.tag()) {
|
||||
case torch::_export::InputSpec::Tag::PARAMETER:
|
||||
argName = inputSpec.get_parameter().get_arg().get_name();
|
||||
tagName = "PARAMETER";
|
||||
break;
|
||||
case torch::_export::InputSpec::Tag::BUFFER:
|
||||
argName = inputSpec.get_buffer().get_arg().get_name();
|
||||
tagName = "BUFFER";
|
||||
break;
|
||||
case torch::_export::InputSpec::Tag::TENSOR_CONSTANT:
|
||||
argName = inputSpec.get_tensor_constant().get_arg().get_name();
|
||||
tagName = "TENSOR_CONSTANT";
|
||||
break;
|
||||
case torch::_export::InputSpec::Tag::CUSTOM_OBJ:
|
||||
tagName = "CUSTOM_OBJ";
|
||||
argName = inputSpec.get_custom_obj().get_arg().get_name();
|
||||
break;
|
||||
case torch::_export::InputSpec::Tag::USER_INPUT:
|
||||
tagName = "USER_INPUT";
|
||||
if (inputSpec.get_user_input().get_arg().tag() ==
|
||||
torch::_export::Argument::Tag::AS_TENSOR) {
|
||||
argName =
|
||||
inputSpec.get_user_input().get_arg().get_as_tensor().get_name();
|
||||
} else if (
|
||||
inputSpec.get_user_input().get_arg().tag() ==
|
||||
torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
|
||||
argName =
|
||||
inputSpec.get_user_input().get_arg().get_as_custom_obj().get_name();
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported USER_INPUT argument type.");
|
||||
}
|
||||
break;
|
||||
case torch::_export::InputSpec::Tag::CONSTANT_INPUT:
|
||||
argName = inputSpec.get_constant_input().get_name();
|
||||
tagName = "CONSTANT_INPUT";
|
||||
break;
|
||||
case torch::_export::InputSpec::Tag::TOKEN:
|
||||
throw std::runtime_error("Token inputs not implemented yet");
|
||||
default:
|
||||
throw std::runtime_error("Unknown InputSpec tag encountered.");
|
||||
}
|
||||
return std::make_pair(argName, tagName);
|
||||
}
|
||||
|
||||
void checkInputOrders(
|
||||
const std::vector<torch::_export::InputSpec>& inputSpecs) {
|
||||
// Map each tag to its index in the expected order
|
||||
std::unordered_map<torch::_export::InputSpec::Tag, size_t> tagOrderMap = {
|
||||
{torch::_export::InputSpec::Tag::TOKEN, 0},
|
||||
{torch::_export::InputSpec::Tag::PARAMETER, 1},
|
||||
{torch::_export::InputSpec::Tag::BUFFER, 2},
|
||||
{torch::_export::InputSpec::Tag::TENSOR_CONSTANT, 3},
|
||||
{torch::_export::InputSpec::Tag::CUSTOM_OBJ, 4}};
|
||||
size_t currentOrderIndex = 0;
|
||||
bool seenNonPersistentBuffer = false;
|
||||
for (const auto& inputSpec : inputSpecs) {
|
||||
if (inputSpec.tag() == torch::_export::InputSpec::Tag::USER_INPUT ||
|
||||
inputSpec.tag() == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
|
||||
continue;
|
||||
}
|
||||
auto it = tagOrderMap.find(inputSpec.tag());
|
||||
if (it == tagOrderMap.end()) {
|
||||
throw std::runtime_error("Unknown InputSpec tag encountered.");
|
||||
}
|
||||
size_t tagIndex = it->second;
|
||||
|
||||
if (tagIndex < currentOrderIndex) {
|
||||
auto [argName, tagName] = getSpecDetails(inputSpec);
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Input arg {} with InputSpec {} is out of order!", argName, tagName));
|
||||
}
|
||||
currentOrderIndex = tagIndex;
|
||||
// Additional check for buffers
|
||||
if (inputSpec.tag() == torch::_export::InputSpec::Tag::BUFFER) {
|
||||
if (!inputSpec.get_buffer().get_persistent()) {
|
||||
seenNonPersistentBuffer = true;
|
||||
} else if (
|
||||
inputSpec.get_buffer().get_persistent() && seenNonPersistentBuffer) {
|
||||
throw std::runtime_error(
|
||||
"Persistent buffer found after a non-persistent buffer. "
|
||||
"Persistent buffers must come before non-persistent buffers.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void checkInputNames(
|
||||
const std::set<std::string>& sigNames,
|
||||
const std::set<std::string>& graphNames) {
|
||||
if (sigNames == graphNames) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string errorMsg =
|
||||
"Error: Value name difference detected between graph signature and graph nodes:\n";
|
||||
errorMsg += "Signature value names:\n";
|
||||
errorMsg += fmt::format("[{}]\n", fmt::join(sigNames, ", "));
|
||||
errorMsg += "Graph node names:\n";
|
||||
errorMsg += fmt::format("[{}]", fmt::join(graphNames, ", "));
|
||||
LOG(FATAL) << errorMsg;
|
||||
};
|
||||
|
||||
void checkOutputNames(
|
||||
const std::set<std::optional<std::string>>& sigNames,
|
||||
const std::set<std::string>& graphNames) {
|
||||
std::vector<std::string> validNames;
|
||||
for (const auto& nameOpt : sigNames) {
|
||||
if (nameOpt.has_value()) {
|
||||
validNames.push_back(*nameOpt);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& name : validNames) {
|
||||
if (graphNames.find(name) == graphNames.end()) {
|
||||
std::string errorMsg =
|
||||
"Error: Value name difference detected between graph signature and graph nodes:\n";
|
||||
errorMsg += "Signature value names:\n";
|
||||
errorMsg += fmt::format("[{}]\n", fmt::join(validNames, ", "));
|
||||
errorMsg += "Graph node names:\n";
|
||||
errorMsg += fmt::format("[{}]", fmt::join(graphNames, ", "));
|
||||
LOG(FATAL) << errorMsg;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void replaceInMap(
|
||||
std::unordered_map<std::string, std::string>& map,
|
||||
std::string_view old,
|
||||
std::string_view replacement) {
|
||||
auto it = map.find(std::string{old});
|
||||
if (it == map.end()) {
|
||||
return;
|
||||
}
|
||||
std::string value = it->second;
|
||||
map.erase(it);
|
||||
map.emplace(replacement, value);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GraphSignature::GraphSignature(const torch::_export::GraphSignature& storage) {
|
||||
checkInputOrders(storage.get_input_specs());
|
||||
|
||||
for (const torch::_export::InputSpec& inputSpec : storage.get_input_specs()) {
|
||||
switch (inputSpec.tag()) {
|
||||
case torch::_export::InputSpec::Tag::USER_INPUT: {
|
||||
if (inputSpec.get_user_input().get_arg().tag() ==
|
||||
torch::_export::Argument::Tag::AS_TENSOR) {
|
||||
userInputs_.push_back(
|
||||
inputSpec.get_user_input().get_arg().get_as_tensor().get_name());
|
||||
} else if (
|
||||
inputSpec.get_user_input().get_arg().tag() ==
|
||||
torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
|
||||
userInputs_.push_back(inputSpec.get_user_input()
|
||||
.get_arg()
|
||||
.get_as_custom_obj()
|
||||
.get_name());
|
||||
} else {
|
||||
// TODO: handle other types
|
||||
LOG(FATAL) << "Non tensor inputs nyi";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case torch::_export::InputSpec::Tag::PARAMETER: {
|
||||
parameters_.push_back(inputSpec.get_parameter().get_parameter_name());
|
||||
const auto& inputName = inputSpec.get_parameter().get_arg().get_name();
|
||||
const auto& weightName = inputSpec.get_parameter().get_parameter_name();
|
||||
inputsToParameters_.emplace(inputName, weightName);
|
||||
inputsToWeights_.emplace_back(inputName, weightName);
|
||||
break;
|
||||
}
|
||||
case torch::_export::InputSpec::Tag::BUFFER: {
|
||||
const bool isPersistent = inputSpec.get_buffer().get_persistent();
|
||||
const std::string& bufferName =
|
||||
inputSpec.get_buffer().get_buffer_name();
|
||||
if (isPersistent) {
|
||||
buffers_.push_back(bufferName);
|
||||
} else {
|
||||
nonPersistentBuffers_.push_back(bufferName);
|
||||
}
|
||||
const auto& inputName = inputSpec.get_buffer().get_arg().get_name();
|
||||
const auto& weightName = inputSpec.get_buffer().get_buffer_name();
|
||||
inputsToBuffers_.emplace(inputName, weightName);
|
||||
inputsToWeights_.emplace_back(inputName, weightName);
|
||||
break;
|
||||
}
|
||||
case torch::_export::InputSpec::Tag::TENSOR_CONSTANT: {
|
||||
tensorConstants_.push_back(
|
||||
inputSpec.get_tensor_constant().get_tensor_constant_name());
|
||||
const auto& inputName =
|
||||
inputSpec.get_tensor_constant().get_arg().get_name();
|
||||
const auto& weightName =
|
||||
inputSpec.get_tensor_constant().get_tensor_constant_name();
|
||||
|
||||
inputsToTensorConstants_.emplace(inputName, weightName);
|
||||
inputsToWeights_.emplace_back(inputName, weightName);
|
||||
break;
|
||||
}
|
||||
case torch::_export::InputSpec::Tag::CUSTOM_OBJ: {
|
||||
customObjs_.push_back(inputSpec.get_custom_obj().get_custom_obj_name());
|
||||
inputsToCustomObjs_.insert(
|
||||
{inputSpec.get_custom_obj().get_arg().get_name(),
|
||||
inputSpec.get_custom_obj().get_custom_obj_name()});
|
||||
break;
|
||||
}
|
||||
case torch::_export::InputSpec::Tag::CONSTANT_INPUT: {
|
||||
constantInputs_.push_back(inputSpec.get_constant_input().get_name());
|
||||
break;
|
||||
}
|
||||
case torch::_export::InputSpec::Tag::TOKEN: {
|
||||
throw std::runtime_error("Token inputs not implemented yet");
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Got empty thrift argument";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string lossOutput;
|
||||
for (const torch::_export::OutputSpec& outputSpec :
|
||||
storage.get_output_specs()) {
|
||||
switch (outputSpec.tag()) {
|
||||
case torch::_export::OutputSpec::Tag::LOSS_OUTPUT:
|
||||
lossOutput_ = outputSpec.get_loss_output().get_arg().get_name();
|
||||
break;
|
||||
case torch::_export::OutputSpec::Tag::USER_OUTPUT:
|
||||
if (isSymbolicOutput(outputSpec.get_user_output().get_arg().tag())) {
|
||||
switch (outputSpec.get_user_output().get_arg().tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR: {
|
||||
userOutputs_.emplace_back(outputSpec.get_user_output()
|
||||
.get_arg()
|
||||
.get_as_tensor()
|
||||
.get_name());
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
||||
userOutputs_.emplace_back(outputSpec.get_user_output()
|
||||
.get_arg()
|
||||
.get_as_sym_int()
|
||||
.get_as_name());
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unsupported symbolic user output type ";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// for constant outputs, we don't have a name
|
||||
userOutputs_.emplace_back(std::nullopt);
|
||||
}
|
||||
break;
|
||||
case torch::_export::OutputSpec::Tag::BUFFER_MUTATION:
|
||||
buffersToMutate_.insert(
|
||||
{outputSpec.get_buffer_mutation().get_arg().get_name(),
|
||||
outputSpec.get_buffer_mutation().get_buffer_name()});
|
||||
break;
|
||||
case torch::_export::OutputSpec::Tag::GRADIENT_TO_PARAMETER:
|
||||
gradientsToParameters_.insert(
|
||||
{outputSpec.get_gradient_to_parameter().get_arg().get_name(),
|
||||
outputSpec.get_gradient_to_parameter().get_parameter_name()});
|
||||
break;
|
||||
case torch::_export::OutputSpec::Tag::GRADIENT_TO_USER_INPUT:
|
||||
gradientsToUserInputs_.insert(
|
||||
{outputSpec.get_gradient_to_user_input().get_arg().get_name(),
|
||||
outputSpec.get_gradient_to_user_input().get_user_input_name()});
|
||||
break;
|
||||
case torch::_export::OutputSpec::Tag::USER_INPUT_MUTATION:
|
||||
userInputsToMutate_.insert(
|
||||
{outputSpec.get_user_input_mutation().get_arg().get_name(),
|
||||
outputSpec.get_user_input_mutation().get_user_input_name()});
|
||||
break;
|
||||
case torch::_export::OutputSpec::Tag::TOKEN: {
|
||||
throw std::runtime_error("Token outputs not implemented yet");
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Got empty thrift argument";
|
||||
}
|
||||
}
|
||||
|
||||
auto keys_of = [&](const std::unordered_map<std::string, std::string>& dict) {
|
||||
std::vector<std::string_view> keys;
|
||||
keys.reserve(dict.size());
|
||||
for (const auto& [key, _] : dict) {
|
||||
keys.emplace_back(key);
|
||||
}
|
||||
return keys;
|
||||
};
|
||||
|
||||
auto get_valid = [&](const std::vector<std::optional<std::string>>& outputs) {
|
||||
std::vector<std::string> validOutputs;
|
||||
for (const auto& output : outputs) {
|
||||
if (output.has_value()) {
|
||||
validOutputs.push_back(*output);
|
||||
} else {
|
||||
validOutputs.emplace_back("Constant");
|
||||
}
|
||||
}
|
||||
return validOutputs;
|
||||
};
|
||||
|
||||
VLOG(1) << fmt::format("[{}]", fmt::join(userInputs_, ", "));
|
||||
VLOG(1) << fmt::format("[{}]", fmt::join(keys_of(inputsToParameters_), ", "));
|
||||
VLOG(1) << fmt::format("[{}]", fmt::join(keys_of(inputsToBuffers_), ", "));
|
||||
VLOG(1) << fmt::format(
|
||||
"[{}]", fmt::join(keys_of(inputsToTensorConstants_), ", "));
|
||||
VLOG(1) << fmt::format("[{}]", fmt::join(get_valid(userOutputs_), ", "));
|
||||
VLOG(1) << fmt::format("[{}]", fmt::join(keys_of(buffersToMutate_), ", "));
|
||||
VLOG(1) << fmt::format(
|
||||
"[{}]", fmt::join(keys_of(gradientsToParameters_), ", "));
|
||||
VLOG(1) << fmt::format(
|
||||
"[{}]", fmt::join(keys_of(gradientsToUserInputs_), ", "));
|
||||
}
|
||||
|
||||
std::set<std::string> GraphSignature::inputNames() const {
|
||||
std::set<std::string> ret;
|
||||
for (const auto& name : userInputs()) {
|
||||
ret.insert(name);
|
||||
}
|
||||
for (const auto& [inputName, _] : inputsToParameters()) {
|
||||
ret.insert(inputName);
|
||||
}
|
||||
for (const auto& [inputName, _] : inputsToBuffers()) {
|
||||
ret.insert(inputName);
|
||||
}
|
||||
for (const auto& [inputName, _] : inputsToTensorConstants()) {
|
||||
ret.insert(inputName);
|
||||
}
|
||||
for (const auto& [inputName, _] : inputsToCustomObjs()) {
|
||||
ret.insert(inputName);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::set<std::optional<std::string>> GraphSignature::outputNames() const {
|
||||
std::set<std::optional<std::string>> ret;
|
||||
for (const auto& name : userOutputs()) {
|
||||
ret.insert(name);
|
||||
}
|
||||
for (const auto& [outputName, _] : buffersToMutate()) {
|
||||
ret.insert(outputName);
|
||||
}
|
||||
for (const auto& [outputName, _] : userInputsToMutate()) {
|
||||
ret.insert(outputName);
|
||||
}
|
||||
if (hasBackward()) {
|
||||
if (!gradientsToParameters().empty()) {
|
||||
for (const auto& [outputName, _] : gradientsToParameters()) {
|
||||
ret.insert(outputName);
|
||||
}
|
||||
}
|
||||
if (!gradientsToUserInputs().empty()) {
|
||||
for (const auto& [outputName, _] : gradientsToUserInputs()) {
|
||||
ret.insert(outputName);
|
||||
}
|
||||
}
|
||||
if (!lossOutput().empty()) {
|
||||
ret.insert(lossOutput());
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void GraphSignature::lint(
|
||||
const std::set<std::string>& graphInputs,
|
||||
const std::set<std::string>& graphOutputs) const {
|
||||
checkInputNames(inputNames(), graphInputs);
|
||||
checkOutputNames(outputNames(), graphOutputs);
|
||||
}
|
||||
|
||||
void GraphSignature::replaceAllUses(
|
||||
std::string_view old,
|
||||
std::string_view replacement) {
|
||||
if (old == replacement) {
|
||||
return;
|
||||
}
|
||||
for (auto& name : userOutputs_) {
|
||||
if (name == old) {
|
||||
name = replacement;
|
||||
}
|
||||
}
|
||||
replaceInMap(buffersToMutate_, old, replacement);
|
||||
if (hasBackward()) {
|
||||
replaceInMap(gradientsToParameters_, old, replacement);
|
||||
replaceInMap(gradientsToUserInputs_, old, replacement);
|
||||
if (old == lossOutput_) {
|
||||
lossOutput_ = replacement;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const GraphSignature& sig) {
|
||||
out << "inputsToParameters: {\n";
|
||||
for (const auto& [inputName, paramName] : sig.inputsToParameters()) {
|
||||
out << "\t" << inputName << " : " << paramName << std::endl;
|
||||
}
|
||||
out << "}\n";
|
||||
|
||||
out << "inputsToBuffers: {\n";
|
||||
for (const auto& [inputName, bufferName] : sig.inputsToBuffers()) {
|
||||
out << "\t" << inputName << " : " << bufferName << std::endl;
|
||||
}
|
||||
out << "}\n";
|
||||
|
||||
out << "inputsToTensorConstants: {\n";
|
||||
for (const auto& [inputName, tensorConstantName] :
|
||||
sig.inputsToTensorConstants()) {
|
||||
out << "\t" << inputName << " : " << tensorConstantName << std::endl;
|
||||
}
|
||||
out << "}\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
141
torch/csrc/nativert/graph/GraphSignature.h
Normal file
141
torch/csrc/nativert/graph/GraphSignature.h
Normal file
@ -0,0 +1,141 @@
|
||||
#pragma once
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class Graph;
|
||||
|
||||
// In memory representation of graph signature.
|
||||
class GraphSignature {
|
||||
public:
|
||||
GraphSignature() = default;
|
||||
explicit GraphSignature(const torch::_export::GraphSignature& storage);
|
||||
|
||||
const auto& lossOutput() const {
|
||||
return lossOutput_;
|
||||
}
|
||||
|
||||
const auto& gradientsToParameters() const {
|
||||
return gradientsToParameters_;
|
||||
}
|
||||
|
||||
const auto& gradientsToUserInputs() const {
|
||||
return gradientsToUserInputs_;
|
||||
}
|
||||
|
||||
const auto& inputsToParameters() const {
|
||||
return inputsToParameters_;
|
||||
}
|
||||
|
||||
const auto& inputsToBuffers() const {
|
||||
return inputsToBuffers_;
|
||||
}
|
||||
|
||||
const auto& inputsToTensorConstants() const {
|
||||
return inputsToTensorConstants_;
|
||||
}
|
||||
|
||||
const auto& inputsToCustomObjs() const {
|
||||
return inputsToCustomObjs_;
|
||||
}
|
||||
|
||||
const auto& userInputsVec() const {
|
||||
return userInputs_;
|
||||
}
|
||||
|
||||
const auto& parameters() const {
|
||||
return parameters_;
|
||||
}
|
||||
|
||||
const auto& buffers() const {
|
||||
return buffers_;
|
||||
}
|
||||
|
||||
const auto& nonPersistentBuffers() const {
|
||||
return nonPersistentBuffers_;
|
||||
}
|
||||
|
||||
const auto& tensorConstants() const {
|
||||
return tensorConstants_;
|
||||
}
|
||||
|
||||
const auto& customObjs() const {
|
||||
return customObjs_;
|
||||
}
|
||||
|
||||
const auto& userInputs() const {
|
||||
return userInputs_;
|
||||
}
|
||||
|
||||
const auto& constantInputs() const {
|
||||
return constantInputs_;
|
||||
}
|
||||
|
||||
const auto& userOutputs() const {
|
||||
return userOutputs_;
|
||||
}
|
||||
|
||||
const auto& buffersToMutate() const {
|
||||
return buffersToMutate_;
|
||||
}
|
||||
|
||||
const auto& userInputsToMutate() const {
|
||||
return userInputsToMutate_;
|
||||
}
|
||||
|
||||
bool hasBackward() const {
|
||||
return !(
|
||||
lossOutput_.empty() && gradientsToParameters_.empty() &&
|
||||
gradientsToUserInputs_.empty() && buffersToMutate_.empty());
|
||||
}
|
||||
|
||||
// Mapping of FQNs to weights with stable iteration order.
|
||||
const auto& inputsToWeights() const {
|
||||
return inputsToWeights_;
|
||||
}
|
||||
|
||||
void lint(
|
||||
const std::set<std::string>& graphInputs,
|
||||
const std::set<std::string>& graphOutputs) const;
|
||||
void replaceAllUses(std::string_view old, std::string_view replacement);
|
||||
|
||||
torch::_export::GraphSignature serialize() const;
|
||||
|
||||
private:
|
||||
std::set<std::string> inputNames() const;
|
||||
std::set<std::optional<std::string>> outputNames() const;
|
||||
|
||||
std::unordered_map<std::string, std::string> gradientsToParameters_;
|
||||
std::unordered_map<std::string, std::string> gradientsToUserInputs_;
|
||||
std::unordered_map<std::string, std::string> inputsToParameters_;
|
||||
std::unordered_map<std::string, std::string> inputsToBuffers_;
|
||||
std::unordered_map<std::string, std::string> buffersToMutate_;
|
||||
std::unordered_map<std::string, std::string> inputsToTensorConstants_;
|
||||
std::unordered_map<std::string, std::string> inputsToCustomObjs_;
|
||||
std::unordered_map<std::string, std::string> userInputsToMutate_;
|
||||
|
||||
// map union of inputsToParameters_, inputsToBuffers_ and
|
||||
// inputsToTensorConstants_
|
||||
std::vector<std::pair<std::string, std::string>> inputsToWeights_;
|
||||
|
||||
std::vector<std::string> parameters_;
|
||||
std::vector<std::string> buffers_;
|
||||
std::vector<std::string> tensorConstants_;
|
||||
std::vector<std::string> customObjs_;
|
||||
std::vector<std::string> nonPersistentBuffers_;
|
||||
|
||||
std::vector<std::string> userInputs_;
|
||||
std::vector<std::string> constantInputs_;
|
||||
std::vector<std::optional<std::string>> userOutputs_;
|
||||
std::string lossOutput_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const GraphSignature& sig);
|
||||
|
||||
} // namespace torch::nativert
|
||||
525
torch/csrc/nativert/graph/Serialization.cpp
Normal file
525
torch/csrc/nativert/graph/Serialization.cpp
Normal file
@ -0,0 +1,525 @@
|
||||
#include "torch/csrc/nativert/graph/Serialization.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/ostream.h>
|
||||
#include <fmt/ranges.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<Graph> jsonToSubgraph(
|
||||
const torch::_export::Graph& thriftGraph,
|
||||
const torch::_export::GraphSignature* signature,
|
||||
bool loadNodeMetadata);
|
||||
|
||||
Value* symbolicToValue(
|
||||
const torch::_export::Argument& arg,
|
||||
Graph& graph,
|
||||
Node* insertBefore) {
|
||||
switch (arg.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR:
|
||||
return graph.getValue(arg.get_as_tensor().get_name());
|
||||
case torch::_export::Argument::Tag::AS_TENSORS: {
|
||||
// Need to insert a list pack node
|
||||
std::vector<Value*> listValue;
|
||||
for (const auto& listEl : arg.get_as_tensors()) {
|
||||
listValue.push_back(graph.getValue(listEl.get_name()));
|
||||
}
|
||||
auto listPack = graph.createListPack(std::move(listValue), Type::Tensor);
|
||||
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: {
|
||||
// Need to insert a list pack node
|
||||
std::vector<Value*> listValue;
|
||||
for (const auto& listEl : arg.get_as_optional_tensors()) {
|
||||
switch (listEl.tag()) {
|
||||
case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: {
|
||||
listValue.push_back(
|
||||
graph.getValue(listEl.get_as_tensor().get_name()));
|
||||
break;
|
||||
}
|
||||
case torch::_export::OptionalTensorArgument::Tag::AS_NONE: {
|
||||
listValue.push_back(
|
||||
graph.addValue(std::nullopt, Type::None, nullptr));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unknown OptionalTensorArgument type: {}",
|
||||
torch::_export::printEnum(listEl.tag())));
|
||||
}
|
||||
}
|
||||
auto listPack = graph.createOptionalListPack(std::move(listValue));
|
||||
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
||||
return graph.getValue(arg.get_as_sym_int().get_as_name());
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS: {
|
||||
// Need to insert a list pack node
|
||||
std::vector<Value*> listValue;
|
||||
for (const auto& listEl : arg.get_as_sym_ints()) {
|
||||
switch (listEl.tag()) {
|
||||
case torch::_export::SymIntArgument::Tag::AS_NAME: {
|
||||
listValue.push_back(graph.getValue(listEl.get_as_name()));
|
||||
break;
|
||||
}
|
||||
case torch::_export::SymIntArgument::Tag::AS_INT: {
|
||||
// These are concrete int values in the SymIntList, e.g [s0, 8]
|
||||
// We convert them into a constant Value in graph. These value
|
||||
// doesn't have producer node
|
||||
int value = listEl.get_as_int();
|
||||
Value* symintValue = graph.createConstantSymIntValue(value);
|
||||
listValue.push_back(symintValue);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unknown SymIntArgument type: {}",
|
||||
torch::_export::printEnum(listEl.tag())));
|
||||
}
|
||||
}
|
||||
auto listPack = graph.createListPack(std::move(listValue), Type::SymInt);
|
||||
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
|
||||
return graph.getValue(arg.get_as_custom_obj().get_name());
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
|
||||
return graph.getValue(arg.get_as_sym_bool().get_as_name());
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
||||
return graph.getValue(arg.get_as_sym_float().get_as_name());
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"This function should only be called for symbolics, got \'{}\' instead",
|
||||
torch::_export::printEnum(arg.tag())));
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<
|
||||
std::vector<torch::_export::InputSpec>,
|
||||
std::vector<torch::_export::Argument>>
|
||||
enforceInputOrder(
|
||||
const std::vector<torch::_export::InputSpec>& inputSpecs,
|
||||
const std::vector<torch::_export::Argument>& graphInputs) {
|
||||
// Enforce the order of inputSpecs and graphInputs to be the following:
|
||||
// 1. token
|
||||
// 2. parameter
|
||||
// 3. persistent buffer, non-persistent buffer
|
||||
// 4. tensor_constant
|
||||
// 5. custom_obj
|
||||
// 6. user_input/constant_input
|
||||
std::vector<torch::_export::InputSpec> reorderedInputSpecs;
|
||||
std::vector<torch::_export::Argument> reorderedGraphInputs;
|
||||
std::vector<torch::_export::InputSpec::Tag> desiredOrder = {
|
||||
torch::_export::InputSpec::Tag::TOKEN,
|
||||
torch::_export::InputSpec::Tag::PARAMETER,
|
||||
torch::_export::InputSpec::Tag::BUFFER,
|
||||
torch::_export::InputSpec::Tag::TENSOR_CONSTANT,
|
||||
torch::_export::InputSpec::Tag::CUSTOM_OBJ};
|
||||
|
||||
auto reorder = [&](auto condition) {
|
||||
for (size_t i = 0; i < inputSpecs.size(); ++i) {
|
||||
if (condition(inputSpecs[i])) {
|
||||
reorderedInputSpecs.push_back(inputSpecs[i]);
|
||||
reorderedGraphInputs.push_back(graphInputs[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (const auto& tag : desiredOrder) {
|
||||
if (tag == torch::_export::InputSpec::Tag::BUFFER) {
|
||||
// Add persistent buffers first, then non-persistent
|
||||
reorder([&](const auto& spec) {
|
||||
return spec.tag() == tag && spec.get_buffer().get_persistent();
|
||||
});
|
||||
reorder([&](const auto& spec) {
|
||||
return spec.tag() == tag && !spec.get_buffer().get_persistent();
|
||||
});
|
||||
} else {
|
||||
reorder([&](const auto& spec) { return spec.tag() == tag; });
|
||||
}
|
||||
}
|
||||
|
||||
// Append USER_INPUT and CONSTANT_INPUT without reordering
|
||||
for (size_t i = 0; i < inputSpecs.size(); ++i) {
|
||||
auto tag = inputSpecs[i].tag();
|
||||
if (tag == torch::_export::InputSpec::Tag::USER_INPUT ||
|
||||
tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
|
||||
reorderedInputSpecs.push_back(inputSpecs[i]);
|
||||
reorderedGraphInputs.push_back(graphInputs[i]);
|
||||
}
|
||||
}
|
||||
return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)};
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> jsonToSubgraph(
|
||||
const torch::_export::Graph& jsonGraph,
|
||||
const torch::_export::GraphSignature* signature,
|
||||
bool loadNodeMetadata) {
|
||||
auto graphInputs = jsonGraph.get_inputs();
|
||||
auto graph = Graph::createGraph();
|
||||
|
||||
if (signature) {
|
||||
// enforcing the order signature inputspecs and graph inputs
|
||||
auto inputSpecs = signature->get_input_specs();
|
||||
|
||||
auto [reorderedInputSpecs, reorderedGraphInputs] =
|
||||
enforceInputOrder(inputSpecs, graphInputs);
|
||||
|
||||
graphInputs = std::move(reorderedGraphInputs);
|
||||
auto reorderedSignature = *signature;
|
||||
reorderedSignature.set_input_specs(reorderedInputSpecs);
|
||||
graph->setSignature(GraphSignature{reorderedSignature});
|
||||
}
|
||||
|
||||
for (const auto& input : graphInputs) {
|
||||
if (isSymbolic(input)) {
|
||||
switch (input.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR: {
|
||||
const auto& asTensor = input.get_as_tensor();
|
||||
const auto& name = asTensor.get_name();
|
||||
graph->addInput(name, Type::Tensor);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
|
||||
const auto& asCustomObj = input.get_as_custom_obj();
|
||||
const std::string& name = asCustomObj.get_name();
|
||||
const std::string& classFqn = asCustomObj.get_class_fqn();
|
||||
graph->addInput(name, Type(Type::CustomObj, classFqn));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported graph input type {}",
|
||||
static_cast<int>(input.tag())));
|
||||
}
|
||||
} else {
|
||||
switch (input.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_INT:
|
||||
case torch::_export::Argument::Tag::AS_FLOAT:
|
||||
case torch::_export::Argument::Tag::AS_STRING:
|
||||
case torch::_export::Argument::Tag::AS_BOOL:
|
||||
case torch::_export::Argument::Tag::AS_NONE: {
|
||||
// Constant graph inputs are specialized in the graph, here we simply
|
||||
// add a nullptr of Value to the graph input node.
|
||||
graph->addInput();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported graph input type {}",
|
||||
static_cast<int>(input.tag())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& jsonNode : jsonGraph.get_nodes()) {
|
||||
auto node = graph->insertNode(
|
||||
jsonNode.get_target(),
|
||||
{},
|
||||
loadNodeMetadata ? jsonNode.get_metadata()
|
||||
: std::unordered_map<std::string, std::string>());
|
||||
|
||||
std::vector<NamedArgument> args;
|
||||
std::vector<Attribute> attributes;
|
||||
for (const auto& input : jsonNode.get_inputs()) {
|
||||
// We handle constants and symbolic inputs differently.
|
||||
const auto& arg = input.get_arg();
|
||||
if (isSymbolic(arg)) {
|
||||
// Symbolic values are made part of the inputs to the node
|
||||
node->addInput(NamedArgument{
|
||||
input.get_name(), symbolicToValue(input.get_arg(), *graph, node)});
|
||||
} else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) {
|
||||
node->addInput(NamedArgument{
|
||||
input.get_name(), graph->addValue(std::nullopt, Type::None, node)});
|
||||
} else {
|
||||
node->addAttribute(Attribute{
|
||||
input.get_name(),
|
||||
constantToValue(input.get_arg(), loadNodeMetadata)});
|
||||
// Constant values are added as "attributes" to the node.
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Value*> outputs;
|
||||
std::vector<Value*> listUnpacksToCreate;
|
||||
for (const auto& output : jsonNode.get_outputs()) {
|
||||
switch (output.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_NONE: {
|
||||
node->addOutput(Type::None);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_TENSOR: {
|
||||
const auto name = output.get_as_tensor().get_name();
|
||||
node->addOutput(name, Type::Tensor);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_TENSORS: {
|
||||
auto outputValue =
|
||||
node->addOutput(graph->getUniqueValueName(), Type::TensorList);
|
||||
|
||||
Node* listUnpack =
|
||||
graph->insertNode("prim.ListUnpack", {{"input", outputValue}});
|
||||
for (const auto& arg : output.get_as_tensors()) {
|
||||
listUnpack->addOutput(arg.get_name(), Type::Tensor);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
||||
const auto name = output.get_as_sym_int().get_as_name();
|
||||
node->addOutput(name, Type::SymInt);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS: {
|
||||
throw std::runtime_error(
|
||||
"SymInts NYI. We currently doesn't have op that produces SymInts as output");
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
|
||||
const auto name = output.get_as_sym_bool().get_as_name();
|
||||
node->addOutput(name, Type::SymBool);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOLS: {
|
||||
throw std::runtime_error(
|
||||
"SymBools NYI. We currently doesn't have op that produces SymBools as output");
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
||||
const auto name = output.get_as_sym_float().get_as_name();
|
||||
node->addOutput(name, Type::SymFloat);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
|
||||
throw std::runtime_error(
|
||||
"SymFloats NYI. We currently doesn't have op that produces SymFloats as output");
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"disallowed output type {}",
|
||||
torch::_export::printEnum(output.tag())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& output : jsonGraph.get_outputs()) {
|
||||
// handle symbolic outputs and constant outputs differently
|
||||
if (isSymbolic(output)) {
|
||||
switch (output.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR: {
|
||||
const auto& asTensor = output.get_as_tensor();
|
||||
const auto& name = asTensor.get_name();
|
||||
Value* outputValue = graph->getValue(name);
|
||||
graph->addOutput(outputValue);
|
||||
break;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
||||
const auto& asSymInt = output.get_as_sym_int();
|
||||
TORCH_CHECK(
|
||||
asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME);
|
||||
const auto& name = asSymInt.get_as_name();
|
||||
Value* outputValue = graph->getValue(name);
|
||||
graph->addOutput(outputValue);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported graph output type: {}",
|
||||
static_cast<size_t>(output.tag())));
|
||||
}
|
||||
} else {
|
||||
Constant constValue = constantToValue(output, loadNodeMetadata);
|
||||
graph->addConstantOutput(constValue);
|
||||
}
|
||||
}
|
||||
|
||||
auto jsonTensorValue = jsonGraph.get_tensor_values();
|
||||
|
||||
if (!signature) {
|
||||
// For subgraphs we just need to derive a graph signature that only
|
||||
// contains user inputs and outputs, because we don't need to handle any
|
||||
// special semantics for them, e.g. mutation or gradients.
|
||||
torch::_export::GraphSignature sig;
|
||||
std::vector<torch::_export::InputSpec> inputSpecs;
|
||||
for (const auto& input : graph->inputs()) {
|
||||
torch::_export::Argument arg;
|
||||
if (input->type().kind() == Type::Tensor) {
|
||||
torch::_export::TensorArgument targ;
|
||||
targ.set_name(std::string{input->name()});
|
||||
arg.set_as_tensor(std::move(targ));
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported subgraph input type {}",
|
||||
fmt::streamed(input->type())));
|
||||
}
|
||||
torch::_export::UserInputSpec userInputSpec;
|
||||
userInputSpec.set_arg(std::move(arg));
|
||||
torch::_export::InputSpec inputSpec;
|
||||
inputSpec.set_user_input(std::move(userInputSpec));
|
||||
inputSpecs.push_back(std::move(inputSpec));
|
||||
}
|
||||
sig.set_input_specs(std::move(inputSpecs));
|
||||
|
||||
std::vector<torch::_export::OutputSpec> outputSpecs;
|
||||
for (const auto& output : graph->outputs()) {
|
||||
torch::_export::Argument arg;
|
||||
if (output->type().kind() == Type::Tensor) {
|
||||
torch::_export::TensorArgument targ;
|
||||
targ.set_name(std::string{output->name()});
|
||||
arg.set_as_tensor(std::move(targ));
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported subgraph input type {}",
|
||||
fmt::streamed(output->type())));
|
||||
}
|
||||
torch::_export::UserOutputSpec userOutputSpec;
|
||||
userOutputSpec.set_arg(std::move(arg));
|
||||
torch::_export::OutputSpec outputSpec;
|
||||
outputSpec.set_user_output(std::move(userOutputSpec));
|
||||
outputSpecs.push_back(std::move(outputSpec));
|
||||
}
|
||||
sig.set_output_specs(std::move(outputSpecs));
|
||||
|
||||
graph->setSignature(GraphSignature{sig});
|
||||
}
|
||||
|
||||
// weightsTensorMeta are indexed by weight's name, not graph input's name
|
||||
std::unordered_map<std::string, torch::_export::TensorMeta> weightsTensorMeta;
|
||||
for (const auto& [inputName, weightName] :
|
||||
graph->signature().inputsToWeights()) {
|
||||
auto value = graph->getValue(inputName);
|
||||
if (value->type().kind() == Type::CustomObj) {
|
||||
// skip setting meta for non-tensor inputs
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = jsonTensorValue.find(inputName);
|
||||
CHECK(it != jsonTensorValue.end())
|
||||
<< "Missing tensor metadata for " << inputName
|
||||
<< "in thriftGraph.tensorValue";
|
||||
weightsTensorMeta[weightName] = it->second;
|
||||
}
|
||||
graph->setWeightsMeta(weightsTensorMeta);
|
||||
|
||||
graph->setTensorValuesMeta(jsonTensorValue);
|
||||
|
||||
graph->finalize();
|
||||
|
||||
graph->lint();
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isSymbolic(const torch::_export::Argument& arg) {
|
||||
switch (arg.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR:
|
||||
case torch::_export::Argument::Tag::AS_TENSORS:
|
||||
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Constant constantToValue(
|
||||
const torch::_export::Argument& jsonArg,
|
||||
bool loadNodeMetadata) {
|
||||
switch (jsonArg.tag()) {
|
||||
case torch::_export::Argument::Tag::AS_TENSOR:
|
||||
throw std::runtime_error("Tensor is symbolic, shouldn't reach here");
|
||||
case torch::_export::Argument::Tag::AS_TENSORS:
|
||||
throw std::runtime_error("Tensor[] is symbolic, shouldn't reach here");
|
||||
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
||||
throw std::runtime_error("Tensor?[] is symbolic, shouldn't reach here");
|
||||
case torch::_export::Argument::Tag::AS_NONE:
|
||||
return None();
|
||||
case torch::_export::Argument::Tag::AS_INT:
|
||||
return jsonArg.get_as_int();
|
||||
case torch::_export::Argument::Tag::AS_INTS: {
|
||||
std::vector<int64_t> ret;
|
||||
for (const auto& arg : jsonArg.get_as_ints()) {
|
||||
ret.push_back(arg);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_FLOAT:
|
||||
return jsonArg.get_as_float().get();
|
||||
case torch::_export::Argument::Tag::AS_FLOATS: {
|
||||
std::vector<double> ret;
|
||||
for (const auto& arg : jsonArg.get_as_floats()) {
|
||||
ret.push_back(arg.get());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_STRING:
|
||||
return jsonArg.get_as_string();
|
||||
case torch::_export::Argument::Tag::AS_STRINGS: {
|
||||
std::vector<std::string> ret;
|
||||
for (const auto& arg : jsonArg.get_as_strings()) {
|
||||
ret.push_back(arg);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_INT:
|
||||
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
||||
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
||||
throw std::runtime_error(
|
||||
"Symint/Symbool Values are symbolic, shouldn't reach here");
|
||||
|
||||
case torch::_export::Argument::Tag::AS_SCALAR_TYPE:
|
||||
return convertJsonScalarType(jsonArg.get_as_scalar_type());
|
||||
case torch::_export::Argument::Tag::AS_MEMORY_FORMAT:
|
||||
return convertJsonMemoryFormat(jsonArg.get_as_memory_format());
|
||||
case torch::_export::Argument::Tag::AS_LAYOUT:
|
||||
return convertJsonLayout(jsonArg.get_as_layout());
|
||||
case torch::_export::Argument::Tag::AS_DEVICE:
|
||||
return convertJsonDevice(jsonArg.get_as_device());
|
||||
case torch::_export::Argument::Tag::AS_BOOL:
|
||||
return jsonArg.get_as_bool();
|
||||
case torch::_export::Argument::Tag::AS_BOOLS: {
|
||||
std::vector<bool> ret;
|
||||
for (const auto& arg : jsonArg.get_as_bools()) {
|
||||
ret.push_back(arg);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_GRAPH: {
|
||||
return jsonToSubgraph(
|
||||
*jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata);
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
||||
throw std::runtime_error("custom obj is dynamic, shouldn't reach here");
|
||||
case torch::_export::Argument::Tag::AS_OPERATOR:
|
||||
return jsonArg.get_as_operator();
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
||||
throw std::runtime_error("sym float is not yet implemented");
|
||||
}
|
||||
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
|
||||
throw std::runtime_error("sym floats is not yet implemented");
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Got unknown thrift argument");
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> jsonToGraph(
|
||||
const torch::_export::GraphModule& jsonGraphModule,
|
||||
bool loadNodeMetadata) {
|
||||
auto graph = jsonToSubgraph(
|
||||
jsonGraphModule.get_graph(),
|
||||
&jsonGraphModule.get_signature(),
|
||||
loadNodeMetadata);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
34
torch/csrc/nativert/graph/Serialization.h
Normal file
34
torch/csrc/nativert/graph/Serialization.h
Normal file
@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
namespace torch::nativert {
|
||||
/**
|
||||
* This file contains serialization utilities for Graph.
|
||||
*
|
||||
* There are two serialized representations we care about:
|
||||
* - Json: stable but hard to work with, not really human readable
|
||||
* - Debug format: human-readable, not stable.
|
||||
*
|
||||
* All formats should be logically equivalent, so we should be able to go from
|
||||
* in-memory graph <> json <> debugformat interchangeably
|
||||
*/
|
||||
|
||||
// Json -> Graph
|
||||
std::unique_ptr<Graph> jsonToGraph(
|
||||
const torch::_export::GraphModule& jsonGraph,
|
||||
bool loadNodeMetadata = true);
|
||||
|
||||
// Graph -> Json
|
||||
std::pair<torch::_export::Graph, torch::_export::GraphSignature> graphToJson(
|
||||
const Graph& graph);
|
||||
|
||||
bool isSymbolic(const torch::_export::Argument& arg);
|
||||
|
||||
Constant constantToValue(
|
||||
const torch::_export::Argument& jsonArg,
|
||||
bool loadNodeMetadata);
|
||||
|
||||
} // namespace torch::nativert
|
||||
136
torch/csrc/nativert/graph/TensorMeta.cpp
Normal file
136
torch/csrc/nativert/graph/TensorMeta.cpp
Normal file
@ -0,0 +1,136 @@
|
||||
#include "torch/csrc/nativert/graph/TensorMeta.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
c10::ScalarType convertJsonScalarType(
|
||||
const torch::_export::ScalarType& scalarType) {
|
||||
switch (scalarType) {
|
||||
case torch::_export::ScalarType::UNKNOWN:
|
||||
CHECK(false) << "scalar type is not properly set";
|
||||
case torch::_export::ScalarType::BYTE:
|
||||
return c10::ScalarType::Byte;
|
||||
case torch::_export::ScalarType::CHAR:
|
||||
return c10::ScalarType::Char;
|
||||
case torch::_export::ScalarType::SHORT:
|
||||
return c10::ScalarType::Short;
|
||||
case torch::_export::ScalarType::INT:
|
||||
return c10::ScalarType::Int;
|
||||
case torch::_export::ScalarType::LONG:
|
||||
return c10::ScalarType::Long;
|
||||
case torch::_export::ScalarType::HALF:
|
||||
return c10::ScalarType::Half;
|
||||
case torch::_export::ScalarType::FLOAT:
|
||||
return c10::ScalarType::Float;
|
||||
case torch::_export::ScalarType::DOUBLE:
|
||||
return c10::ScalarType::Double;
|
||||
case torch::_export::ScalarType::COMPLEXHALF:
|
||||
return c10::ScalarType::ComplexHalf;
|
||||
case torch::_export::ScalarType::COMPLEXFLOAT:
|
||||
return c10::ScalarType::ComplexFloat;
|
||||
case torch::_export::ScalarType::COMPLEXDOUBLE:
|
||||
return c10::ScalarType::ComplexDouble;
|
||||
case torch::_export::ScalarType::BOOL:
|
||||
return c10::ScalarType::Bool;
|
||||
case torch::_export::ScalarType::BFLOAT16:
|
||||
return c10::ScalarType::BFloat16;
|
||||
default:
|
||||
TORCH_CHECK(false, "unknown scalar type", static_cast<int>(scalarType));
|
||||
}
|
||||
}
|
||||
|
||||
c10::MemoryFormat convertJsonMemoryFormat(
|
||||
const torch::_export::MemoryFormat& memoryFormat) {
|
||||
switch (memoryFormat) {
|
||||
case torch::_export::MemoryFormat::Unknown:
|
||||
TORCH_CHECK(false, "got unknown scalar type");
|
||||
case torch::_export::MemoryFormat::ContiguousFormat:
|
||||
return c10::MemoryFormat::Contiguous;
|
||||
case torch::_export::MemoryFormat::ChannelsLast:
|
||||
return c10::MemoryFormat::ChannelsLast;
|
||||
case torch::_export::MemoryFormat::ChannelsLast3d:
|
||||
return c10::MemoryFormat::ChannelsLast3d;
|
||||
case torch::_export::MemoryFormat::PreserveFormat:
|
||||
return c10::MemoryFormat::Preserve;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "unknown memory format", static_cast<int>(memoryFormat));
|
||||
}
|
||||
}
|
||||
|
||||
c10::Layout convertJsonLayout(const torch::_export::Layout& layout) {
|
||||
switch (layout) {
|
||||
case torch::_export::Layout::Unknown:
|
||||
TORCH_CHECK(false, "got unknown layout");
|
||||
case torch::_export::Layout::SparseCoo:
|
||||
// TODO is this the right translation
|
||||
return c10::Layout::Sparse;
|
||||
case torch::_export::Layout::SparseCsr:
|
||||
return c10::Layout::SparseCsr;
|
||||
case torch::_export::Layout::SparseCsc:
|
||||
return c10::Layout::SparseCsc;
|
||||
case torch::_export::Layout::SparseBsr:
|
||||
return c10::Layout::SparseBsr;
|
||||
case torch::_export::Layout::SparseBsc:
|
||||
return c10::Layout::SparseBsc;
|
||||
case torch::_export::Layout::_mkldnn:
|
||||
return c10::Layout::Mkldnn;
|
||||
case torch::_export::Layout::Strided:
|
||||
return c10::Layout::Strided;
|
||||
default:
|
||||
TORCH_CHECK(false, "unknown layout", static_cast<int>(layout));
|
||||
}
|
||||
}
|
||||
|
||||
c10::Device convertJsonDevice(const torch::_export::Device& device) {
|
||||
c10::Device d(device.get_type());
|
||||
if (auto index = device.get_index()) {
|
||||
d.set_index(*index);
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta)
|
||||
: device_(convertJsonDevice(tensorMeta.get_device())) {
|
||||
dtype_ = convertJsonScalarType(tensorMeta.get_dtype());
|
||||
layout_ = convertJsonLayout(tensorMeta.get_layout());
|
||||
requiresGrad_ = tensorMeta.get_requires_grad();
|
||||
|
||||
if (tensorMeta.get_storage_offset().tag() ==
|
||||
torch::_export::SymInt::Tag::AS_INT) {
|
||||
storage_offset_ = tensorMeta.get_storage_offset().get_as_int();
|
||||
} else {
|
||||
CHECK(false) << "SymInt not supported yet";
|
||||
}
|
||||
|
||||
for (const auto& size : tensorMeta.get_sizes()) {
|
||||
if (size.tag() == torch::_export::SymInt::Tag::AS_INT) {
|
||||
int64_t val = size.get_as_int();
|
||||
sizes_.emplace_back(val);
|
||||
numel_ *= val;
|
||||
} else if (size.tag() == torch::_export::SymInt::Tag::AS_EXPR) {
|
||||
// TODO: it's still unclear how SymInt shape should be used in runtime
|
||||
// One potential use cases is for verifing inputs shape matches constrain
|
||||
// This would require unpacking the serialized constrain, which is NYI
|
||||
//
|
||||
// For the time being, we just set the symbolic dim to -1
|
||||
hasSymbolicShape_ = true;
|
||||
sizes_.emplace_back(-1);
|
||||
numel_ = -1;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& stride : tensorMeta.get_strides()) {
|
||||
if (stride.tag() == torch::_export::SymInt::Tag::AS_INT) {
|
||||
strides_.emplace_back(stride.get_as_int());
|
||||
} else if (stride.tag() == torch::_export::SymInt::Tag::AS_EXPR) {
|
||||
// TODO: it's still unclear how SymInt shape should be used in runtime
|
||||
// Setting symbolic shape to -1 for now
|
||||
hasSymbolicShape_ = true;
|
||||
strides_.emplace_back(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
92
torch/csrc/nativert/graph/TensorMeta.h
Normal file
92
torch/csrc/nativert/graph/TensorMeta.h
Normal file
@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include "c10/core/Layout.h"
|
||||
|
||||
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
c10::ScalarType convertJsonScalarType(
|
||||
const torch::_export::ScalarType& scalarType);
|
||||
c10::MemoryFormat convertJsonMemoryFormat(
|
||||
const torch::_export::MemoryFormat& memoryFormat);
|
||||
c10::Layout convertJsonLayout(const torch::_export::Layout& layout);
|
||||
c10::Device convertJsonDevice(const torch::_export::Device& device);
|
||||
|
||||
class TensorMeta {
|
||||
public:
|
||||
explicit TensorMeta(const torch::_export::TensorMeta& tensorMeta);
|
||||
|
||||
c10::IntArrayRef sizes() const {
|
||||
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
|
||||
return sizes_;
|
||||
}
|
||||
|
||||
c10::IntArrayRef strides() const {
|
||||
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
|
||||
return strides_;
|
||||
}
|
||||
|
||||
c10::Layout layout() const {
|
||||
return layout_;
|
||||
}
|
||||
|
||||
c10::ScalarType dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
bool requires_grad() const {
|
||||
return requiresGrad_;
|
||||
}
|
||||
|
||||
int64_t storage_offset() const {
|
||||
return storage_offset_;
|
||||
}
|
||||
|
||||
int64_t dim() const {
|
||||
return sizes_.size();
|
||||
}
|
||||
|
||||
int64_t numel() const {
|
||||
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
|
||||
return numel_;
|
||||
}
|
||||
|
||||
c10::Device device() const {
|
||||
return device_;
|
||||
}
|
||||
|
||||
c10::TensorOptions asTensorOptions() const {
|
||||
return c10::TensorOptions().dtype(dtype_).layout(layout_).requires_grad(
|
||||
requiresGrad_);
|
||||
}
|
||||
|
||||
// NYI
|
||||
// c10::SymIntArrayRef sym_sizes() const {}
|
||||
// c10::SymIntArrayRef sym_strides() const {}
|
||||
// c10::SymInt sym_storage_offset() const {}
|
||||
// c10::SymInt sym_numel() const {}
|
||||
|
||||
private:
|
||||
bool hasSymbolicShape_ = false;
|
||||
|
||||
std::vector<int64_t> sizes_;
|
||||
std::vector<int64_t> strides_;
|
||||
int64_t storage_offset_ = 0;
|
||||
int64_t numel_ = 1;
|
||||
|
||||
c10::ScalarType dtype_;
|
||||
c10::Layout layout_;
|
||||
bool requiresGrad_;
|
||||
|
||||
c10::Device device_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
50
torch/csrc/nativert/kernels/AOTICallDelegateKernel.cpp
Normal file
50
torch/csrc/nativert/kernels/AOTICallDelegateKernel.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
#include "torch/csrc/nativert/kernels/AOTICallDelegateKernel.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
AOTICallDelegateKernel::AOTICallDelegateKernel(
|
||||
const Node* node,
|
||||
AOTIDelegateExecutor& delegateExecutor)
|
||||
: OpKernel(node), delegateExecutor_(delegateExecutor) {
|
||||
// torch.higher_order_ops.aoti_call_delegate(lowered_module, original_gm,
|
||||
// weight_args, input_args). However, the second
|
||||
// input is currently serialized as None, so we only have 3 inputs and 1
|
||||
// output.
|
||||
// TODO(T213681594): Fix this, None input should still be included in
|
||||
// numInputs().
|
||||
TORCH_CHECK_EQ(node->numInputs(), 3);
|
||||
TORCH_CHECK_EQ(node->numOutputs(), 1);
|
||||
|
||||
// Weights are in node->inputs()[1], but it's not used in the forward call
|
||||
// Instead, weight are bound to AOTI via loadWeights()
|
||||
const Value* input = node->inputs()[2].value;
|
||||
const Value* output = node->outputs()[0];
|
||||
|
||||
CHECK(input->type() == Type::TensorList)
|
||||
<< "torch.higher_order_ops.aoti_call_delegate input should be a TensorList, but got "
|
||||
<< input->type();
|
||||
CHECK(output->type() == Type::TensorList)
|
||||
<< "torch.higher_order_ops.aoti_call_delegate output should be a TensorList, but got "
|
||||
<< output->type();
|
||||
|
||||
inputValueId_ = input->id();
|
||||
outputValueId_ = output->id();
|
||||
}
|
||||
|
||||
void AOTICallDelegateKernel::computeInternal(
|
||||
ExecutionFrame& executionFrame) const {
|
||||
std::vector<at::Tensor> inputs =
|
||||
executionFrame.getTensorVector(inputValueId_);
|
||||
|
||||
auto outputs = delegateExecutor_.run(inputs);
|
||||
|
||||
executionFrame.setIValue(outputValueId_, std::move(outputs));
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
27
torch/csrc/nativert/kernels/AOTICallDelegateKernel.h
Normal file
27
torch/csrc/nativert/kernels/AOTICallDelegateKernel.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class AOTIDelegateExecutor;
|
||||
|
||||
// Kernel for torch.higher_order_ops.aoti_call_delegate
|
||||
class AOTICallDelegateKernel : public OpKernel {
|
||||
public:
|
||||
explicit AOTICallDelegateKernel(
|
||||
const Node* node,
|
||||
AOTIDelegateExecutor& delegateExecutor);
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
AOTIDelegateExecutor& delegateExecutor_;
|
||||
|
||||
ValueId inputValueId_;
|
||||
ValueId outputValueId_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
41
torch/csrc/nativert/kernels/AOTIKernel.cpp
Normal file
41
torch/csrc/nativert/kernels/AOTIKernel.cpp
Normal file
@ -0,0 +1,41 @@
|
||||
#include "torch/csrc/nativert/kernels/AOTIKernel.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
AOTIKernel::AOTIKernel(const Node* node, AOTIDelegateExecutor& delegateExecutor)
|
||||
: OpKernel(node), delegateExecutor_(delegateExecutor) {
|
||||
// The schema is "call_aotinductor(str path, Tensor[] weights, Tensor[]
|
||||
// inputs) -> Tensor[] outputs", expects 2 inputs and 1 output
|
||||
TORCH_CHECK_EQ(node->numInputs(), 2);
|
||||
TORCH_CHECK_EQ(node->numOutputs(), 1);
|
||||
|
||||
// Weights are in node->inputs()[0], but it's not used in the forward call
|
||||
// Instead, weight are bound to AOTI via loadWeights()
|
||||
const Value* input = node->inputs()[1].value;
|
||||
const Value* output = node->outputs()[0];
|
||||
|
||||
CHECK(input->type() == Type::TensorList)
|
||||
<< "delegate.call_aotinductor input should be a TensorList";
|
||||
CHECK(output->type() == Type::TensorList)
|
||||
<< "delegate.call_aotinductor output should be a TensorList";
|
||||
|
||||
inputValueId_ = input->id();
|
||||
outputValueId_ = output->id();
|
||||
}
|
||||
|
||||
void AOTIKernel::computeInternal(ExecutionFrame& executionFrame) const {
|
||||
std::vector<at::Tensor> inputs =
|
||||
executionFrame.getTensorVector(inputValueId_);
|
||||
|
||||
auto outputs = delegateExecutor_.run(inputs);
|
||||
|
||||
executionFrame.setIValue(outputValueId_, std::move(outputs));
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
26
torch/csrc/nativert/kernels/AOTIKernel.h
Normal file
26
torch/csrc/nativert/kernels/AOTIKernel.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class AOTIDelegateExecutor;
|
||||
|
||||
// Kernel for torch.ops.delegate.call_aotinductor
|
||||
// TODO: Deprecate this when we move to aoti_call_delegate HOP
|
||||
class AOTIKernel : public OpKernel {
|
||||
public:
|
||||
explicit AOTIKernel(const Node* node, AOTIDelegateExecutor& delegateExecutor);
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
AOTIDelegateExecutor& delegateExecutor_;
|
||||
|
||||
ValueId inputValueId_;
|
||||
ValueId outputValueId_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
64
torch/csrc/nativert/kernels/AutoFunctionalizeKernel.cpp
Normal file
64
torch/csrc/nativert/kernels/AutoFunctionalizeKernel.cpp
Normal file
@ -0,0 +1,64 @@
|
||||
#include "torch/csrc/nativert/kernels/AutoFunctionalizeKernel.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/Enumerate.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
|
||||
: OpKernel(node),
|
||||
op_(getOperatorForTarget(
|
||||
std::get<std::string>(node->attributes()[0].value))),
|
||||
schema_(op_.schema()),
|
||||
arguments_(prefillStackWithStaticArgs(node, schema_)) {
|
||||
for (const auto& [idx, schemaArg] : enumerate(schema_.arguments())) {
|
||||
if (schemaArg.alias_info() != nullptr &&
|
||||
schemaArg.alias_info()->isWrite()) {
|
||||
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
|
||||
}
|
||||
}
|
||||
|
||||
numOutputs_ = schema_.returns().size();
|
||||
}
|
||||
|
||||
void UnsafeAutoFunctionalizeKernel::computeInternal(
|
||||
ExecutionFrame& executionFrame) const {
|
||||
// Make a copy of the stack
|
||||
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
|
||||
|
||||
fillDynamicInputs(executionFrame, arguments_, stack);
|
||||
|
||||
// Call the op with the prepared stack.
|
||||
try {
|
||||
op_.callBoxed(stack);
|
||||
} catch (const std::exception& ex) {
|
||||
// TODO: this eats the original exception type. ATen returns different
|
||||
// exception types that correspond to different Python errors (e.g.
|
||||
// IndexError, ValueError). If retaining this information is important
|
||||
// to us, we'll have to change this up a little.
|
||||
auto stackTrace = node_->getMetadata("stack_trace");
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Original Python stacktrace:\n{}\n{}",
|
||||
stackTrace ? *stackTrace : "<no stack trace>",
|
||||
ex.what()));
|
||||
}
|
||||
|
||||
const auto& outputValues = node_->outputs();
|
||||
|
||||
for (int i = 0; i < numOutputs_; ++i) {
|
||||
executionFrame.setIValue(outputValues[i]->id(), std::move(stack.at(i)));
|
||||
}
|
||||
|
||||
// Copy over mutating inputs to outputs
|
||||
int mutatingArgStartIndex = (numOutputs_ == 0) ? 1 : numOutputs_;
|
||||
for (int i = mutatingArgStartIndex; i < outputValues.size(); ++i) {
|
||||
executionFrame.setIValue(
|
||||
outputValues[i]->id(),
|
||||
executionFrame.getIValue(
|
||||
mutatingInputArgs_.at(i - mutatingArgStartIndex)->id(),
|
||||
true /* allowNone */));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
27
torch/csrc/nativert/kernels/AutoFunctionalizeKernel.h
Normal file
27
torch/csrc/nativert/kernels/AutoFunctionalizeKernel.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/script.h>
|
||||
#include "c10/core/Device.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class UnsafeAutoFunctionalizeKernel : public OpKernel {
|
||||
public:
|
||||
UnsafeAutoFunctionalizeKernel() = delete; // deleted default constructor
|
||||
UnsafeAutoFunctionalizeKernel(const Node* node);
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
c10::OperatorHandle op_;
|
||||
c10::FunctionSchema schema_;
|
||||
|
||||
Arguments arguments_;
|
||||
|
||||
std::vector<Value*> mutatingInputArgs_;
|
||||
int numOutputs_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
242
torch/csrc/nativert/kernels/C10Kernel.cpp
Normal file
242
torch/csrc/nativert/kernels/C10Kernel.cpp
Normal file
@ -0,0 +1,242 @@
|
||||
#include "torch/csrc/nativert/kernels/C10Kernel.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/ostream.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/Enumerate.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
|
||||
#ifdef __SIGRID_USE_GPU__
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#endif
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
C10Kernel::C10Kernel(const Node* node, c10::Device device)
|
||||
: OpKernel(node, device),
|
||||
op_(getOperatorForTarget(node->target(), node)),
|
||||
schema_(op_.schema()),
|
||||
arguments_(prefillStackWithStaticArgs(node, schema_)) {}
|
||||
|
||||
void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const {
|
||||
// Make a copy of the stack
|
||||
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
|
||||
|
||||
fillDynamicInputs(executionFrame, arguments_, stack);
|
||||
|
||||
// Call the op with the prepared stack.
|
||||
try {
|
||||
op_.callBoxed(stack);
|
||||
} catch (const std::exception& ex) {
|
||||
auto stackTrace = node_->getMetadata("stack_trace");
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Exception while executing node: {}\n"
|
||||
"with args:\n{}\n"
|
||||
"{}\n"
|
||||
"Original Python stacktrace:\n{}",
|
||||
fmt::streamed(*node_),
|
||||
readableArgs(schema_, stack),
|
||||
ex.what(),
|
||||
stackTrace ? *stackTrace : "<no stack trace>"));
|
||||
}
|
||||
|
||||
// Write out results
|
||||
// TODO: we store intermediates in a single table (symint and tensor alike).
|
||||
// This can theoretically lead to name collisions, although based on how
|
||||
// these are named I don't think it will ever happen in practice. We need to
|
||||
// enforce it though.
|
||||
const auto& outputValues = node_->outputs();
|
||||
TORCH_CHECK_EQ(outputValues.size(), stack.size())
|
||||
<< "Output size mismatch for " << node_->toString();
|
||||
for (auto&& [i, actualOutput] : enumerate(stack)) {
|
||||
executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::unordered_map<std::string, c10::IValue> getSymInputs(
|
||||
const ExecutionFrame& executionFrame,
|
||||
const Node& node) {
|
||||
std::unordered_map<std::string, c10::IValue> inputs;
|
||||
for (const auto& input : node.inputs()) {
|
||||
const auto& val = executionFrame.getIValue(input.value->id());
|
||||
if (val.isInt() || val.isDouble() || val.isBool()) {
|
||||
inputs[input.name] = val;
|
||||
} else {
|
||||
throw std::runtime_error("unsupported type for symbolic input");
|
||||
}
|
||||
}
|
||||
for (const auto& attribute : node.attributes()) {
|
||||
if (std::holds_alternative<int64_t>(attribute.value)) {
|
||||
inputs[attribute.name] = std::get<int64_t>(attribute.value);
|
||||
} else if (std::holds_alternative<double>(attribute.value)) {
|
||||
inputs[attribute.name] = std::get<double>(attribute.value);
|
||||
} else if (std::holds_alternative<bool>(attribute.value)) {
|
||||
inputs[attribute.name] = std::get<bool>(attribute.value);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported type for symbolic input");
|
||||
}
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void computeScalarBinaryOp(
|
||||
ExecutionFrame& executionFrame,
|
||||
const Node& node,
|
||||
std::enable_if_t<true, T> a,
|
||||
std::enable_if_t<true, T> b) {
|
||||
std::string_view target = node.target();
|
||||
T out;
|
||||
|
||||
if (target == "_operator.add") {
|
||||
out = a + b;
|
||||
} else if (target == "_operator.sub") {
|
||||
out = a - b;
|
||||
} else if (target == "_operator.mul") {
|
||||
out = a * b;
|
||||
} else if (target == "_operator.pow") {
|
||||
out = std::pow(a, b);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
fmt::format("unsupported operator for symbolic values: {}", target));
|
||||
}
|
||||
|
||||
executionFrame.setIValue(node.outputs()[0]->id(), out);
|
||||
VLOG(2) << fmt::format(
|
||||
"Completed executing node: {} with a={}, b={}, out={}",
|
||||
fmt::streamed(node),
|
||||
a,
|
||||
b,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScalarBinaryOpKernel::computeInternal(
|
||||
ExecutionFrame& executionFrame) const {
|
||||
auto inputs = getSymInputs(executionFrame, *node_);
|
||||
|
||||
const auto& a = inputs.at("a");
|
||||
const auto& b = inputs.at("b");
|
||||
|
||||
auto coerceToDouble = [](const c10::IValue& x) -> double {
|
||||
if (x.isInt()) {
|
||||
return static_cast<double>(x.toInt());
|
||||
} else if (x.isDouble()) {
|
||||
return x.toDouble();
|
||||
} else {
|
||||
throw std::runtime_error("unsupported type for symbolic input");
|
||||
}
|
||||
};
|
||||
|
||||
if (a.isInt() && b.isInt()) {
|
||||
computeScalarBinaryOp<int64_t>(
|
||||
executionFrame, *node_, a.toInt(), b.toInt());
|
||||
} else {
|
||||
computeScalarBinaryOp<double>(
|
||||
executionFrame, *node_, coerceToDouble(a), coerceToDouble(b));
|
||||
}
|
||||
}
|
||||
|
||||
void SymIntOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
|
||||
auto inputs = getSymInputs(executionFrame, *node_);
|
||||
|
||||
int64_t a = inputs.at("a").toInt();
|
||||
std::string_view target = node_->target();
|
||||
if (target == "torch.sym_float") {
|
||||
double out = static_cast<double>(a);
|
||||
executionFrame.setIValue(node_->outputs()[0]->id(), out);
|
||||
VLOG(2) << fmt::format(
|
||||
"Completed executing node: {} with a={}, out={}",
|
||||
fmt::streamed(*node_),
|
||||
a,
|
||||
out);
|
||||
return;
|
||||
}
|
||||
int64_t b = inputs.at("b").toInt();
|
||||
int64_t out;
|
||||
|
||||
if (target == "_operator.floordiv") {
|
||||
out = a / b;
|
||||
} else if (target == "_operator.mod") {
|
||||
out = a % b;
|
||||
} else if (target == "torch.sym_max") {
|
||||
out = std::max(a, b);
|
||||
} else if (target == "torch.sym_min") {
|
||||
out = std::min(a, b);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
fmt::format("unsupported operator for SymInt: {}", node_->target()));
|
||||
}
|
||||
|
||||
executionFrame.setIValue(node_->outputs()[0]->id(), out);
|
||||
VLOG(2) << fmt::format(
|
||||
"Completed executing node: {} with a={}, b={}, out={}",
|
||||
fmt::streamed(*node_),
|
||||
a,
|
||||
b,
|
||||
out);
|
||||
}
|
||||
|
||||
void SymBoolOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
|
||||
auto inputs = getSymInputs(executionFrame, *node_);
|
||||
|
||||
bool out;
|
||||
|
||||
const std::string_view target = node_->target();
|
||||
if (target == "torch.sym_not") {
|
||||
bool a = inputs.at("a").toBool();
|
||||
out = !a;
|
||||
} else if (target == "_operator.ge") {
|
||||
int64_t a = inputs.at("a").toInt();
|
||||
int64_t b = inputs.at("b").toInt();
|
||||
out = a >= b;
|
||||
} else if (target == "_operator.le") {
|
||||
int64_t a = inputs.at("a").toInt();
|
||||
int64_t b = inputs.at("b").toInt();
|
||||
out = a <= b;
|
||||
} else if (target == "_operator.eq") {
|
||||
int64_t a = inputs.at("a").toInt();
|
||||
int64_t b = inputs.at("b").toInt();
|
||||
out = a == b;
|
||||
} else if (target == "_operator.gt") {
|
||||
int64_t a = inputs.at("a").toInt();
|
||||
int64_t b = inputs.at("b").toInt();
|
||||
out = a > b;
|
||||
} else if (target == "_operator.lt") {
|
||||
int64_t a = inputs.at("a").toInt();
|
||||
int64_t b = inputs.at("b").toInt();
|
||||
out = a < b;
|
||||
} else if (target == "_operator.and_") {
|
||||
bool a = inputs.at("a").toBool();
|
||||
bool b = inputs.at("b").toBool();
|
||||
out = a && b;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
fmt::format("unsupported operator for SymBool: {}", node_->target()));
|
||||
}
|
||||
|
||||
executionFrame.setIValue(node_->outputs()[0]->id(), out);
|
||||
}
|
||||
|
||||
void SymFloatOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
|
||||
auto inputs = getSymInputs(executionFrame, *node_);
|
||||
|
||||
const std::string_view target = node_->target();
|
||||
if (target == "math.trunc") {
|
||||
double x = inputs.at("x").toDouble();
|
||||
int64_t out = trunc(x);
|
||||
executionFrame.setIValue(node_->outputs()[0]->id(), out);
|
||||
} else if (target == "torch._sym_sqrt") {
|
||||
double a = inputs.at("a").toDouble();
|
||||
double out = std::sqrt(a);
|
||||
executionFrame.setIValue(node_->outputs()[0]->id(), out);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
fmt::format("unsupported operator for SymFloat: {}", node_->target()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
73
torch/csrc/nativert/kernels/C10Kernel.h
Normal file
73
torch/csrc/nativert/kernels/C10Kernel.h
Normal file
@ -0,0 +1,73 @@
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/script.h>
|
||||
#include "c10/core/Device.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
// Implementation of Kernel for ATen operators
|
||||
//
|
||||
// This class exists to amortize per-kernel overhead by computing things during
|
||||
// initialization instead of on every execution. Right now we are only
|
||||
// amortizing schema resolution, and static arguments parsing,
|
||||
// but in the future this could be extended to avoid operator dispatch and
|
||||
// do better "Register" allocation (e.g. convert input/outputs to directly
|
||||
// array accesses onto a set of registers, in concert with memory planning)
|
||||
class C10Kernel : public OpKernel {
|
||||
public:
|
||||
C10Kernel() = delete; // deleted default constructor
|
||||
C10Kernel(const Node* node, c10::Device device);
|
||||
virtual ~C10Kernel() = default;
|
||||
|
||||
[[nodiscard]] const c10::IValue& input(
|
||||
uint32_t i,
|
||||
ExecutionFrame& executionFrame) const override {
|
||||
if (Value* dynamicArg = arguments_.findDynamic(i)) {
|
||||
return executionFrame.getIValue(dynamicArg->id());
|
||||
} else {
|
||||
return arguments_.getStatic(i);
|
||||
}
|
||||
}
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override;
|
||||
|
||||
private:
|
||||
c10::OperatorHandle op_;
|
||||
c10::FunctionSchema schema_;
|
||||
|
||||
Arguments arguments_;
|
||||
};
|
||||
|
||||
class SymIntOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit SymIntOpKernel(const Node* node) : OpKernel(node) {}
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
};
|
||||
|
||||
class SymBoolOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit SymBoolOpKernel(const Node* node) : OpKernel(node) {}
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
};
|
||||
|
||||
class SymFloatOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit SymFloatOpKernel(const Node* node) : OpKernel(node) {}
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
};
|
||||
|
||||
// ScalarOpKernel does binary arithmetic operations on scalar values.
|
||||
// Integers and floats are supported as input types. The output will be
|
||||
// promoted to float if and only if there's at least one float input.
|
||||
class ScalarBinaryOpKernel : public OpKernel {
|
||||
public:
|
||||
explicit ScalarBinaryOpKernel(const Node* node) : OpKernel(node) {}
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
48
torch/csrc/nativert/kernels/CallTorchBindKernel.cpp
Normal file
48
torch/csrc/nativert/kernels/CallTorchBindKernel.cpp
Normal file
@ -0,0 +1,48 @@
|
||||
#include "torch/csrc/nativert/kernels/CallTorchBindKernel.h"
|
||||
#include "torch/csrc/nativert/common/Enumerate.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) {
|
||||
const Value* customObjValue = node_->inputs()[0].value;
|
||||
CHECK(customObjValue->type() == Type::CustomObj);
|
||||
|
||||
customClassName_ = customObjValue->type().classFqn();
|
||||
customClassType_ = torch::jit::getCustomClass(customClassName_);
|
||||
|
||||
// sample schema
|
||||
// torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1);
|
||||
|
||||
CHECK(node->attributes().size() == 1)
|
||||
<< "Expects higher_order.call_torchbind to only have a single attribute, methodName";
|
||||
const auto& attr = node->attributes()[0];
|
||||
|
||||
CHECK(std::holds_alternative<std::string>(attr.value))
|
||||
<< "method should be a string";
|
||||
methodName_ = std::get<std::string>(attr.value);
|
||||
method_ = customClassType_->findMethod(methodName_);
|
||||
|
||||
CHECK(method_ != nullptr) << "method not found: " << methodName_;
|
||||
}
|
||||
|
||||
void CallTorchBindKernel::computeInternal(
|
||||
ExecutionFrame& executionFrame) const {
|
||||
// prepare inputs
|
||||
std::vector<c10::IValue> stack;
|
||||
for (const auto& input : node_->inputs()) {
|
||||
const auto& id = input.value->id();
|
||||
stack.emplace_back(executionFrame.getIValue(id));
|
||||
}
|
||||
|
||||
// call the method
|
||||
method_->run(stack);
|
||||
|
||||
// set outputs
|
||||
const auto& outputs = node_->outputs();
|
||||
TORCH_CHECK_EQ(outputs.size(), stack.size());
|
||||
for (auto&& [i, outputValue] : enumerate(stack)) {
|
||||
executionFrame.setIValue(outputs[i]->id(), std::move(outputValue));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
25
torch/csrc/nativert/kernels/CallTorchBindKernel.h
Normal file
25
torch/csrc/nativert/kernels/CallTorchBindKernel.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/script.h>
|
||||
#include "c10/core/Device.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class CallTorchBindKernel : public OpKernel {
|
||||
public:
|
||||
CallTorchBindKernel() = delete; // deleted default constructor
|
||||
CallTorchBindKernel(const Node* node);
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
std::string methodName_;
|
||||
torch::jit::Function* method_;
|
||||
|
||||
std::string customClassName_;
|
||||
at::ClassTypePtr customClassType_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
3211
torch/csrc/nativert/kernels/GeneratedStaticDispatchKernels.cpp
Normal file
3211
torch/csrc/nativert/kernels/GeneratedStaticDispatchKernels.cpp
Normal file
File diff suppressed because it is too large
Load Diff
116
torch/csrc/nativert/kernels/HigherOrderKernel.cpp
Normal file
116
torch/csrc/nativert/kernels/HigherOrderKernel.cpp
Normal file
@ -0,0 +1,116 @@
|
||||
#include "torch/csrc/nativert/kernels/HigherOrderKernel.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "torch/csrc/nativert/common/String.h"
|
||||
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
HigherOrderKernel::HigherOrderKernel(
|
||||
const Node* node,
|
||||
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors)
|
||||
: OpKernel(node), graphExecutors_(std::move(graphExecutors)) {
|
||||
static constexpr std::string_view prefix = "torch.ops.higher_order.";
|
||||
CHECK(starts_with(node->target(), prefix));
|
||||
auto opName = node->target().substr(prefix.size());
|
||||
if (opName == "cond") {
|
||||
opType_ = OpType::COND;
|
||||
// Checking torch.cond schema is as expected:
|
||||
// torch.cond(Tensor predicate, Graph graph1, Graph graph2, Tensor[] args)
|
||||
// -> Tensor[]
|
||||
TORCH_CHECK_EQ(node_->attributes().size(), 2);
|
||||
TORCH_CHECK_EQ(node_->inputs().size(), 2);
|
||||
} else if (opName == "while_loop") {
|
||||
opType_ = OpType::WHILE_LOOP;
|
||||
// Checking torch.while_loop schema is as expected:
|
||||
// torch.while_loop(Graph cond, Graph body, Tensor[] args, Tensor[]
|
||||
// additonal) -> Tensor[]
|
||||
TORCH_CHECK_EQ(node_->attributes().size(), 2);
|
||||
TORCH_CHECK_EQ(node_->inputs().size(), 2);
|
||||
} else if (opName == "run_const_graph") {
|
||||
opType_ = OpType::RUN_CONST_GRAPH;
|
||||
// Checking torch.run_const_graph schema is as expected:
|
||||
// torch.run_const_graph(Graph graph, Tensor[] args) -> Tensor[]
|
||||
TORCH_CHECK_GE(node_->attributes().size(), 1);
|
||||
TORCH_CHECK_EQ(node_->inputs().size(), 1);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
fmt::format("Unknown higher order op: {}", opName));
|
||||
}
|
||||
}
|
||||
|
||||
void HigherOrderKernel::computeInternal(ExecutionFrame& executionFrame) const {
|
||||
switch (opType_) {
|
||||
case OpType::COND: {
|
||||
auto inputs = executionFrame.getIValue(node_->inputs()[1].value->id())
|
||||
.toList()
|
||||
.vec();
|
||||
std::vector<c10::IValue> outputs;
|
||||
auto cond = executionFrame.getIValue(node_->inputs()[0].value->id());
|
||||
size_t branchIdx = 0;
|
||||
if (cond.isTensor()) {
|
||||
branchIdx = cond.toTensor().item().toBool() ? 0 : 1;
|
||||
} else if (cond.isBool()) {
|
||||
branchIdx = cond.toBool() ? 0 : 1;
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported type for cond predicate");
|
||||
}
|
||||
ExecutionFrame branchFrame(*std::get<std::unique_ptr<Graph>>(
|
||||
node_->attributes()[branchIdx].value));
|
||||
graphExecutors_[branchIdx]->execute(branchFrame, std::move(inputs));
|
||||
auto ret = branchFrame.getUserOutputs();
|
||||
for (size_t i = 0; i < ret.size(); i++) {
|
||||
executionFrame.setIValue(node_->outputs()[i]->id(), std::move(ret[i]));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OpType::WHILE_LOOP: {
|
||||
auto carriedVals =
|
||||
executionFrame.getIValue(node_->inputs()[0].value->id())
|
||||
.toList()
|
||||
.vec();
|
||||
auto additonalVals =
|
||||
executionFrame.getIValue(node_->inputs()[1].value->id())
|
||||
.toList()
|
||||
.vec();
|
||||
size_t numCarriedVals = carriedVals.size();
|
||||
ExecutionFrame condFrame(
|
||||
*std::get<std::unique_ptr<Graph>>(node_->attributes()[0].value));
|
||||
ExecutionFrame bodyFrame(
|
||||
*std::get<std::unique_ptr<Graph>>(node_->attributes()[1].value));
|
||||
while (true) {
|
||||
auto inputs = carriedVals;
|
||||
inputs.insert(inputs.end(), additonalVals.begin(), additonalVals.end());
|
||||
graphExecutors_[0]->execute(condFrame, inputs);
|
||||
auto cond = condFrame.getUserOutputs();
|
||||
|
||||
if (cond.at(0).isTensor() && !cond[0].toTensor().item().toBool()) {
|
||||
break;
|
||||
}
|
||||
if (cond.at(0).isBool() && !cond[0].toBool()) {
|
||||
break;
|
||||
}
|
||||
graphExecutors_[1]->execute(bodyFrame, std::move(inputs));
|
||||
auto out = bodyFrame.getUserOutputs();
|
||||
TORCH_CHECK(out.size() == numCarriedVals);
|
||||
carriedVals = std::move(out);
|
||||
}
|
||||
for (size_t i = 0; i < carriedVals.size(); i++) {
|
||||
executionFrame.setIValue(
|
||||
node_->outputs()[i]->id(), std::move(carriedVals[i]));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OpType::RUN_CONST_GRAPH: {
|
||||
// run_const_graph op is a special case of higher order op which has
|
||||
// been executed during weights loading, therefore at runtime we can
|
||||
// just make this a no-op.
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown higher order op");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
28
torch/csrc/nativert/kernels/HigherOrderKernel.h
Normal file
28
torch/csrc/nativert/kernels/HigherOrderKernel.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "c10/core/Device.h"
|
||||
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class HigherOrderKernel : public OpKernel {
|
||||
enum class OpType {
|
||||
UNKNOWN,
|
||||
COND,
|
||||
WHILE_LOOP,
|
||||
RUN_CONST_GRAPH,
|
||||
};
|
||||
|
||||
public:
|
||||
HigherOrderKernel(
|
||||
const Node* node,
|
||||
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors);
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors_;
|
||||
OpType opType_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
288
torch/csrc/nativert/kernels/KernelFactory.cpp
Normal file
288
torch/csrc/nativert/kernels/KernelFactory.cpp
Normal file
@ -0,0 +1,288 @@
|
||||
#include "torch/csrc/nativert/kernels/KernelFactory.h"
|
||||
#include <string_view>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/kernels/AOTICallDelegateKernel.h"
|
||||
#include "torch/csrc/nativert/kernels/AOTIKernel.h"
|
||||
|
||||
#ifdef __SIGRID_USE_FBA__
|
||||
#include "sigmoid/backend/MTIADelegateExecutor.h"
|
||||
#include "sigmoid/backend/MTIAKernel.h"
|
||||
#endif
|
||||
|
||||
#include "torch/csrc/nativert/common/String.h"
|
||||
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
#include "torch/csrc/nativert/executor/ParallelGraphExecutor.h"
|
||||
#include "torch/csrc/nativert/executor/SerialGraphExecutor.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
#include "torch/csrc/nativert/graph/GraphPasses.h"
|
||||
#include "torch/csrc/nativert/kernels/AutoFunctionalizeKernel.h"
|
||||
#include "torch/csrc/nativert/kernels/C10Kernel.h"
|
||||
#include "torch/csrc/nativert/kernels/CallTorchBindKernel.h"
|
||||
#include "torch/csrc/nativert/kernels/HigherOrderKernel.h"
|
||||
#include "torch/csrc/nativert/kernels/KernelRegistry.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
|
||||
c10::Device inferTargetDevice(
|
||||
const Node& node,
|
||||
const std::unordered_map<std::string, TensorMeta>& tensorValuesMeta,
|
||||
const Placement& placement) {
|
||||
if (node.target() == "prim.Input" || node.target() == "prim.Output") {
|
||||
return c10::Device(c10::DeviceType::CPU);
|
||||
}
|
||||
|
||||
std::vector<c10::Device> devices;
|
||||
for (auto& output : node.outputs()) {
|
||||
if (output->type() == Type::Tensor) {
|
||||
auto it = tensorValuesMeta.find(std::string{output->name()});
|
||||
if (it != tensorValuesMeta.end()) {
|
||||
devices.emplace_back(it->second.device());
|
||||
}
|
||||
} else if (output->type() == Type::TensorList) {
|
||||
for (const auto& el : output->getListElements()) {
|
||||
auto it = tensorValuesMeta.find(std::string{el->name()});
|
||||
if (it != tensorValuesMeta.end()) {
|
||||
devices.emplace_back(it->second.device());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (devices.empty()) {
|
||||
return c10::Device(c10::DeviceType::CPU);
|
||||
} else {
|
||||
for (size_t i = 1; i < devices.size(); ++i) {
|
||||
if (!isSameDevice(devices[0], devices[i])) {
|
||||
LOG(WARNING) << "Node " << node
|
||||
<< " has outputs on multiple devices: " << devices[0]
|
||||
<< " and " << devices[i];
|
||||
}
|
||||
}
|
||||
|
||||
return placement.getMappedDevice(devices[0]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
inline constexpr std::string_view kSymIntOps[] = {
|
||||
"_operator.floordiv",
|
||||
"_operator.mod",
|
||||
"torch.sym_int",
|
||||
"torch.sym_float",
|
||||
"torch.sym_ite",
|
||||
"torch.sym_max",
|
||||
"torch.sym_min",
|
||||
};
|
||||
|
||||
inline constexpr std::string_view kSymBoolOps[] = {
|
||||
"_operator.eq",
|
||||
"_operator.ne",
|
||||
"_operator.le",
|
||||
"_operator.ge",
|
||||
"_operator.lt",
|
||||
"_operator.gt",
|
||||
"_operator.and_",
|
||||
"torch.sym_not",
|
||||
};
|
||||
|
||||
inline constexpr std::string_view kSymFloatOps[] = {
|
||||
"torch._sym_sqrt",
|
||||
"math.trunc",
|
||||
};
|
||||
|
||||
inline constexpr std::string_view kScalarBinaryOps[] = {
|
||||
"_operator.mul",
|
||||
"_operator.add",
|
||||
"_operator.sub",
|
||||
"_operator.pow",
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
const std::string maybeRevisedStaticDispatchTarget(const Node& node) {
|
||||
auto overloadName = selectScalarOverloadName(node);
|
||||
if (!ends_with(node.target(), overloadName)) {
|
||||
const std::string& newTarget =
|
||||
std::string(node.target())
|
||||
.replace(node.target().rfind('.'), std::string::npos, overloadName);
|
||||
LOG(INFO) << fmt::format(
|
||||
"Converting Tensor to {} for node: {} -> {}",
|
||||
overloadName,
|
||||
node.target(),
|
||||
newTarget);
|
||||
return newTarget;
|
||||
}
|
||||
return std::string(node.target());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ExecutionKernels KernelFactory::initializeNodeKernels(
|
||||
const Graph& graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const ExecutorConfig& executorConfig,
|
||||
const Placement& placement,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader) {
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
|
||||
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
|
||||
std::vector<ConstFoldingExecution> constFoldingExecutions;
|
||||
|
||||
std::unordered_set<std::string> opsWithoutStaticDispatch;
|
||||
|
||||
VLOG(1) << "PrimKernelRegistry: " << join(", ", PrimKernelRegistry()->Keys());
|
||||
VLOG(1) << "StaticallyDispatchedCPUKernelRegistry: "
|
||||
<< join(", ", StaticallyDispatchedCPUKernelRegistry()->Keys());
|
||||
|
||||
for (const auto& node : graph.nodes()) {
|
||||
std::string originalTarget = std::string(node.target());
|
||||
const std::string target =
|
||||
(executorConfig.enableStaticCPUKernels &&
|
||||
StaticallyDispatchedCPUKernelRegistry()->Has(originalTarget))
|
||||
? maybeRevisedStaticDispatchTarget(node)
|
||||
: std::move(originalTarget);
|
||||
c10::Device targetDevice =
|
||||
inferTargetDevice(node, graph.tensorValuesMeta(), placement);
|
||||
|
||||
if (PrimKernelRegistry()->Has(target)) {
|
||||
nodeKernels.push_back(PrimKernelRegistry()->Create(target, &node));
|
||||
} else if (
|
||||
executorConfig.enableStaticCPUKernels &&
|
||||
StaticallyDispatchedCPUKernelRegistry()->Has(target) &&
|
||||
targetDevice.is_cpu()) {
|
||||
nodeKernels.push_back(StaticallyDispatchedCPUKernelRegistry()->Create(
|
||||
target, &node, targetDevice));
|
||||
} else if (starts_with(
|
||||
node.target(), "torch.ops.higher_order.call_torchbind")) {
|
||||
nodeKernels.push_back(std::make_unique<CallTorchBindKernel>(&node));
|
||||
} else if (
|
||||
starts_with(
|
||||
node.target(),
|
||||
"torch.ops.higher_order.auto_functionalized") ||
|
||||
starts_with( // TODO Remove this condition once the old
|
||||
// pt2 archives are expired.
|
||||
node.target(),
|
||||
"torch._higher_order_ops.auto_functionalize.auto_functionalized")) {
|
||||
nodeKernels.push_back(
|
||||
std::make_unique<UnsafeAutoFunctionalizeKernel>(&node));
|
||||
} else if (
|
||||
std::find(
|
||||
std::begin(kSymIntOps), std::end(kSymIntOps), node.target()) !=
|
||||
std::end(kSymIntOps)) {
|
||||
nodeKernels.push_back(std::make_unique<SymIntOpKernel>(&node));
|
||||
} else if (
|
||||
std::find(
|
||||
std::begin(kSymBoolOps), std::end(kSymBoolOps), node.target()) !=
|
||||
std::end(kSymBoolOps)) {
|
||||
nodeKernels.push_back(std::make_unique<SymBoolOpKernel>(&node));
|
||||
} else if (
|
||||
std::find(
|
||||
std::begin(kSymFloatOps), std::end(kSymFloatOps), node.target()) !=
|
||||
std::end(kSymFloatOps)) {
|
||||
nodeKernels.push_back(std::make_unique<SymFloatOpKernel>(&node));
|
||||
} else if (
|
||||
std::find(
|
||||
std::begin(kScalarBinaryOps),
|
||||
std::end(kScalarBinaryOps),
|
||||
node.target()) != std::end(kScalarBinaryOps)) {
|
||||
nodeKernels.push_back(std::make_unique<ScalarBinaryOpKernel>(&node));
|
||||
} else if (starts_with(
|
||||
node.target(), "torch.ops.delegate.call_aotinductor")) {
|
||||
const auto pathAttr = node.tryGetAttribute("path");
|
||||
CHECK(pathAttr != nullptr);
|
||||
|
||||
const Constant& pathValue = pathAttr->value;
|
||||
CHECK(std::holds_alternative<std::string>(pathValue));
|
||||
std::string path = std::get<std::string>(pathValue);
|
||||
|
||||
auto delegateExecutor = std::make_unique<AOTIDelegateExecutor>(
|
||||
path, weights, targetDevice, executorConfig, pytorchStreamReader);
|
||||
nodeKernels.push_back(
|
||||
std::make_unique<AOTIKernel>(&node, *delegateExecutor));
|
||||
delegateExecutors.push_back(std::move(delegateExecutor));
|
||||
} else if (starts_with(
|
||||
node.target(),
|
||||
"torch.ops.higher_order.aoti_call_delegate")) {
|
||||
// the first attribute is serialized as the path to the aotinductor
|
||||
const auto pathAttr = node.attributes().begin();
|
||||
const Constant& pathValue = pathAttr->value;
|
||||
CHECK(std::holds_alternative<std::string>(pathValue));
|
||||
std::string path = std::get<std::string>(pathValue);
|
||||
|
||||
auto delegateExecutor = std::make_unique<AOTIDelegateExecutor>(
|
||||
path, weights, targetDevice, executorConfig, pytorchStreamReader);
|
||||
nodeKernels.push_back(
|
||||
std::make_unique<AOTICallDelegateKernel>(&node, *delegateExecutor));
|
||||
delegateExecutors.push_back(std::move(delegateExecutor));
|
||||
} else if (starts_with(node.target(), "torch.ops.delegate.call_mtia")) {
|
||||
#ifdef __SIGRID_USE_FBA__
|
||||
auto delegateExecutor = std::make_unique<MTIADelegateExecutor>(
|
||||
&node, weights, executorConfig, pytorchStreamReader);
|
||||
nodeKernels.push_back(
|
||||
std::make_unique<MTIAKernel>(&node, *delegateExecutor));
|
||||
delegateExecutors.push_back(std::move(delegateExecutor));
|
||||
#endif
|
||||
} else if (starts_with(node.target(), "torch.ops.higher_order")) {
|
||||
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors;
|
||||
for (const auto& attr : node.attributes()) {
|
||||
if (std::holds_alternative<std::unique_ptr<Graph>>(attr.value)) {
|
||||
const auto& subgraph = std::get<std::unique_ptr<Graph>>(attr.value);
|
||||
auto executionKernels = initializeNodeKernels(
|
||||
*subgraph, weights, executorConfig, placement);
|
||||
CHECK(executionKernels.delegateExecutors.empty())
|
||||
<< "HigherOrderKernel does not support delegates";
|
||||
CHECK(executionKernels.constFoldingExecutions.size() == 0)
|
||||
<< "HigherOrderKernel does not support const folding";
|
||||
if (executorConfig.maxParallelOps > 1) {
|
||||
graphExecutors.emplace_back(
|
||||
std::unique_ptr<GraphExecutorBase>(new ParallelGraphExecutor(
|
||||
*subgraph,
|
||||
std::move(executionKernels.nodeKernels),
|
||||
executorConfig)));
|
||||
} else {
|
||||
graphExecutors.emplace_back(
|
||||
std::unique_ptr<GraphExecutorBase>(new SerialGraphExecutor(
|
||||
*subgraph,
|
||||
std::move(executionKernels.nodeKernels),
|
||||
executorConfig)));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (node.target() == "torch.ops.higher_order.run_const_graph") {
|
||||
constFoldingExecutions.push_back(
|
||||
ConstFoldingExecution{std::move(graphExecutors[0])});
|
||||
}
|
||||
nodeKernels.push_back(std::make_unique<HigherOrderKernel>(
|
||||
&node, std::move(graphExecutors)));
|
||||
} else if (starts_with(node.target(), "torch.ops")) {
|
||||
nodeKernels.push_back(std::make_unique<C10Kernel>(&node, targetDevice));
|
||||
|
||||
opsWithoutStaticDispatch.insert(std::string(node.target()));
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported operator: ", target);
|
||||
}
|
||||
}
|
||||
|
||||
if (executorConfig.enableStaticCPUKernels) {
|
||||
std::stringstream ss;
|
||||
for (const auto& op : opsWithoutStaticDispatch) {
|
||||
ss << op << ", ";
|
||||
}
|
||||
LOG(WARNING) << "Following ops are missing static dispatched kernels: "
|
||||
<< ss.str();
|
||||
}
|
||||
|
||||
return {
|
||||
std::move(nodeKernels),
|
||||
std::move(delegateExecutors),
|
||||
std::move(constFoldingExecutions)};
|
||||
}
|
||||
} // namespace torch::nativert
|
||||
37
torch/csrc/nativert/kernels/KernelFactory.h
Normal file
37
torch/csrc/nativert/kernels/KernelFactory.h
Normal file
@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
|
||||
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class DelegateExecutor;
|
||||
|
||||
struct ConstFoldingExecution {
|
||||
std::unique_ptr<GraphExecutorBase> executor;
|
||||
};
|
||||
|
||||
struct ExecutionKernels {
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
|
||||
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
|
||||
std::vector<ConstFoldingExecution> constFoldingExecutions;
|
||||
};
|
||||
|
||||
class KernelFactory {
|
||||
public:
|
||||
explicit KernelFactory() {}
|
||||
|
||||
ExecutionKernels initializeNodeKernels(
|
||||
const Graph& graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const ExecutorConfig& executorConfig,
|
||||
const Placement& placement,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
pytorchStreamReader = nullptr);
|
||||
};
|
||||
} // namespace torch::nativert
|
||||
1191
torch/csrc/nativert/kernels/KernelRegistry.cpp
Normal file
1191
torch/csrc/nativert/kernels/KernelRegistry.cpp
Normal file
File diff suppressed because it is too large
Load Diff
117
torch/csrc/nativert/kernels/KernelRegistry.h
Normal file
117
torch/csrc/nativert/kernels/KernelRegistry.h
Normal file
@ -0,0 +1,117 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
#include "torch/csrc/nativert/executor/OpKernel.h"
|
||||
#include "torch/csrc/nativert/graph/Graph.h"
|
||||
#include "torch/csrc/nativert/kernels/C10Kernel.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
#define KernelInput(id) input(id, executionFrame)
|
||||
#define KernelOutput(id) output(id, executionFrame)
|
||||
|
||||
TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*);
|
||||
|
||||
#define REGISTER_PRIM_KERNEL(name, id, ...) \
|
||||
class OpKernel_##id : public OpKernel { \
|
||||
public: \
|
||||
OpKernel_##id(const Node* node) : OpKernel(node) { \
|
||||
kind_ = OpKernel::Kind::kPrimKernel; \
|
||||
} \
|
||||
void computeInternal( \
|
||||
ExecutionFrame& executionFrame) const override final { \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}; \
|
||||
C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id);
|
||||
|
||||
TORCH_DECLARE_REGISTRY(
|
||||
StaticallyDispatchedCPUKernelRegistry,
|
||||
OpKernel,
|
||||
const Node*,
|
||||
c10::Device);
|
||||
|
||||
#define REGISTER_CPU_KERNEL(name, id, ...) \
|
||||
class OpKernel_##id : public C10Kernel { \
|
||||
public: \
|
||||
OpKernel_##id(const Node* node, c10::Device device) \
|
||||
: C10Kernel(node, device) { \
|
||||
kind_ = OpKernel::Kind::kStaticDispatchKernel; \
|
||||
} \
|
||||
void computeInternal( \
|
||||
ExecutionFrame& executionFrame) const override final { \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}; \
|
||||
C10_REGISTER_TYPED_CLASS( \
|
||||
StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id);
|
||||
|
||||
inline bool checkResizedDataPtr(at::Tensor& t) {
|
||||
auto const prev_data_ptr = t.data_ptr();
|
||||
t.resize_({0});
|
||||
return prev_data_ptr == t.data_ptr();
|
||||
}
|
||||
|
||||
inline void fastResizeToZero(at::Tensor& t) {
|
||||
t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
|
||||
}
|
||||
|
||||
inline at::Tensor create_empty_from(const at::Tensor& t) {
|
||||
return at::detail::empty_cpu(
|
||||
{0},
|
||||
c10::typeMetaToScalarType(t.dtype()),
|
||||
t.layout(),
|
||||
t.device(),
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
inline at::Tensor create_empty_from(
|
||||
const at::Tensor& t,
|
||||
c10::ScalarType dtype) {
|
||||
return at::detail::empty_cpu(
|
||||
{0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) {
|
||||
return at::detail::empty_cpu(
|
||||
{0},
|
||||
c10::typeMetaToScalarType(t.dtype()),
|
||||
t.layout(),
|
||||
device,
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
}
|
||||
inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) {
|
||||
return at::detail::empty_cpu(
|
||||
{0},
|
||||
c10::typeMetaToScalarType(t.dtype()),
|
||||
layout,
|
||||
t.device(),
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
inline at::Tensor create_empty_from(
|
||||
const at::Tensor& t,
|
||||
c10::MemoryFormat memory_format) {
|
||||
return at::detail::empty_cpu(
|
||||
{0},
|
||||
c10::typeMetaToScalarType(t.dtype()),
|
||||
t.layout(),
|
||||
t.device(),
|
||||
std::nullopt,
|
||||
memory_format);
|
||||
}
|
||||
|
||||
inline at::Tensor create_empty_from(
|
||||
const at::Tensor& t,
|
||||
c10::ScalarType dtype,
|
||||
c10::MemoryFormat memory_format) {
|
||||
return at::detail::empty_cpu(
|
||||
{0}, dtype, t.layout(), t.device(), std::nullopt, memory_format);
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
14
torch/csrc/nativert/kernels/NativeKernels.cpp
Normal file
14
torch/csrc/nativert/kernels/NativeKernels.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include "torch/csrc/nativert/kernels/KernelRegistry.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
REGISTER_CPU_KERNEL("torch.ops.aten.slice.Tensor", aten_slice_Tensor, {
|
||||
const auto& self = KernelInput(0).toTensor();
|
||||
const auto& dim = KernelInput(1).toInt();
|
||||
const auto& start = KernelInput(2).toOptional<int64_t>();
|
||||
const auto& end = KernelInput(3).toOptional<int64_t>();
|
||||
const auto& step = KernelInput(4).toInt();
|
||||
KernelOutput(0) = at::native::slice(self, dim, start, end, step);
|
||||
});
|
||||
|
||||
} // namespace torch::nativert
|
||||
68
torch/csrc/nativert/package/pt2_archive_constants.h
Normal file
68
torch/csrc/nativert/package/pt2_archive_constants.h
Normal file
@ -0,0 +1,68 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string_view>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace archive_spec {
|
||||
|
||||
#define FORALL_COSTANTS(_) \
|
||||
_(ARCHIVE_ROOT_NAME, "package") \
|
||||
/* Archive format */ \
|
||||
_(ARCHIVE_FORMAT_PATH, "archive_format") \
|
||||
_(ARCHIVE_FORMAT_VALUE, "pt2") \
|
||||
/* Archive version */ \
|
||||
_(ARCHIVE_VERSION_PATH, "archive_version") \
|
||||
_(ARCHIVE_VERSION_VALUE, \
|
||||
"0") /* Sep.4.2024: This is the initial version of the PT2 Archive Spec */ \
|
||||
/* \
|
||||
* ######## Note on updating ARCHIVE_VERSION_VALUE ######## \
|
||||
* When there is a BC breaking change to the PT2 Archive Spec, \
|
||||
* e.g. deleting a folder, or changing the naming convention of the \
|
||||
* following fields it would require bumping the ARCHIVE_VERSION_VALUE \
|
||||
* Archive reader would need corresponding changes to support loading both \
|
||||
* the current and older versions of the PT2 Archive. \
|
||||
*/ \
|
||||
/* Model definitions */ \
|
||||
_(MODELS_DIR, "models/") \
|
||||
_(MODELS_FILENAME_FORMAT, "models/{}.json") /* {model_name} */ \
|
||||
/* AOTInductor artifacts */ \
|
||||
_(AOTINDUCTOR_DIR, "data/aotinductor/") \
|
||||
/* MTIA artifacts */ \
|
||||
_(MTIA_DIR, "data/mtia") \
|
||||
/* weights, including parameters and buffers */ \
|
||||
_(WEIGHTS_DIR, "data/weights/") \
|
||||
_(WEIGHT_FILENAME_PREFIX, "weight_") \
|
||||
/* constants, including tensor_constants, non-persistent buffers and script \
|
||||
* objects */ \
|
||||
_(CONSTANTS_DIR, "data/constants/") \
|
||||
_(TENSOR_CONSTANT_FILENAME_PREFIX, "tensor_") \
|
||||
_(CUSTOM_OBJ_FILENAME_PREFIX, "custom_obj_") \
|
||||
/* example inputs */ \
|
||||
_(SAMPLE_INPUTS_DIR, "data/sample_inputs/") \
|
||||
_(SAMPLE_INPUTS_FILENAME_FORMAT, \
|
||||
"data/sample_inputs/{}.pt") /* {model_name} */ \
|
||||
/* extra folder */ \
|
||||
_(EXTRA_DIR, "extra/") \
|
||||
_(MODULE_INFO_PATH, "extra/module_info.json") \
|
||||
/* xl_model_weights, this folder is used for storing per-feature-weights for \
|
||||
* remote net data in this folder is consume by Predictor, and is not \
|
||||
* intended to be used by Sigmoid */ \
|
||||
_(XL_MODEL_WEIGHTS_DIR, "xl_model_weights/") \
|
||||
_(XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH, "xl_model_weights/model_param_config")
|
||||
|
||||
#define DEFINE_GLOBAL(NAME, VALUE) \
|
||||
inline constexpr std::string_view NAME = VALUE;
|
||||
|
||||
#define DEFINE_ENTRY(NAME, VALUE) std::pair(#NAME, VALUE),
|
||||
|
||||
FORALL_COSTANTS(DEFINE_GLOBAL)
|
||||
|
||||
inline constexpr std::array kAllConstants{FORALL_COSTANTS(DEFINE_ENTRY)};
|
||||
|
||||
#undef DEFINE_ENTRY
|
||||
#undef DEFINE_GLOBAL
|
||||
#undef FORALL_COSTANTS
|
||||
} // namespace archive_spec
|
||||
} // namespace torch::nativert
|
||||
18
torch/csrc/nativert/package/pt2_archive_constants_pybind.cpp
Normal file
18
torch/csrc/nativert/package/pt2_archive_constants_pybind.cpp
Normal file
@ -0,0 +1,18 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
|
||||
|
||||
namespace torch::nativert {
|
||||
void initPt2ArchiveConstantsPybind(pybind11::module& m) {
|
||||
for (const auto& entry : torch::nativert::archive_spec::kAllConstants) {
|
||||
m.attr(entry.first) = entry.second;
|
||||
}
|
||||
}
|
||||
} // namespace torch::nativert
|
||||
|
||||
// TODO Remove this once we fully migrate to OSS build.
|
||||
#ifdef FBCODE_CAFFE2
|
||||
PYBIND11_MODULE(pt2_archive_constants_pybind, m) {
|
||||
torch::nativert::initPt2ArchiveConstantsPybind(m);
|
||||
}
|
||||
#endif
|
||||
@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
void initPt2ArchiveConstantsPybind(pybind11::module& m);
|
||||
} // namespace torch::nativert
|
||||
225
torch/export/experimental/package/__init__.py
Normal file
225
torch/export/experimental/package/__init__.py
Normal file
@ -0,0 +1,225 @@
|
||||
import ctypes
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
###############################################################################
|
||||
#
|
||||
# This file contains the code to package a model for Sigmoid in open source.
|
||||
# Please do not introduce fbcode dependencies here. (e.g. aiplatform, fbgemm, thrift)
|
||||
#
|
||||
###############################################################################
|
||||
import torch
|
||||
|
||||
from torch.export.experimental.package.pt2_archive import PT2ArchiveWriter
|
||||
from torch._C.nativert.pt2_archive_constants import ( # @manual=//sigmoid/core/package:pt2_archive_constants_pybind
|
||||
CONSTANTS_DIR,
|
||||
CUSTOM_OBJ_FILENAME_PREFIX,
|
||||
MODELS_FILENAME_FORMAT,
|
||||
SAMPLE_INPUTS_FILENAME_FORMAT,
|
||||
TENSOR_CONSTANT_FILENAME_PREFIX,
|
||||
WEIGHT_FILENAME_PREFIX,
|
||||
WEIGHTS_DIR,
|
||||
)
|
||||
from torch._export.serde.schema import Model, Program
|
||||
|
||||
from torch._export.serde.serialize import (
|
||||
_enable_graph_inputs_of_type_nn_module,
|
||||
_to_json_bytes,
|
||||
ExportedProgramSerializer,
|
||||
)
|
||||
from torch.export import ExportedProgram
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_raw_tensor_bytes(value: torch.Tensor) -> bytes:
|
||||
# NOTE: don't chain .cpu() with .data_ptr(). If an HtoD copy needs to be
|
||||
# performed, the CPU copy needs to be kept alive when its underlying
|
||||
# memory is accessed.
|
||||
if value.data_ptr():
|
||||
cpu_tensor = value.cpu().contiguous()
|
||||
# we store the raw bytes of tensor. Tensor metadata is stored separately
|
||||
value_bytes = bytes(
|
||||
ctypes.cast(
|
||||
cpu_tensor.data_ptr(),
|
||||
ctypes.POINTER(ctypes.c_ubyte * value.element_size() * value.numel()),
|
||||
).contents
|
||||
)
|
||||
else:
|
||||
# for empty tensor
|
||||
value_bytes = bytes()
|
||||
return value_bytes
|
||||
|
||||
|
||||
def _package_state_dict(
|
||||
exported_program: ExportedProgram,
|
||||
zip_file: PT2ArchiveWriter,
|
||||
) -> Dict[str, str]:
|
||||
idx = zip_file.count_prefix(os.path.join(WEIGHTS_DIR, WEIGHT_FILENAME_PREFIX))
|
||||
|
||||
qual_name_to_id = {} # Map from tensor name to its name in xl_weights folder
|
||||
|
||||
for name, tensor in exported_program.state_dict.items():
|
||||
if tensor.is_meta:
|
||||
logger.error(
|
||||
f"Skipping state_dict packing of {name} since it's a meta tensor"
|
||||
)
|
||||
continue
|
||||
|
||||
param_name = f"{WEIGHT_FILENAME_PREFIX}{idx}"
|
||||
idx += 1
|
||||
|
||||
qual_name_to_id[name] = param_name
|
||||
|
||||
archive_path = os.path.join(WEIGHTS_DIR, param_name)
|
||||
tensor_bytes = get_raw_tensor_bytes(tensor)
|
||||
zip_file.write_bytes(archive_path, tensor_bytes)
|
||||
|
||||
return qual_name_to_id
|
||||
|
||||
|
||||
def _package_constants(
|
||||
exported_program: ExportedProgram,
|
||||
zip_file: PT2ArchiveWriter,
|
||||
) -> Dict[str, Any]:
|
||||
tensor_idx = zip_file.count_prefix(
|
||||
os.path.join(CONSTANTS_DIR, TENSOR_CONSTANT_FILENAME_PREFIX)
|
||||
)
|
||||
custom_obj_idx = zip_file.count_prefix(
|
||||
os.path.join(CONSTANTS_DIR, CUSTOM_OBJ_FILENAME_PREFIX)
|
||||
)
|
||||
|
||||
qual_name_to_id = {} # Map from constant name to its name in constants folder
|
||||
|
||||
for name, constant in exported_program.constants.items():
|
||||
if isinstance(constant, torch.Tensor):
|
||||
# Save the constant tensors the same way we save weights
|
||||
tensor_name = f"{TENSOR_CONSTANT_FILENAME_PREFIX}{tensor_idx}"
|
||||
tensor_idx += 1
|
||||
|
||||
qual_name_to_id[name] = tensor_name
|
||||
|
||||
archive_path = os.path.join(CONSTANTS_DIR, tensor_name)
|
||||
tensor_bytes = get_raw_tensor_bytes(constant)
|
||||
zip_file.write_bytes(archive_path, tensor_bytes)
|
||||
|
||||
elif isinstance(constant, torch._C.ScriptObject):
|
||||
# CustomClassHolder objects implement their own pickle saving
|
||||
# functions.
|
||||
logger.info(f"saving script object {name}")
|
||||
|
||||
custom_obj_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}"
|
||||
custom_obj_idx += 1
|
||||
|
||||
qual_name_to_id[name] = custom_obj_name
|
||||
|
||||
archive_path = os.path.join(CONSTANTS_DIR, custom_obj_name)
|
||||
|
||||
custom_obj_bytes = torch._C._pickle_save(constant)
|
||||
zip_file.write_bytes(archive_path, custom_obj_bytes)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Serializing constant type {type(constant)} nyi")
|
||||
|
||||
return qual_name_to_id
|
||||
|
||||
|
||||
# `sample_inputs` will be pytree_flatten as a python list, and saved via `torch.save()`
|
||||
# in the zip archive as "data/sample_inputs/<model_name>.pt".
|
||||
# In C++, this can be loaded via `torch::pickle_load`.
|
||||
# See sigmoid::ModelRunner::loadSampleInputs() for more details.
|
||||
def _package_sample_inputs(
|
||||
sample_args: Tuple[pytree.PyTree, ...],
|
||||
sample_kwargs: Dict[str, pytree.PyTree],
|
||||
zip_file: PT2ArchiveWriter,
|
||||
model_name: str,
|
||||
) -> str:
|
||||
sample_inputs_path = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name)
|
||||
buffer = io.BytesIO()
|
||||
|
||||
# Convert torch.nn.Parameter to torch.Tensor
|
||||
# This is needed because torch::pickle_load() doesn't support torch.nn.Parameter
|
||||
def get_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(x, torch.nn.Parameter):
|
||||
return x.data
|
||||
else:
|
||||
return x
|
||||
|
||||
# args must be a tuple, not a list
|
||||
sample_args = tuple(pytree.tree_map(get_tensor, sample_args))
|
||||
|
||||
# kwargs must be a dict
|
||||
sample_kwargs = pytree.tree_map(get_tensor, sample_kwargs)
|
||||
|
||||
torch.save((sample_args, sample_kwargs), buffer)
|
||||
|
||||
zip_file.write_bytes(sample_inputs_path, buffer.getvalue())
|
||||
|
||||
return sample_inputs_path
|
||||
|
||||
|
||||
def package_model(
|
||||
exported_program: ExportedProgram,
|
||||
model_name: str,
|
||||
zip_file: PT2ArchiveWriter,
|
||||
delegates: Optional[Dict[str, ExportedProgram]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saving in the format that's compatible with sigmoid ModelRunner.
|
||||
"""
|
||||
|
||||
def _make_program(
|
||||
ep: ExportedProgram,
|
||||
) -> Program:
|
||||
with _enable_graph_inputs_of_type_nn_module(ep.example_inputs):
|
||||
return Program(
|
||||
methods={
|
||||
"forward": ExportedProgramSerializer()
|
||||
.serialize(ep)
|
||||
.exported_program
|
||||
},
|
||||
)
|
||||
|
||||
if delegates is None:
|
||||
delegates = {}
|
||||
|
||||
# Packaging for Weight
|
||||
tensor_path_map = _package_state_dict(exported_program, zip_file)
|
||||
# Packaging for Constants (tensor constants, custom class obj)
|
||||
constant_path_map = _package_constants(exported_program, zip_file)
|
||||
example_args, example_kwargs = exported_program.example_inputs
|
||||
|
||||
# Packaging for input samples
|
||||
assert (
|
||||
example_args is not None or example_kwargs is not None
|
||||
), "PT2 Archive requires sample inputs to be present"
|
||||
sample_inputs_path = _package_sample_inputs( # noqa
|
||||
example_args, example_kwargs, zip_file, model_name
|
||||
)
|
||||
|
||||
model_json = Model(
|
||||
name=model_name,
|
||||
tensorPaths=tensor_path_map,
|
||||
program=_make_program(exported_program),
|
||||
delegates={key: _make_program(ep) for key, ep in delegates.items()},
|
||||
deviceAllocationMap={},
|
||||
constantPaths=constant_path_map,
|
||||
)
|
||||
model_bytes: bytes = _to_json_bytes(model_json)
|
||||
|
||||
# Packaging for model
|
||||
zip_file.write_bytes(MODELS_FILENAME_FORMAT.format(model_name), model_bytes)
|
||||
|
||||
# Include readable graph for debugging
|
||||
zip_file.write_string(
|
||||
f"models/debug/{model_name}_readable.txt", str(exported_program)
|
||||
)
|
||||
zip_file.write_string(
|
||||
f"models/debug/{model_name}_device_annotation.txt",
|
||||
exported_program.graph_module.print_readable(
|
||||
print_output=False, include_device=True
|
||||
),
|
||||
)
|
||||
139
torch/export/experimental/package/pt2_archive.py
Normal file
139
torch/export/experimental/package/pt2_archive.py
Normal file
@ -0,0 +1,139 @@
|
||||
# pyre-unsafe
|
||||
import glob
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from typing import BinaryIO, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch._C.nativert.pt2_archive_constants import ( # @manual=//sigmoid/core/package:pt2_archive_constants_pybind
|
||||
ARCHIVE_FORMAT_PATH,
|
||||
ARCHIVE_FORMAT_VALUE,
|
||||
ARCHIVE_VERSION_PATH,
|
||||
ARCHIVE_VERSION_VALUE,
|
||||
)
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_sigmoid_package(serialized_model: Union[bytes, str]) -> bool:
|
||||
try:
|
||||
zip_reader = zipfile.ZipFile(
|
||||
io.BytesIO(serialized_model)
|
||||
if isinstance(serialized_model, bytes)
|
||||
else serialized_model
|
||||
)
|
||||
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
|
||||
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
|
||||
if archive_format_path in zip_reader.namelist():
|
||||
return zip_reader.read(archive_format_path) == b"pt2"
|
||||
except Exception as ex:
|
||||
logger.info(f"Model is not a sigmoid package: {ex}")
|
||||
return False
|
||||
|
||||
|
||||
class PT2ArchiveWriter:
|
||||
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
|
||||
# pyre-ignore
|
||||
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer)
|
||||
# NOTICE: version here is different from the archive_version
|
||||
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
|
||||
# archive_version is the version of the PT2 archive spec, which write to /archive_version
|
||||
self.archive_file.set_min_version(6)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
if not self.has_record(ARCHIVE_FORMAT_PATH):
|
||||
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
|
||||
|
||||
if not self.has_record(ARCHIVE_VERSION_PATH):
|
||||
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
|
||||
|
||||
self.close()
|
||||
|
||||
def has_record(self, name: str) -> bool:
|
||||
return name in self.archive_file.get_all_written_records()
|
||||
|
||||
def count_prefix(self, prefix: str) -> int:
|
||||
return sum(
|
||||
1
|
||||
for record in self.archive_file.get_all_written_records()
|
||||
if record.startswith(prefix)
|
||||
)
|
||||
|
||||
def write_bytes(self, name: str, data: bytes) -> None:
|
||||
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
|
||||
self.archive_file.write_record(name, data, len(data))
|
||||
|
||||
def write_string(self, name: str, data: str) -> None:
|
||||
assert isinstance(data, str), f"Expected string but got {type(data)}"
|
||||
data_bytes = data.encode()
|
||||
self.write_bytes(name, data_bytes)
|
||||
|
||||
def write_file(self, name: str, file_path: str) -> None:
|
||||
"""
|
||||
Copy a file into the archive.
|
||||
name: The destination file inside the archive.
|
||||
file_path: The source file on disk.
|
||||
"""
|
||||
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
self.write_bytes(name, file_bytes)
|
||||
|
||||
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
|
||||
"""
|
||||
Copy a folder into the archive.
|
||||
archive_dir: The destination folder inside the archive.
|
||||
folder_dir: The source folder on disk.
|
||||
"""
|
||||
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
|
||||
|
||||
file_paths = filter(
|
||||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
||||
)
|
||||
for file_path in file_paths:
|
||||
filename = os.path.relpath(file_path, folder_dir)
|
||||
archive_path = os.path.join(archive_dir, filename)
|
||||
self.write_file(archive_path, file_path)
|
||||
|
||||
def close(self) -> None:
|
||||
self.archive_file.write_end_of_file()
|
||||
|
||||
|
||||
class PT2ArchiveReader:
|
||||
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
|
||||
# pyre-ignore
|
||||
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer)
|
||||
assert (
|
||||
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
|
||||
), "Invalid archive format"
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
# torch._C.PyTorchFileReader doesn't have a close method
|
||||
pass
|
||||
|
||||
def read_bytes(self, name: str) -> bytes:
|
||||
return self.archive_file.get_record(name)
|
||||
|
||||
def read_string(self, name: str) -> str:
|
||||
data = self.read_bytes(name)
|
||||
return data.decode()
|
||||
|
||||
def archive_version(self) -> int:
|
||||
try:
|
||||
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
|
||||
except Exception:
|
||||
# if archive_version is not found, it means the archive is older than version 0.
|
||||
# In this case, we assume the archive is version 0.
|
||||
archive_version = 0
|
||||
|
||||
return int(archive_version)
|
||||
139
torch/export/experimental/pt2_archive.py
Normal file
139
torch/export/experimental/pt2_archive.py
Normal file
@ -0,0 +1,139 @@
|
||||
# pyre-unsafe
|
||||
import glob
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from typing import BinaryIO, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch._C.nativert.package.pt2_archive_constants_pybind import ( # @manual=//sigmoid/core/package:pt2_archive_constants_pybind
|
||||
ARCHIVE_FORMAT_PATH,
|
||||
ARCHIVE_FORMAT_VALUE,
|
||||
ARCHIVE_VERSION_PATH,
|
||||
ARCHIVE_VERSION_VALUE,
|
||||
)
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_sigmoid_package(serialized_model: Union[bytes, str]) -> bool:
|
||||
try:
|
||||
zip_reader = zipfile.ZipFile(
|
||||
io.BytesIO(serialized_model)
|
||||
if isinstance(serialized_model, bytes)
|
||||
else serialized_model
|
||||
)
|
||||
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
|
||||
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
|
||||
if archive_format_path in zip_reader.namelist():
|
||||
return zip_reader.read(archive_format_path) == b"pt2"
|
||||
except Exception as ex:
|
||||
logger.info(f"Model is not a sigmoid package: {ex}")
|
||||
return False
|
||||
|
||||
|
||||
class PT2ArchiveWriter:
|
||||
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
|
||||
# pyre-ignore
|
||||
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer)
|
||||
# NOTICE: version here is different from the archive_version
|
||||
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
|
||||
# archive_version is the version of the PT2 archive spec, which write to /archive_version
|
||||
self.archive_file.set_min_version(6)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
if not self.has_record(ARCHIVE_FORMAT_PATH):
|
||||
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
|
||||
|
||||
if not self.has_record(ARCHIVE_VERSION_PATH):
|
||||
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
|
||||
|
||||
self.close()
|
||||
|
||||
def has_record(self, name: str) -> bool:
|
||||
return name in self.archive_file.get_all_written_records()
|
||||
|
||||
def count_prefix(self, prefix: str) -> int:
|
||||
return sum(
|
||||
1
|
||||
for record in self.archive_file.get_all_written_records()
|
||||
if record.startswith(prefix)
|
||||
)
|
||||
|
||||
def write_bytes(self, name: str, data: bytes) -> None:
|
||||
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
|
||||
self.archive_file.write_record(name, data, len(data))
|
||||
|
||||
def write_string(self, name: str, data: str) -> None:
|
||||
assert isinstance(data, str), f"Expected string but got {type(data)}"
|
||||
data_bytes = data.encode()
|
||||
self.write_bytes(name, data_bytes)
|
||||
|
||||
def write_file(self, name: str, file_path: str) -> None:
|
||||
"""
|
||||
Copy a file into the archive.
|
||||
name: The destination file inside the archive.
|
||||
file_path: The source file on disk.
|
||||
"""
|
||||
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
self.write_bytes(name, file_bytes)
|
||||
|
||||
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
|
||||
"""
|
||||
Copy a folder into the archive.
|
||||
archive_dir: The destination folder inside the archive.
|
||||
folder_dir: The source folder on disk.
|
||||
"""
|
||||
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
|
||||
|
||||
file_paths = filter(
|
||||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
||||
)
|
||||
for file_path in file_paths:
|
||||
filename = os.path.relpath(file_path, folder_dir)
|
||||
archive_path = os.path.join(archive_dir, filename)
|
||||
self.write_file(archive_path, file_path)
|
||||
|
||||
def close(self) -> None:
|
||||
self.archive_file.write_end_of_file()
|
||||
|
||||
|
||||
class PT2ArchiveReader:
|
||||
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
|
||||
# pyre-ignore
|
||||
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer)
|
||||
assert (
|
||||
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
|
||||
), "Invalid archive format"
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
# torch._C.PyTorchFileReader doesn't have a close method
|
||||
pass
|
||||
|
||||
def read_bytes(self, name: str) -> bytes:
|
||||
return self.archive_file.get_record(name)
|
||||
|
||||
def read_string(self, name: str) -> str:
|
||||
data = self.read_bytes(name)
|
||||
return data.decode()
|
||||
|
||||
def archive_version(self) -> int:
|
||||
try:
|
||||
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
|
||||
except Exception:
|
||||
# if archive_version is not found, it means the archive is older than version 0.
|
||||
# In this case, we assume the archive is version 0.
|
||||
archive_version = 0
|
||||
|
||||
return int(archive_version)
|
||||
Reference in New Issue
Block a user