Compare commits

...

1 Commits

Author SHA1 Message Date
290a4b4fea o.o 2025-03-20 10:59:09 -07:00
89 changed files with 21645 additions and 1 deletions

View File

@ -587,6 +587,44 @@ jit_sources_full = [
libtorch_core_jit_sources = sorted(jit_sources_full)
libtorch_runtime_sources = [
"torch/csrc/nativert/common/ConfigUtils.cpp",
"torch/csrc/nativert/common/Conv.cpp",
"torch/csrc/nativert/common/FileUtil.cpp",
"torch/csrc/nativert/common/Pytree.cpp",
"torch/csrc/nativert/common/String.cpp",
"torch/csrc/nativert/executor/AOTInductorModelImpl.cpp",
"torch/csrc/nativert/executor/SerialGraphExecutor.cpp",
"torch/csrc/nativert/executor/AOTIDelegateExecutor.cpp",
"torch/csrc/nativert/executor/Executor.cpp",
"torch/csrc/nativert/executor/GraphExecutorBase.cpp",
"torch/csrc/nativert/executor/ConstantFolder.cpp",
"torch/csrc/nativert/executor/DelegateExecutor.cpp",
"torch/csrc/nativert/executor/ExecutionFrame.cpp",
"torch/csrc/nativert/executor/ExecutionPlanner.cpp",
"torch/csrc/nativert/executor/ModelRunnerBase.cpp",
"torch/csrc/nativert/executor/OpKernel.cpp",
"torch/csrc/nativert/executor/ParallelGraphExecutor.cpp",
"torch/csrc/nativert/executor/Placement.cpp",
"torch/csrc/nativert/executor/Weights.cpp",
"torch/csrc/nativert/graph/Graph.cpp",
"torch/csrc/nativert/graph/GraphPasses.cpp",
"torch/csrc/nativert/graph/GraphSignature.cpp",
"torch/csrc/nativert/graph/Serialization.cpp",
"torch/csrc/nativert/graph/TensorMeta.cpp",
"torch/csrc/nativert/kernels/NativeKernels.cpp",
"torch/csrc/nativert/kernels/AOTICallDelegateKernel.cpp",
"torch/csrc/nativert/kernels/AOTIKernel.cpp",
"torch/csrc/nativert/kernels/CallTorchBindKernel.cpp",
"torch/csrc/nativert/kernels/GeneratedStaticDispatchKernels.cpp",
"torch/csrc/nativert/kernels/AutoFunctionalizeKernel.cpp",
"torch/csrc/nativert/kernels/C10Kernel.cpp",
"torch/csrc/nativert/kernels/HigherOrderKernel.cpp",
"torch/csrc/nativert/kernels/KernelFactory.cpp",
"torch/csrc/nativert/kernels/KernelRegistry.cpp",
"torch/csrc/nativert/ModelRunner.cpp",
]
torch_mobile_tracer_sources = [
"torch/csrc/jit/mobile/model_tracer/tracer.cpp",
"torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp",
@ -619,7 +657,7 @@ libtorch_lite_cmake_sources = sorted(
torch_mobile_core,
)
libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources
libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources + libtorch_runtime_sources
libtorch_extra_sources = libtorch_core_jit_sources + [
"torch/csrc/autograd/TraceTypeManual.cpp",
@ -935,6 +973,8 @@ libtorch_python_core_sources = [
"torch/csrc/utils/verbose.cpp",
"torch/csrc/cpu/Module.cpp",
"torch/csrc/instruction_counter/Module.cpp",
"torch/csrc/nativert/ModelRunnerPybind.cpp",
"torch/csrc/nativert/package/pt2_archive_constants_pybind.cpp",
] + lazy_tensor_core_python_sources
libtorch_python_distributed_core_sources = [

View File

@ -0,0 +1,219 @@
# Owner(s): ["oncall: export"]
import copy
import pathlib
import tempfile
import unittest
import torch
from torch._C.nativert import (
PyExecutorType,
PyModelRunner,
PyPlacement,
PyRuntimeConfigs,
)
from torch.export.experimental.package import package_model
from torch.export.experimental.package.pt2_archive import PT2ArchiveWriter
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils import _pytree as pytree
try:
from .. import test_export, testing
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export
test_classes = {}
def _use_real_inputs(ep):
ep = copy.copy(ep)
has_fake_tensor = False
def _to_real_tensor(t):
if isinstance(t, torch.nn.Parameter):
return torch.nn.Parameter(_to_real_tensor(t.data))
if isinstance(t, FakeTensor):
nonlocal has_fake_tensor
has_fake_tensor = True
return torch.randn(t.shape, device=t.device, requires_grad=t.requires_grad)
return t
new_example_inputs = pytree.tree_map_only(
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.example_inputs
)
if has_fake_tensor:
ep.example_inputs = new_example_inputs
ep = ep._update(
ep.graph_module,
ep.graph_signature,
state_dict=pytree.tree_map_only(
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.state_dict
),
constants=pytree.tree_map_only(
(torch.Tensor, torch.nn.Parameter), _to_real_tensor, ep.constants
),
)
return ep
def _is_supported_types(arg) -> bool:
if isinstance(arg, list):
return (
all(_is_supported_types(a) for a in arg)
and len({type(a) for a in arg}) <= 1
)
elif isinstance(arg, tuple):
return all(_is_supported_types(a) for a in arg)
elif isinstance(arg, dict):
return (
all(_is_supported_types(a) for a in arg.values())
and len({type(a) for a in arg.values()}) <= 1
)
elif isinstance(arg, (torch.Tensor, int, float, bool, str)):
return True
elif arg is None:
return True
else:
return False
def run_with_sigmoid(ep):
# Downstream tests might mutate the exported program in subtle ways, so
# we need to make a copy here.
ep_infer = copy.deepcopy(ep)
MODEL_NAME = "test_export"
ep_infer = _use_real_inputs(ep_infer.run_decompositions())
# TODO Does named tempfile have collision?
with tempfile.NamedTemporaryFile(delete=False) as f:
with PT2ArchiveWriter(f) as archive_writer:
package_model(
ep_infer,
MODEL_NAME,
archive_writer,
)
device = torch.device("cpu")
placement = PyPlacement(device)
runtime_configs = PyRuntimeConfigs()
filename = f.name
try:
ep_args, ep_kwargs = ep_infer.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
)
except Exception as e:
raise unittest.case.SkipTest(e)
model_runner = PyModelRunner(
filename,
MODEL_NAME,
PyExecutorType.INTERPRETER,
runtime_configs,
placement,
)
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) == type(expected)
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
finally:
pathlib.Path(filename).unlink(missing_ok=True)
return ep
def mocked_sigmoid_export_strict(*args, **kwargs):
if "strict" in kwargs:
ep = export(*args, **kwargs)
else:
ep = export(*args, **kwargs, strict=True)
run_with_sigmoid(ep)
return ep
def mocked_sigmoid_export_nonstrict(*args, **kwargs):
if "strict" in kwargs:
ep = export(*args, **kwargs)
else:
ep = export(*args, **kwargs, strict=False)
run_with_sigmoid(ep)
return ep
def make_dynamic_cls(cls, strict=False):
cls_prefix = "Sigmoid"
if strict:
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
test_export.CPP_RUNTIME_STRICT_SUFFIX,
mocked_sigmoid_export_strict,
xfail_prop="_expected_failure_cpp_runtime",
test_only_if_no_xfail=True,
)
else:
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
test_export.CPP_RUNTIME_NONSTRICT_SUFFIX,
mocked_sigmoid_export_nonstrict,
xfail_prop="_expected_failure_cpp_runtime_non_strict",
test_only_if_no_xfail=True,
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
tests = [
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test, strict=True)
make_dynamic_cls(test, strict=False)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -129,6 +129,9 @@
#endif
#endif
#include <torch/csrc/nativert/ModelRunnerPybind.h>
#include <torch/csrc/nativert/package/pt2_archive_constants_pybind.h>
#if defined(USE_VALGRIND)
#include <callgrind.h>
#endif
@ -2586,6 +2589,13 @@ Call this whenever a new thread is created in order to propagate values from
#ifdef USE_KINETO
torch::global_kineto_init();
#endif
{
auto py_module_runtime = py_module.def_submodule("nativert");
torch::nativert::initModelRunnerPybind(py_module_runtime);
auto py_module_constants =
py_module_runtime.def_submodule("pt2_archive_constants");
torch::nativert::initPt2ArchiveConstantsPybind(py_module_constants);
}
return module;
END_HANDLE_TH_ERRORS
}

View File

@ -0,0 +1,113 @@
#include "torch/csrc/nativert/ModelRunner.h"
#include <nlohmann/json.hpp>
#include <caffe2/serialize/file_adapter.h>
#include "torch/csrc/nativert/graph/GraphPasses.h"
#include "torch/csrc/nativert/graph/Serialization.h"
namespace torch::nativert::core {
ModelRunner::ModelRunner(
const std::string& packagePath,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const Placement& placement)
: ModelRunner(
std::make_unique<caffe2::serialize::FileAdapter>(packagePath),
modelName,
executorType,
runtimeConfigs,
placement) {}
ModelRunner::ModelRunner(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const Placement& placement)
: ModelRunner(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(rai),
modelName,
executorType,
runtimeConfigs,
placement) {}
ModelRunner::ModelRunner(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const Placement& placement)
: ModelRunner(
std::move(pytorchStreamReader),
modelName,
executorType,
runtimeConfigs,
[=](const torch::nativert::Graph&) { return placement; }) {}
ModelRunner::ModelRunner(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const std::function<Placement(const torch::nativert::Graph& graph)>&
buildPlacementFn)
: ModelRunnerBase(
pytorchStreamReader,
modelName,
executorType,
runtimeConfigs,
buildPlacementFn) {
std::string modelSerialized = loadSerializedModel(pytorchStreamReader);
model_ = nlohmann::json::parse(modelSerialized)
.template get<torch::_export::Model>();
exportedProgram_ = model_.get_program().get_methods().at("forward");
for (const auto& _ : model_.get_delegates()) {
(void)_;
TORCH_CHECK(false, "Delegates are not supported yet");
// TODO delegates_.emplace(name, delegate.get_methods().at("forward"));
}
stateDictPath_ = model_.get_tensorPaths();
constantPaths_ = model_.get_constantPaths();
TORCH_CHECK_EQ(
exportedProgram_.get_graph_module().get_module_call_graph()[0].get_fqn(),
"");
inputSpec_ = treeSpecLoads(exportedProgram_.get_graph_module()
.get_module_call_graph()[0]
.get_signature()
.value()
.get_in_spec());
outputSpec_ = treeSpecLoads(exportedProgram_.get_graph_module()
.get_module_call_graph()[0]
.get_signature()
.value()
.get_out_spec());
graph_ = jsonToGraph(
model_.get_program().get_methods().at("forward").get_graph_module());
VLOG(1) << "Graph: \n" << *graph_;
placement_ = buildPlacementFn(*graph_);
LOG(INFO) << "Placement: " << placement_;
graph_->applyDevicePlacement(placement_);
selectScalarOverload(graph_.get());
loadNewWeights(pytorchStreamReader);
if (!runtimeConfigs.deferInitialization) {
initialize(pytorchStreamReader);
}
}
std::unique_ptr<Graph> ModelRunner::deserializeDelegateGraph() const {
return {}; // TODO
}
} // namespace torch::nativert::core

View File

@ -0,0 +1,76 @@
#pragma once
#include "torch/csrc/nativert/executor/ModelRunnerBase.h"
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
namespace torch::nativert::core {
class TORCH_API ModelRunner : public ModelRunnerBase {
public:
ModelRunner(
const std::string& packagePath,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const Placement& placement = Placement());
ModelRunner(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const Placement& placement = Placement());
ModelRunner(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const Placement& placement = Placement());
ModelRunner(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
// functor to build the placement after the graph is loaded, but before
// loading the weights.
const std::function<Placement(const torch::nativert::Graph& graph)>&
buildPlacementFn);
ModelRunner(ModelRunner&&) = default;
ModelRunner& operator=(ModelRunner&&) = default;
ModelRunner(const ModelRunner&) = delete;
ModelRunner& operator=(const ModelRunner&) = delete;
~ModelRunner() override = default;
std::vector<std::string> availableDelegates() const {
std::vector<std::string> delegateNames;
delegateNames.reserve(delegates_.size());
for (const auto& [name, _] : delegates_) {
delegateNames.push_back(name);
}
return delegateNames;
}
template <typename T>
std::vector<T*> getDelegates() {
std::vector<T*> delegates;
for (const auto& delegate : executor_->getDelegates()) {
if (auto* d = dynamic_cast<T*>(delegate)) {
delegates.push_back(d);
}
}
return delegates;
}
private:
std::unique_ptr<Graph> deserializeDelegateGraph() const override;
torch::_export::Model model_;
torch::_export::ExportedProgram exportedProgram_;
std::unordered_map<std::string, torch::_export::ExportedProgram> delegates_;
};
} // namespace torch::nativert::core

View File

@ -0,0 +1,148 @@
#include <unordered_map>
#include <pybind11/pybind11.h>
#include <caffe2/serialize/file_adapter.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h> // @manual=//caffe2:torch-cpp-cpu
#include "c10/core/Device.h"
#include "torch/csrc/nativert/ModelRunner.h"
namespace py = pybind11;
using namespace torch::nativert;
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
namespace torch {
namespace nativert {
void initModelRunnerPybind(py::module& m) {
py::enum_<ExecutorType>(m, "PyExecutorType")
.value("INTERPRETER", ExecutorType::INTERPRETER)
.value("AOTINDUCTOR", ExecutorType::AOTINDUCTOR)
.value("MTIA", ExecutorType::MTIA)
.export_values();
py::class_<Placement>(m, "PyPlacement")
.def(py::init<>())
.def(py::init<std::optional<c10::Device>>(), py::arg("defaultDevice"))
.def(
py::init<
const std::unordered_map<c10::Device, c10::Device>&,
std::optional<c10::Device>>(),
py::arg("deviceMap"),
py::arg("defaultDevice") = std::nullopt)
.def("get_mapped_device", &Placement::getMappedDevice);
py::class_<BaseRuntimeConfigs>(m, "PyRuntimeConfigs")
.def(py::init<>())
.def_readwrite("isDebug", &BaseRuntimeConfigs::isDebug)
.def_readwrite("validateInputs", &BaseRuntimeConfigs::validateInputs)
.def_readwrite(
"enableStaticCPUKernels", &BaseRuntimeConfigs::enableStaticCPUKernels)
.def_readwrite(
"deferInitialization", &BaseRuntimeConfigs::deferInitialization)
.def_readwrite("platformArch", &BaseRuntimeConfigs::platformArch)
.def_readwrite(
"maxNumConcurrentThreads",
&BaseRuntimeConfigs::maxNumConcurrentThreads)
.def_readwrite("maxParallelOps", &BaseRuntimeConfigs::maxParallelOps);
shared_ptr_class_<core::ModelRunner>(m, "PyModelRunner")
.def(
py::init<
const std::string&,
const std::string&,
ExecutorType,
const BaseRuntimeConfigs&,
const Placement&>(),
py::arg("packagePath"),
py::arg("modelName"),
py::arg("executorType"),
py::arg("runtimeConfigs"),
py::arg("placement") = Placement())
.def(py::init<
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>,
const std::string&,
ExecutorType,
const BaseRuntimeConfigs&,
std::function<Placement(const Graph& graph)>&>())
.def(
"run",
[](core::ModelRunner& self,
py::args pyargs,
const py::kwargs& pykwargs) {
std::vector<c10::IValue> args;
for (const auto i : c10::irange(pyargs.size())) {
auto ivalue =
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
args.push_back(std::move(ivalue));
}
std::unordered_map<std::string, c10::IValue> kwargs;
for (const auto& [key, pyarg] : pykwargs) {
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
kwargs[py::str(key)] = std::move(ivalue);
}
c10::IValue ret = self.run(args, kwargs);
return torch::jit::createPyObjectForStack({ret});
})
.def(
"__call__",
[](core::ModelRunner& self,
py::args pyargs,
const py::kwargs& pykwargs) {
std::vector<c10::IValue> args;
for (const auto i : c10::irange(pyargs.size())) {
auto ivalue =
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
args.push_back(std::move(ivalue));
}
std::unordered_map<std::string, c10::IValue> kwargs;
for (const auto& [key, pyarg] : pykwargs) {
auto ivalue = torch::jit::toIValue(pyarg, c10::AnyType::get());
kwargs[py::str(key)] = std::move(ivalue);
}
c10::IValue ret = self.run(args, kwargs);
return torch::jit::createPyObjectForStack({ret});
})
.def(
"load_sample_inputs",
[](core::ModelRunner& self,
const std::string& packagePath,
const Placement& placement = Placement()) {
auto reader =
std::make_shared<caffe2::serialize::PyTorchStreamReader>(
std::make_unique<caffe2::serialize::FileAdapter>(
packagePath));
const auto [args, kwargs] = self.loadSampleInputs(reader);
const auto val = argsToIValue(args, kwargs);
return torch::jit::createPyObjectForStack({val});
})
.def(
"run_with_flat_inputs_and_outputs",
[](core::ModelRunner& self, py::args pyargs) {
std::vector<c10::IValue> args;
for (const auto i : c10::irange(pyargs.size())) {
auto ivalue =
torch::jit::toIValue(pyargs[i], c10::AnyType::get());
args.push_back(std::move(ivalue));
}
auto rets = self.runWithFlatInputsAndOutputs(std::move(args));
return torch::jit::createPyObjectForStack(std::move(rets));
});
}
} // namespace nativert
} // namespace torch
// TODO Remove this once we fully migrate to OSS build.
#ifdef FBCODE_CAFFE2
PYBIND11_MODULE(model_runner_pybind, m) {
initModelRunnerPybind(m);
}
#endif

View File

@ -0,0 +1,11 @@
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace torch {
namespace nativert {
void initModelRunnerPybind(pybind11::module& m);
} // namespace nativert
} // namespace torch

View File

@ -0,0 +1,125 @@
/*
* Ported from folly/logging/AutoTimer.h
*/
#pragma once
#include <chrono>
#include <optional>
#include <string>
#include <string_view>
#include <c10/util/Logging.h>
#include <fmt/format.h>
namespace torch::nativert {
// Default logger
enum class GoogleLoggerStyle { SECONDS, MILLISECONDS };
template <GoogleLoggerStyle>
struct GoogleLogger;
/**
* Automatically times a block of code, printing a specified log message on
* destruction or whenever the log() method is called. For example:
*
* AutoTimer t("Foo() completed");
* doWork();
* t.log("Do work finished");
* doMoreWork();
*
* This would print something like:
* "Do work finished in 1.2 seconds"
* "Foo() completed in 4.3 seconds"
*
* You can customize what you use as the logger and clock. The logger needs
* to have an operator()(StringPiece, std::chrono::duration<double>) that
* gets a message and a duration. The clock needs to model Clock from
* std::chrono.
*
* The default logger logs usings glog. It only logs if the message is
* non-empty, so you can also just use this class for timing, e.g.:
*
* AutoTimer t;
* doWork()
* const auto how_long = t.log();
*/
template <
class Logger = GoogleLogger<GoogleLoggerStyle::MILLISECONDS>,
class Clock = std::chrono::high_resolution_clock>
class AutoTimer final {
public:
using DoubleSeconds = std::chrono::duration<double>;
explicit AutoTimer(
std::string&& msg = "",
const DoubleSeconds& minTimetoLog = DoubleSeconds::zero(),
Logger&& logger = Logger())
: destructionMessage_(std::move(msg)),
minTimeToLog_(minTimetoLog),
logger_(std::move(logger)) {}
// It doesn't really make sense to copy AutoTimer
// Movable to make sure the helper method for creating an AutoTimer works.
AutoTimer(const AutoTimer&) = delete;
AutoTimer(AutoTimer&&) = default;
AutoTimer& operator=(const AutoTimer&) = delete;
AutoTimer& operator=(AutoTimer&&) = default;
~AutoTimer() {
if (destructionMessage_) {
log(destructionMessage_.value());
}
}
DoubleSeconds log(std::string_view msg = "") {
return logImpl(Clock::now(), msg);
}
template <typename... Args>
DoubleSeconds logFormat(fmt::format_string<Args...> fmt, Args&&... args) {
auto now = Clock::now();
return logImpl(now, fmt::format(fmt, std::forward<Args>(args)...));
}
private:
// We take in the current time so that we don't measure time to call
// to<std::string> or format() in the duration.
DoubleSeconds logImpl(
std::chrono::time_point<Clock> now,
std::string_view msg) {
auto duration = now - start_;
if (duration >= minTimeToLog_) {
logger_(msg, duration);
}
start_ = Clock::now(); // Don't measure logging time
return duration;
}
std::optional<std::string> destructionMessage_;
std::chrono::time_point<Clock> start_ = Clock::now();
DoubleSeconds minTimeToLog_;
Logger logger_;
};
template <GoogleLoggerStyle Style>
struct GoogleLogger final {
void operator()(
std::string_view msg,
const std::chrono::duration<double>& sec) const {
if (msg.empty()) {
return;
}
if (Style == GoogleLoggerStyle::SECONDS) {
LOG(INFO) << msg << " in " << sec.count() << " seconds";
} else if (Style == GoogleLoggerStyle::MILLISECONDS) {
LOG(INFO) << msg << " in "
<< std::chrono::duration_cast<
std::chrono::duration<double, std::milli>>(sec)
.count()
<< " milliseconds";
}
}
};
} // namespace torch::nativert

View File

@ -0,0 +1,13 @@
#include "torch/csrc/nativert/common/ConfigUtils.h"
#include <string.h>
namespace torch::nativert {
std::optional<std::string> maybeGetEnv(std::string_view envVar) {
const char* env = getenv(envVar.data());
if (env == nullptr || strlen(env) == 0) {
return std::nullopt;
}
return std::string(env);
}
} // namespace torch::nativert

View File

@ -0,0 +1,7 @@
#pragma once
#include <optional>
#include <string>
namespace torch::nativert {
std::optional<std::string> maybeGetEnv(std::string_view envVar);
}

View File

@ -0,0 +1,47 @@
#include "torch/csrc/nativert/common/Conv.h"
#include <charconv>
#include <string>
namespace torch::nativert {
template <>
std::optional<int64_t> tryTo<int64_t>(std::string_view symbol) {
int64_t value;
auto [ptr, ec] =
std::from_chars(symbol.data(), symbol.data() + symbol.size(), value);
if (ec != std::errc()) {
return std::nullopt;
}
if (ptr != symbol.data() + symbol.size()) {
return std::nullopt;
}
return value;
}
template <>
std::optional<double> tryTo<double>(std::string_view symbol) {
double value;
#ifdef __APPLE__
char extra; // to detect any extra characters after the number
// Try to parse the string using sscanf
auto str = std::string{symbol};
if (sscanf(str.c_str(), "%lf %c", &value, &extra) != 1) {
// If sscanf returns anything other than 1, it means parsing failed or there
// were extra characters
return std::nullopt;
}
#else
auto [ptr, ec] =
std::from_chars(symbol.data(), symbol.data() + symbol.size(), value);
if (ec != std::errc()) {
return std::nullopt;
}
if (ptr != symbol.data() + symbol.size()) {
return std::nullopt;
}
#endif
return value;
}
} // namespace torch::nativert

View File

@ -0,0 +1,27 @@
#pragma once
#include <optional>
#include <string_view>
namespace torch::nativert {
template <typename T>
std::optional<T> tryTo(std::string_view symbol) = delete;
/*
* Convert a string to an integer. prefixes like "0x" or trailing whitespaces
* are not supported. Similayly, integer string with trailing characters like
* "123abc" will be rejected either.
*/
template <>
std::optional<int64_t> tryTo<int64_t>(std::string_view symbol);
/*
* Convert a string to a double. prefixes like "0x" or trailing whitespaces
* are not supported. Similayly, integer string with trailing characters like
* "123abc" will be rejected either.
*/
template <>
std::optional<double> tryTo<double>(std::string_view symbol);
} // namespace torch::nativert

View File

@ -0,0 +1,153 @@
/*
* Ported from folly/container/Enumerate.h
*/
#pragma once
#include <iterator>
#include <memory>
#include "c10/macros/Macros.h"
/**
* Similar to Python's enumerate(), enumerate() can be used to
* iterate a range with a for-range loop, and it also allows to
* retrieve the count of iterations so far. Can be used in constexpr
* context.
*
* For example:
*
* for (auto&& [index, element] : enumerate(vec)) {
* // index is a const reference to a size_t containing the iteration count.
* // element is a reference to the type contained within vec, mutable
* // unless vec is const.
* }
*
* If the binding is const, the element reference is too.
*
* for (const auto&& [index, element] : enumerate(vec)) {
* // element is always a const reference.
* }
*
* It can also be used as follows:
*
* for (auto&& it : enumerate(vec)) {
* // *it is a reference to the current element. Mutable unless vec is const.
* // it->member can be used as well.
* // it.index contains the iteration count.
* }
*
* As before, const auto&& it can also be used.
*/
namespace torch::nativert {
namespace detail {
template <class T>
struct MakeConst {
using type = const T;
};
template <class T>
struct MakeConst<T&> {
using type = const T&;
};
template <class T>
struct MakeConst<T*> {
using type = const T*;
};
template <class Iterator>
class Enumerator {
public:
constexpr explicit Enumerator(Iterator it) : it_(std::move(it)) {}
class Proxy {
public:
using difference_type = ssize_t;
using value_type = typename std::iterator_traits<Iterator>::value_type;
using reference = typename std::iterator_traits<Iterator>::reference;
using pointer = typename std::iterator_traits<Iterator>::pointer;
using iterator_category = std::input_iterator_tag;
C10_ALWAYS_INLINE constexpr explicit Proxy(const Enumerator& e)
: index(e.idx_), element(*e.it_) {}
// Non-const Proxy: Forward constness from Iterator.
C10_ALWAYS_INLINE constexpr reference operator*() {
return element;
}
C10_ALWAYS_INLINE constexpr pointer operator->() {
return std::addressof(element);
}
// Const Proxy: Force const references.
C10_ALWAYS_INLINE constexpr typename MakeConst<reference>::type operator*()
const {
return element;
}
C10_ALWAYS_INLINE constexpr typename MakeConst<pointer>::type operator->()
const {
return std::addressof(element);
}
public:
const size_t index;
reference element;
};
C10_ALWAYS_INLINE constexpr Proxy operator*() const {
return Proxy(*this);
}
C10_ALWAYS_INLINE constexpr Enumerator& operator++() {
++it_;
++idx_;
return *this;
}
template <typename OtherIterator>
C10_ALWAYS_INLINE constexpr bool operator==(
const Enumerator<OtherIterator>& rhs) const {
return it_ == rhs.it_;
}
template <typename OtherIterator>
C10_ALWAYS_INLINE constexpr bool operator!=(
const Enumerator<OtherIterator>& rhs) const {
return !(it_ == rhs.it_);
}
private:
template <typename OtherIterator>
friend class Enumerator;
Iterator it_;
size_t idx_ = 0;
};
template <class Range>
class RangeEnumerator {
Range r_;
using BeginIteratorType = decltype(std::declval<Range>().begin());
using EndIteratorType = decltype(std::declval<Range>().end());
public:
constexpr explicit RangeEnumerator(Range&& r) : r_(std::forward<Range>(r)) {}
constexpr Enumerator<BeginIteratorType> begin() {
return Enumerator<BeginIteratorType>(r_.begin());
}
constexpr Enumerator<EndIteratorType> end() {
return Enumerator<EndIteratorType>(r_.end());
}
};
} // namespace detail
template <class Range>
constexpr detail::RangeEnumerator<Range> enumerate(Range&& r) {
return detail::RangeEnumerator<Range>(std::forward<Range>(r));
}
} // namespace torch::nativert

View File

@ -0,0 +1,189 @@
#include "torch/csrc/nativert/common/FileUtil.h"
#include <unistd.h>
#include <cerrno>
#include <fmt/core.h>
namespace torch::nativert {
namespace {
inline void incr(ssize_t) {}
template <typename Offset>
inline void incr(ssize_t n, Offset& offset) {
offset += static_cast<Offset>(n);
}
// Wrap call to read/pread/write/pwrite(fd, buf, count, offset?) to retry on
// incomplete reads / writes. The variadic argument magic is there to support
// an additional argument (offset) for pread / pwrite; see the incr() functions
// above which do nothing if the offset is not present and increment it if it
// is.
template <class F, class... Offset>
ssize_t wrapFull(F f, int fd, void* buf, size_t count, Offset... offset) {
char* b = static_cast<char*>(buf);
ssize_t totalBytes = 0;
ssize_t r;
do {
r = f(fd, b, count, offset...);
if (r == -1) {
if (errno == EINTR) {
continue;
}
return r;
}
totalBytes += r;
b += r;
count -= r;
incr(r, offset...);
} while (r != 0 && count); // 0 means EOF
return totalBytes;
}
int filterCloseReturn(int r) {
// Ignore EINTR. On Linux, close() may only return EINTR after the file
// descriptor has been closed, so you must not retry close() on EINTR --
// in the best case, you'll get EBADF, and in the worst case, you'll end up
// closing a different file (one opened from another thread).
//
// Interestingly enough, the Single Unix Specification says that the state
// of the file descriptor is unspecified if close returns EINTR. In that
// case, the safe thing to do is also not to retry close() -- leaking a file
// descriptor is definitely better than closing the wrong file.
if (r == -1 && errno == EINTR) {
return 0;
}
return r;
}
// The following wrapX() funcions are private functions for wrapping file-io
// against interrupt and partial op completions.
// Wrap call to f(args) in loop to retry on EINTR
template <class F, class... Args>
ssize_t wrapNoInt(F f, Args... args) {
ssize_t r;
do {
r = f(args...);
} while (r == -1 && errno == EINTR);
return r;
}
} // namespace
int openNoInt(const char* name, int flags, mode_t mode) {
// Android NDK bionic with FORTIFY has this definition:
// https://android.googlesource.com/platform/bionic/+/9349b9e51b/libc/include/bits/fortify/fcntl.h
// ```
// __BIONIC_ERROR_FUNCTION_VISIBILITY
// int open(const char* pathname, int flags, mode_t modes, ...) __overloadable
// __errorattr(__open_too_many_args_error);
// ```
// This is originally to prevent open() with incorrect parameters.
//
// However, combined with folly wrapNotInt, template deduction will fail.
// In this case, we create a custom lambda to bypass the error.
// The solution is referenced from
// https://github.com/llvm/llvm-project/commit/0a0e411204a2baa520fd73a8d69b664f98b428ba
//
auto openWrapper = [&] { return open(name, flags, mode); };
return int(wrapNoInt(openWrapper));
}
int closeNoInt(int fd) {
return filterCloseReturn(close(fd));
}
ssize_t writeFull(int fd, const void* buf, size_t count) {
return wrapFull(write, fd, const_cast<void*>(buf), count);
}
ssize_t readFull(int fd, void* buf, size_t count) {
return wrapFull(read, fd, buf, count);
}
File::File(int fd, bool ownsFd) noexcept : fd_(fd), ownsFd_(ownsFd) {
TORCH_CHECK(fd >= -1, "fd must be -1 or non-negative");
TORCH_CHECK(fd != -1 || !ownsFd, "cannot own -1");
}
File::File(std::string_view name, int flags, mode_t mode)
: fd_(::open(name.data(), flags, mode)), ownsFd_(false) {
if (fd_ == -1) {
throw std::runtime_error(fmt::format(
"open(\"{}\", {}, 0{}) failed with errno {}.",
name,
flags,
mode,
errno));
}
ownsFd_ = true;
}
File::File(File&& other) noexcept : fd_(other.fd_), ownsFd_(other.ownsFd_) {
other.release();
}
File& File::operator=(File&& other) {
closeNoThrow();
swap(other);
return *this;
}
File::~File() {
auto fd = fd_;
if (!closeNoThrow()) { // ignore most errors
TORCH_CHECK(
errno != EBADF,
"closing fd ",
fd,
", it may already ",
"have been closed. Another time, this might close the wrong FD.");
}
}
/* static */ File File::temporary() {
// make a temp file with tmpfile(), dup the fd, then return it in a File.
FILE* tmpFile = tmpfile();
if (!tmpFile) {
throw std::runtime_error("tmpfile() failed");
}
auto guard = c10::make_scope_exit([&]() { fclose(tmpFile); });
int fd = ::dup(fileno(tmpFile));
if (fd == -1) {
throw std::runtime_error("dup() failed");
}
return File(fd, true);
}
int File::release() noexcept {
int released = fd_;
fd_ = -1;
ownsFd_ = false;
return released;
}
void File::swap(File& other) noexcept {
using std::swap;
swap(fd_, other.fd_);
swap(ownsFd_, other.ownsFd_);
}
void File::close() {
if (!closeNoThrow()) {
throw std::runtime_error("close() failed");
}
}
bool File::closeNoThrow() {
int r = ownsFd_ ? ::close(fd_) : 0;
release();
return r == 0;
}
} // namespace torch::nativert

View File

@ -0,0 +1,230 @@
#pragma once
/*
* Ported from folly/FileUtil.h
*/
#include <limits>
#include <string_view>
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include "c10/util/Exception.h"
#include "c10/util/ScopeExit.h"
namespace torch::nativert {
class File {
public:
/**
* Creates an empty File object, for late initialization.
*/
constexpr File() noexcept : fd_(-1), ownsFd_(false) {}
/**
* Create a File object from an existing file descriptor.
*
* @param fd Existing file descriptor
* @param ownsFd Takes ownership of the file descriptor if ownsFd is true.
*/
explicit File(int fd, bool ownsFd = false) noexcept;
/**
* Open and create a file object. Throws on error.
* Owns the file descriptor implicitly.
*/
explicit File(
std::string_view name,
int flags = O_RDONLY,
mode_t mode = 0666);
~File();
/**
* Create and return a temporary, owned file (uses tmpfile()).
*/
static File temporary();
/**
* Return the file descriptor, or -1 if the file was closed.
*/
int fd() const {
return fd_;
}
/**
* Returns 'true' iff the file was successfully opened.
*/
explicit operator bool() const {
return fd_ != -1;
}
/**
* If we own the file descriptor, close the file and throw on error.
* Otherwise, do nothing.
*/
void close();
/**
* Closes the file (if owned). Returns true on success, false (and sets
* errno) on error.
*/
bool closeNoThrow();
/**
* Returns and releases the file descriptor; no longer owned by this File.
* Returns -1 if the File object didn't wrap a file.
*/
int release() noexcept;
/**
* Swap this File with another.
*/
void swap(File& other) noexcept;
// movable
File(File&&) noexcept;
File& operator=(File&&);
private:
// unique
File(const File&) = delete;
File& operator=(const File&) = delete;
int fd_;
bool ownsFd_;
};
/**
* Convenience wrappers around some commonly used system calls. The *NoInt
* wrappers retry on EINTR. The *Full wrappers retry on EINTR and also loop
* until all data is written. Note that *Full wrappers weaken the thread
* semantics of underlying system calls.
*/
int openNoInt(const char* name, int flags, mode_t mode = 0666);
int closeNoInt(int fd);
/**
* Similar to readFull and preadFull above, wrappers around write() and
* pwrite() that loop until all data is written.
*
* Generally, the write() / pwrite() system call may always write fewer bytes
* than requested, just like read(). In certain cases (such as when writing to
* a pipe), POSIX provides stronger guarantees, but not in the general case.
* For example, Linux (even on a 64-bit platform) won't write more than 2GB in
* one write() system call.
*
* Note that writevFull and pwritevFull require iov to be non-const, unlike
* writev and pwritev. The contents of iov after these functions return
* is unspecified.
*
* These functions return -1 on error, or the total number of bytes written
* (which is always the same as the number of requested bytes) on success.
*/
ssize_t writeFull(int fd, const void* buf, size_t count);
/**
* Wrapper around read() (and pread()) that, in addition to retrying on
* EINTR, will loop until all data is read.
*
* This wrapper is only useful for blocking file descriptors (for non-blocking
* file descriptors, you have to be prepared to deal with incomplete reads
* anyway), and only exists because POSIX allows read() to return an incomplete
* read if interrupted by a signal (instead of returning -1 and setting errno
* to EINTR).
*
* Note that this wrapper weakens the thread safety of read(): the file pointer
* is shared between threads, but the system call is atomic. If multiple
* threads are reading from a file at the same time, you don't know where your
* data came from in the file, but you do know that the returned bytes were
* contiguous. You can no longer make this assumption if using readFull().
* You should probably use pread() when reading from the same file descriptor
* from multiple threads simultaneously, anyway.
*
* Note that readvFull and preadvFull require iov to be non-const, unlike
* readv and preadv. The contents of iov after these functions return
* is unspecified.
*/
[[nodiscard]] ssize_t readFull(int fd, void* buf, size_t count);
/**
* Read entire file (if num_bytes is defaulted) or no more than
* num_bytes (otherwise) into container *out. The container is assumed
* to be contiguous, with element size equal to 1, and offer size(),
* reserve(), and random access (e.g. std::vector<char>, std::string,
* fbstring).
*
* Returns: true on success or false on failure. In the latter case
* errno will be set appropriately by the failing system primitive.
*/
template <class Container>
bool readFile(
int fd,
Container& out,
size_t num_bytes = std::numeric_limits<size_t>::max()) {
static_assert(
sizeof(out[0]) == 1,
"readFile: only containers with byte-sized elements accepted");
size_t soFar = 0; // amount of bytes successfully read
auto guard = c10::make_scope_exit([&]() {
assert(out.size() >= soFar); // resize better doesn't throw
out.resize(soFar);
});
// Obtain file size:
struct stat buf;
if (fstat(fd, &buf) == -1) {
return false;
}
// Some files (notably under /proc and /sys on Linux) lie about
// their size, so treat the size advertised by fstat under advise
// but don't rely on it. In particular, if the size is zero, we
// should attempt to read stuff. If not zero, we'll attempt to read
// one extra byte.
constexpr size_t initialAlloc = 1024 * 4;
out.resize(std::min(
buf.st_size > 0 ? (size_t(buf.st_size) + 1) : initialAlloc, num_bytes));
while (soFar < out.size()) {
const auto actual = readFull(fd, &out[soFar], out.size() - soFar);
if (actual == -1) {
return false;
}
soFar += actual;
if (soFar < out.size()) {
// File exhausted
break;
}
// Ew, allocate more memory. Use exponential growth to avoid
// quadratic behavior. Cap size to num_bytes.
out.resize(std::min(out.size() * 3 / 2, num_bytes));
}
return true;
}
/**
* Same as above, but takes in a file name instead of fd
*/
template <class Container>
bool readFile(
const char* file_name,
Container& out,
size_t num_bytes = std::numeric_limits<size_t>::max()) {
TORCH_CHECK(file_name);
const auto fd = openNoInt(file_name, O_RDONLY | O_CLOEXEC);
if (fd == -1) {
return false;
}
auto guard = c10::make_scope_exit([&]() {
// Ignore errors when closing the file
closeNoInt(fd);
});
return readFile(fd, out, num_bytes);
}
} // namespace torch::nativert

View File

@ -0,0 +1,189 @@
#pragma once
#include <c10/util/Exception.h>
namespace torch::nativert {
template <typename T>
class IntrusiveList;
class IntrusiveListHook {
template <typename P, typename T>
friend class ListIterator;
template <typename T>
friend class IntrusiveList;
IntrusiveListHook* next_;
IntrusiveListHook* prev_;
void link_before(IntrusiveListHook* next_node) {
next_ = next_node;
prev_ = next_node->prev_;
next_node->prev_ = this;
prev_->next_ = this;
}
public:
IntrusiveListHook() : next_(this), prev_(this) {}
IntrusiveListHook(const IntrusiveListHook&) = delete;
IntrusiveListHook& operator=(const IntrusiveListHook&) = delete;
IntrusiveListHook(IntrusiveListHook&&) = delete;
IntrusiveListHook& operator=(IntrusiveListHook&&) = delete;
void unlink() {
TORCH_CHECK(is_linked());
next_->prev_ = prev_;
prev_->next_ = next_;
next_ = this;
prev_ = this;
}
~IntrusiveListHook() {
if (is_linked()) {
unlink();
}
}
bool is_linked() const {
return next_ != this;
}
};
template <typename P, typename T>
class ListIterator {
static_assert(std::is_same_v<std::remove_const_t<P>, IntrusiveListHook>);
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
P* ptr_;
friend class IntrusiveList<T>;
public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = std::conditional_t<std::is_const_v<P>, const T, T>;
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
explicit ListIterator(P* ptr) : ptr_(ptr) {}
ListIterator(const ListIterator&) = default;
ListIterator& operator=(const ListIterator&) = default;
template <
typename Q,
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
ListIterator(const ListIterator<Q, T>& rhs) : ptr_(rhs.ptr_) {}
template <
typename Q,
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
ListIterator& operator=(const ListIterator<Q, T>& rhs) {
ptr_ = rhs.ptr_;
return *this;
}
template <typename Q>
bool operator==(const ListIterator<Q, T>& other) const {
return ptr_ == other.ptr_;
}
template <typename Q>
bool operator!=(const ListIterator<Q, T>& other) const {
return !(*this == other);
}
auto& operator*() const {
return static_cast<reference>(*ptr_);
}
ListIterator& operator++() {
TORCH_CHECK(ptr_);
ptr_ = ptr_->next_;
return *this;
}
ListIterator& operator--() {
TORCH_CHECK(ptr_);
ptr_ = ptr_->prev_;
return *this;
}
auto* operator->() const {
return static_cast<pointer>(ptr_);
}
};
template <typename T>
class IntrusiveList {
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
public:
using iterator = ListIterator<IntrusiveListHook, T>;
using const_iterator = ListIterator<const IntrusiveListHook, T>;
auto begin() const {
return ++const_iterator{&head_};
}
auto begin() {
return ++iterator{&head_};
}
auto end() const {
return const_iterator{&head_};
}
auto end() {
return iterator{&head_};
}
auto rbegin() const {
return std::reverse_iterator{end()};
}
auto rbegin() {
return std::reverse_iterator{end()};
}
auto rend() const {
return std::reverse_iterator{begin()};
}
auto rend() {
return std::reverse_iterator{begin()};
}
auto iterator_to(const T& n) const {
return const_iterator{&n};
}
auto iterator_to(T& n) {
return iterator{&n};
}
iterator insert(iterator pos, T& n) {
n.link_before(pos.ptr_);
return iterator{&n};
}
size_t size() const {
size_t ret = 0;
for ([[maybe_unused]] auto& _ : *this) {
ret++;
}
return ret;
}
~IntrusiveList() {
while (head_.is_linked()) {
head_.next_->unlink();
}
}
private:
IntrusiveListHook head_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,48 @@
/*
* A simple multi-producer, multi-consumer queue we rolled on our own.
*
* This is a wrapper around std::deque that provides
* lock-free readIfNotEmpty and writeIfNotFull.
*/
#pragma once
#include <deque>
#include <mutex>
#include <type_traits>
namespace torch::nativert {
// TODO (zhxchen17) Add wrapper for concurrentqueue.
template <typename T>
class MPMCQueue {
static_assert(!std::is_reference_v<T>);
public:
explicit MPMCQueue(size_t capacity) : capacity_(capacity) {}
bool readIfNotEmpty(T& out) {
std::lock_guard<std::mutex> lock(mutex_);
if (storage_.empty()) {
return false;
}
out = std::move(storage_.front());
storage_.pop_front();
return true;
}
bool writeIfNotFull(T&& in) {
std::lock_guard<std::mutex> lock(mutex_);
if (storage_.size() == capacity_) {
return false;
}
storage_.push_back(std::move(in));
return true;
}
private:
std::mutex mutex_;
std::deque<T> storage_;
const size_t capacity_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,439 @@
#include "torch/csrc/nativert/common/Pytree.h"
#include "torch/csrc/nativert/common/RecordFunction.h"
#include <iterator>
#include <string_view>
#include <ATen/core/ivalue.h>
#include <c10/util/Synchronized.h>
#include <nlohmann/json.hpp> // @manual=fbsource//third-party/nlohmann-json:nlohmann-json
namespace torch::nativert {
namespace {
inline constexpr int kDefaultTreeSpecSerializationProtocol = 1;
c10::IValue dynamicToIValue(const nlohmann::json& obj) {
if (obj.is_string()) {
return obj.get<std::string>();
} else if (obj.is_number_integer()) {
return obj.get<int64_t>();
} else {
TORCH_CHECK(false, "Unsupported dynamic type: ", obj);
}
}
void treeFlatten(
const c10::IValue& tree,
const TreeSpec& spec,
std::vector<c10::IValue>& leaves) {
if (spec.isLeaf()) {
leaves.push_back(tree);
return;
}
auto flattenFn = spec.nodeDefCache().flattenFn;
flattenFn(tree, spec, leaves);
}
class PytreeNodeRegistry {
public:
PytreeNodeRegistry() {
// Add some law of physics here.
registerNode(
"builtins.tuple",
NodeDef{
[](const c10::IValue& tree,
const TreeSpec& spec,
std::vector<c10::IValue>& leaves) {
const auto& tuple = tree.toTupleRef().elements();
TORCH_CHECK_EQ(tuple.size(), spec.children().size());
for (size_t i = 0; i < tuple.size(); i++) {
treeFlatten(tuple[i], spec.children(i), leaves);
}
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
return c10::ivalue::Tuple::create(std::move(flats));
},
[](TreeMapNoReturnFn fn,
const c10::IValue& tree,
const TreeSpec& spec) {
const auto& tuple = tree.toTupleRef().elements();
TORCH_CHECK_EQ(tuple.size(), spec.children().size());
for (size_t i = 0; i < tuple.size(); i++) {
leafApply(fn, tuple[i], spec.children(i));
}
}});
const auto& tupleNodeDef = getNodeDef("builtins.tuple");
registerNode(
"collections.namedtuple",
NodeDef{
tupleNodeDef.flattenFn,
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
return c10::ivalue::Tuple::create(std::move(flats));
},
tupleNodeDef.leafApplyFn,
[](std::string_view context) { return nlohmann::json{context}; }});
registerNode(
"builtins.list",
NodeDef{
[](const c10::IValue& tree,
const TreeSpec& spec,
std::vector<c10::IValue>& leaves) {
auto list = tree.toList();
for (size_t i = 0; i < list.size(); i++) {
treeFlatten(list[i], spec.children(i), leaves);
}
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
c10::List<c10::IValue> list(c10::AnyType::get());
list.reserve(flats.size());
for (auto& flat : flats) {
list.push_back(std::move(flat));
}
return list;
},
[](TreeMapNoReturnFn fn,
const c10::IValue& tree,
const TreeSpec& spec) {
auto list = tree.toList();
for (size_t i = 0; i < list.size(); i++) {
leafApply(fn, list[i], spec.children(i));
}
}});
registerNode(
"torch.fx.immutable_collections.immutable_list",
getNodeDef("builtins.list"));
registerNode(
"builtins.dict",
NodeDef{
[](const c10::IValue& tree,
const TreeSpec& spec,
std::vector<c10::IValue>& leaves) {
auto dict = tree.toGenericDict();
const auto& context = spec.context();
TORCH_CHECK_EQ(dict.size(), context.size());
size_t i = 0;
for (const auto& keyObj : context) {
auto key = dynamicToIValue(keyObj);
auto it = dict.find(key);
if (it != dict.end()) {
treeFlatten(it->value(), spec.children(i), leaves);
} else {
// when we have a dict with missing keys, we fill the missing
// leaves with c10::IValue()
for (size_t j = 0; j < spec.children(i).numLeaves(); ++j) {
leaves.emplace_back();
}
}
i++;
}
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
c10::Dict<c10::IValue, c10::IValue> dict(
c10::AnyType::get(), c10::AnyType::get());
TORCH_CHECK(obj.is_array());
TORCH_CHECK_EQ(obj.size(), flats.size());
dict.reserve(flats.size());
for (size_t i = 0; i < flats.size(); i++) {
dict.insert(dynamicToIValue(obj[i]), std::move(flats[i]));
}
return dict;
},
[](TreeMapNoReturnFn fn,
const c10::IValue& tree,
const TreeSpec& spec) {
auto dict = tree.toGenericDict();
const auto& context = spec.context();
TORCH_CHECK(
dict.size() <= context.size(),
"input dict has more keys than treeSepc");
size_t i = 0;
for (const auto& keyObj : context) {
auto key = dynamicToIValue(keyObj);
auto it = dict.find(key);
if (it != dict.end()) {
leafApply(fn, it->value(), spec.children(i));
} else {
// when we have a dict with missing keys, we run fn
// on leaves with value of c10::IValue()
for (size_t j = 0; j < spec.children(i).numLeaves(); ++j) {
fn(c10::IValue());
}
}
i++;
}
}});
registerNode(
"torch.fx.immutable_collections.immutable_dict",
getNodeDef("builtins.dict"));
}
bool hasNodeDef(std::string_view typeName) const {
return registry_.find(std::string{typeName}) != registry_.end();
}
const NodeDef& getNodeDef(std::string_view typeName) const {
return registry_.at(std::string{typeName});
}
void registerNode(std::string_view typeName, NodeDef nodeDef) {
TORCH_CHECK(!hasNodeDef(typeName));
registry_.emplace(typeName, std::move(nodeDef));
}
private:
std::unordered_map<std::string, NodeDef> registry_;
};
c10::Synchronized<PytreeNodeRegistry>& getPytreeNodeRegistry() {
static auto* registry = new c10::Synchronized<PytreeNodeRegistry>();
return *registry;
}
TreeSpec makeTreeSpec(const nlohmann::json& obj) {
TORCH_CHECK(obj.is_object());
TORCH_CHECK(obj.find("type") != obj.end());
if (obj["type"].is_null()) {
TORCH_CHECK_EQ(obj["children_spec"].size(), 0);
TORCH_CHECK(obj["context"].is_null());
return TreeSpec{};
}
const auto& name = obj["type"].get<std::string>();
NodeDef nodeDefCache;
getPytreeNodeRegistry().withLock([&](auto& registry) {
TORCH_CHECK(registry.hasNodeDef(name), "Unknown pytree node type: ", name);
nodeDefCache = registry.getNodeDef(name);
});
auto context = nodeDefCache.contextLoadFn(obj["context"].get<std::string>());
const auto& childrenSpec = obj["children_spec"];
TORCH_CHECK(childrenSpec.is_array());
std::vector<TreeSpec> children;
for (const auto& child : childrenSpec) {
children.push_back(makeTreeSpec(child));
}
return TreeSpec(name, context, std::move(children), std::move(nodeDefCache));
}
} // namespace
void registerPytreeNode(std::string_view typeName, NodeDef nodeDef) {
getPytreeNodeRegistry().withLock([&](auto& registry) {
registry.registerNode(typeName, std::move(nodeDef));
});
}
TreeSpec treeSpecLoads(std::string_view json) {
const auto obj = nlohmann::json::parse(json);
TORCH_CHECK(obj.is_array());
TORCH_CHECK_EQ(obj.size(), 2);
TORCH_CHECK_EQ(obj[0].get<int64_t>(), kDefaultTreeSpecSerializationProtocol);
return makeTreeSpec(obj[1]);
}
c10::IValue treeUnflatten(
std::vector<c10::IValue> leaves,
const TreeSpec& spec) {
RecordFunction recordFunction("nativert::treeUnflatten");
TORCH_CHECK_EQ(leaves.size(), spec.numLeaves());
if (spec.isLeaf()) {
return std::move(leaves[0]);
}
auto unflattenFn = spec.nodeDefCache().unflattenFn;
if (spec.allLeaves()) {
return unflattenFn(std::move(leaves), spec.context());
}
size_t start = 0;
std::vector<c10::IValue> childrenPytrees;
for (const auto& child : spec.children()) {
if (child.isLeaf()) {
childrenPytrees.push_back(std::move(leaves[start]));
start++;
continue;
}
size_t numLeaves = child.numLeaves();
std::vector<c10::IValue> slice(
std::make_move_iterator(leaves.begin() + start),
std::make_move_iterator(leaves.begin() + start + numLeaves));
childrenPytrees.push_back(treeUnflatten(std::move(slice), child));
start += numLeaves;
}
return unflattenFn(std::move(childrenPytrees), spec.context());
}
std::vector<c10::IValue> treeFlatten(
const c10::IValue& tree,
const TreeSpec& spec) {
std::vector<c10::IValue> leaves;
leaves.reserve(spec.numLeaves());
treeFlatten(tree, spec, leaves);
return leaves;
}
std::vector<c10::IValue> treeFlattenFromArgs(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& spec) {
RecordFunction recordFunction("nativert::treeFlattenFromArgs");
TORCH_CHECK(!spec.isLeaf());
TORCH_CHECK_EQ(spec.children().size(), 2);
std::vector<c10::IValue> leaves;
leaves.reserve(spec.numLeaves());
const auto& specArgs = spec.children(0);
TORCH_CHECK(!specArgs.isLeaf());
TORCH_CHECK_EQ(specArgs.children().size(), args.size());
for (size_t i = 0; i < args.size(); i++) {
treeFlatten(args[i], specArgs.children(i), leaves);
}
const auto& specKwargs = spec.children(1);
TORCH_CHECK(!specKwargs.isLeaf());
TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size());
for (size_t i = 0; i < specKwargs.context().size(); i++) {
treeFlatten(
kwargs.at(specKwargs.context()[i].get<std::string>()),
specKwargs.children(i),
leaves);
}
return leaves;
}
void leafApplyFromArgs(
TreeMapNoReturnFn fn,
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& spec) {
RecordFunction recordFunction("nativert::leafApplyFromArgs");
TORCH_CHECK(!spec.isLeaf());
TORCH_CHECK_EQ(spec.children().size(), 2);
std::vector<c10::IValue> leaves;
leaves.reserve(spec.numLeaves());
const auto& specArgs = spec.children(0);
TORCH_CHECK(!specArgs.isLeaf());
TORCH_CHECK_EQ(specArgs.children().size(), args.size());
for (size_t i = 0; i < args.size(); i++) {
leafApply(fn, args[i], specArgs.children(i));
}
const auto& specKwargs = spec.children(1);
TORCH_CHECK(!specKwargs.isLeaf());
TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size());
for (size_t i = 0; i < specKwargs.context().size(); i++) {
leafApply(
fn,
kwargs.at(specKwargs.context()[i].get<std::string>()),
specKwargs.children(i));
}
}
std::vector<at::Tensor> treeFlattenToTensorList(
const c10::IValue& tree,
const TreeSpec& spec) {
auto flats = treeFlatten(tree, spec);
std::vector<at::Tensor> tensors;
tensors.reserve(flats.size());
for (const auto& flat : flats) {
tensors.push_back(flat.toTensor());
}
return tensors;
}
c10::IValue
treeMap(TreeMapFn f, const c10::IValue& tree, const TreeSpec& spec) {
const auto flats = treeFlatten(tree, spec);
std::vector<c10::IValue> mapped;
mapped.reserve(flats.size());
for (const auto& flat : flats) {
mapped.push_back(f(flat));
}
return treeUnflatten(std::move(mapped), spec);
}
c10::IValue argsToIValue(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
for (const auto& [key, arg] : kwargs) {
dict.insert(key, arg);
}
return c10::ivalue::Tuple::create({c10::ivalue::Tuple::create(args), dict});
}
std::
pair<std::vector<c10::IValue>, std::unordered_map<std::string, c10::IValue>>
treeMapArgs(
TreeMapFn f,
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& spec) {
const auto val = argsToIValue(args, kwargs);
const auto mapVal = treeMap(f, val, spec);
auto mapArgs =
mapVal.toTupleRef().elements()[0].toTupleRef().elements().vec();
std::unordered_map<std::string, c10::IValue> mapKwargs;
for (const auto& entry : mapVal.toTupleRef().elements()[1].toGenericDict()) {
mapKwargs.emplace(entry.key().toStringRef(), entry.value());
}
return {std::move(mapArgs), std::move(mapKwargs)};
}
void leafApply(
TreeMapNoReturnFn fn,
const c10::IValue& tree,
const TreeSpec& spec) {
if (spec.isLeaf()) {
fn(tree);
return;
}
auto leafApplyFn = spec.nodeDefCache().leafApplyFn;
leafApplyFn(fn, tree, spec);
}
nlohmann::json defaultContextLoadFn(std::string_view context) {
return nlohmann::json::parse(context);
}
c10::TypePtr TreeSpec::toAtenType() const {
if (isLeaf()) {
return c10::AnyType::get();
} else if (uniformName_ == "builtins.tuple") {
std::vector<c10::TypePtr> childrenType;
for (const auto& childrenSpec : children_) {
childrenType.emplace_back(childrenSpec.toAtenType());
}
return c10::TupleType::create(std::move(childrenType));
} else if (
uniformName_ == "builtins.list" ||
uniformName_ == "torch.fx.immutable_collections.immutable_list") {
if (children_.empty()) {
return c10::ListType::create(c10::AnyType::get());
} else {
return c10::ListType::create(children_[0].toAtenType());
}
} else if (
uniformName_ == "builtins.dict" ||
uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
if (children_.empty()) {
return c10::DictType::create(c10::AnyType::get(), c10::AnyType::get());
} else {
return c10::DictType::create(
dynamicToIValue(context_[0]).type(), children_[0].toAtenType());
}
} else {
TORCH_CHECK(false, "Unsupported uniform name: ", uniformName_.value());
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,156 @@
#pragma once
#include <optional>
#include <string_view>
#include <unordered_map>
#include <vector>
#include <ATen/core/ivalue.h>
#include <nlohmann/json.hpp>
namespace torch::nativert {
class TreeSpec;
using TreeFlattenFn =
void (*)(const c10::IValue&, const TreeSpec&, std::vector<c10::IValue>&);
using TreeUnflattenFn =
c10::IValue (*)(std::vector<c10::IValue>, const nlohmann::json&);
using ContextLoadFn = nlohmann::json (*)(std::string_view);
using TreeMapFn = c10::function_ref<c10::IValue(const c10::IValue&)>;
using TreeMapNoReturnFn = c10::function_ref<void(const c10::IValue&)>;
using LeafApplyFn =
void (*)(TreeMapNoReturnFn, const c10::IValue&, const TreeSpec&);
nlohmann::json defaultContextLoadFn(std::string_view);
struct NodeDef {
TreeFlattenFn flattenFn;
TreeUnflattenFn unflattenFn;
LeafApplyFn leafApplyFn;
ContextLoadFn contextLoadFn = defaultContextLoadFn;
};
class TreeSpec {
public:
// Leaf node.
TreeSpec() : numLeaves_(1) {}
// Non leaf node.
TreeSpec(
std::string_view uniformName,
nlohmann::json context,
std::vector<TreeSpec> children,
NodeDef nodeDefCache)
: uniformName_(uniformName),
context_(std::move(context)),
children_(std::move(children)),
nodeDefCache_(nodeDefCache),
numLeaves_(0) {
for (auto& child : children_) {
numLeaves_ += child.numLeaves();
allLeaves_ &= child.isLeaf();
}
}
bool isLeaf() const {
return !uniformName_;
}
std::string_view uniformName() const {
TORCH_CHECK(uniformName_);
return uniformName_.value();
}
const nlohmann::json& context() const {
return context_;
}
const auto& children() const {
return children_;
}
const TreeSpec& children(size_t i) const {
return children_[i];
}
const NodeDef& nodeDefCache() const {
return nodeDefCache_;
}
size_t numLeaves() const {
return numLeaves_;
}
bool allLeaves() const {
return allLeaves_;
}
c10::TypePtr toAtenType() const;
private:
// Only non leaf nodes have names.
// Examples of uniform name: "builtins.tuple", "builtins.dict".
std::optional<std::string> uniformName_;
nlohmann::json context_;
std::vector<TreeSpec> children_;
// Cached fields.
NodeDef nodeDefCache_;
size_t numLeaves_;
bool allLeaves_ = true;
};
void registerPytreeNode(std::string_view typeName, NodeDef nodeDef);
// Serialized json tree spec should be dumped from treespec_dumps() in
// torch.utils._pytree directly .
TreeSpec treeSpecLoads(std::string_view json);
c10::IValue treeUnflatten(
std::vector<c10::IValue> leaves,
const TreeSpec& spec);
std::vector<c10::IValue> treeFlatten(
const c10::IValue& tree,
const TreeSpec& spec);
std::vector<c10::IValue> treeFlattenFromArgs(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& spec);
std::vector<at::Tensor> treeFlattenToTensorList(
const c10::IValue& tree,
const TreeSpec& spec);
c10::IValue treeMap(TreeMapFn f, const c10::IValue& tree, const TreeSpec& spec);
c10::IValue TORCH_API argsToIValue(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
std::
pair<std::vector<c10::IValue>, std::unordered_map<std::string, c10::IValue>>
treeMapArgs(
TreeMapFn f,
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& spec);
void leafApply(
TreeMapNoReturnFn f,
const c10::IValue& tree,
const TreeSpec& spec);
void leafApplyFromArgs(
TreeMapNoReturnFn fn,
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& spec);
} // namespace torch::nativert

View File

@ -0,0 +1,32 @@
#pragma once
#include <torch/library.h> //@manual=//caffe2:libtorch
#include "torch/csrc/autograd/record_function_ops.h" //@manual=//caffe2:libtorch
namespace torch::nativert {
/**
* RAII-style wrapper that behaves similarly to torch.profiler.record_function.
*/
class RecordFunction {
public:
RecordFunction() = delete;
RecordFunction(const RecordFunction&) = default;
RecordFunction& operator=(const RecordFunction&) = default;
RecordFunction(RecordFunction&&) = default;
RecordFunction& operator=(RecordFunction&&) = default;
explicit RecordFunction(const std::string& name) {
recordFunction_ =
torch::autograd::profiler::record_function_enter_new(name);
}
~RecordFunction() {
recordFunction_->record.end();
}
private:
c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction>
recordFunction_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,38 @@
#pragma once
#include <cstddef>
// To use moodycamel semaphore, we need to include the header file
// for concurrentqueue first. Hiding implementation detail here.
#ifdef BLOCK_SIZE
#pragma push_macro("BLOCK_SIZE")
#undef BLOCK_SIZE
#include "torch/csrc/nativert/common/concurrentqueue.h"
#pragma pop_macro("BLOCK_SIZE")
#else
#include "torch/csrc/nativert/common/concurrentqueue.h"
#endif
#include "torch/csrc/nativert/common/lightweightsemaphore.h"
namespace torch::nativert {
// A textbook semaphore implementation. Nothing fancy.
// In the future, we can consider using C++20's semaphore.
class Semaphore {
moodycamel::LightweightSemaphore impl_;
public:
void release() {
impl_.signal();
}
void release(size_t n) {
impl_.signal(n);
}
void acquire() {
impl_.wait();
}
};
} // namespace torch::nativert

View File

@ -0,0 +1,46 @@
#include "torch/csrc/nativert/common/String.h"
#include <sstream>
namespace torch::nativert {
std::vector<std::string_view> split(std::string_view target, char delimiter) {
std::vector<std::string_view> atoms;
std::string_view buffer = target;
while (buffer.size() > 0) {
auto i = buffer.find(delimiter);
if (i == std::string_view::npos) {
atoms.push_back(buffer);
buffer.remove_prefix(buffer.size());
} else {
atoms.push_back(buffer.substr(0, i));
buffer.remove_prefix(i + 1);
}
}
return atoms;
}
std::string join(
std::string_view delimiter,
const std::vector<std::string>& keys) {
std::ostringstream result;
for (size_t i = 0; i < keys.size(); i++) {
result << keys[i];
if (i != keys.size() - 1) {
result << delimiter;
}
}
return result.str();
}
bool starts_with(std::string_view str, std::string_view prefix) {
return str.size() >= prefix.size() &&
str.compare(0, prefix.size(), prefix) == 0;
}
bool ends_with(std::string_view str, std::string_view suffix) {
return str.size() >= suffix.size() &&
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
}
} // namespace torch::nativert

View File

@ -0,0 +1,19 @@
#pragma once
#include <string_view>
#include <vector>
namespace torch::nativert {
std::vector<std::string_view> split(std::string_view target, char delimiter);
std::string join(
std::string_view delimiter,
const std::vector<std::string>& keys);
// These helpers should be replaced by string_view.starts_with and
// string_view.ends_with in C++20, when they are available.
bool starts_with(std::string_view target, std::string_view prefix);
bool ends_with(std::string_view target, std::string_view prefix);
} // namespace torch::nativert

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,427 @@
// Provides an efficient implementation of a semaphore (LightweightSemaphore).
// This is an extension of Jeff Preshing's sempahore implementation (licensed
// under the terms of its separate zlib license) that has been adapted and
// extended by Cameron Desrochers.
#pragma once
#include <atomic>
#include <cstddef> // For std::size_t
#include <type_traits> // For std::make_signed<T>
#if defined(_WIN32)
// Avoid including windows.h in a header; we only need a handful of
// items, so we'll redeclare them here (this is relatively safe since
// the API generally has to remain stable between Windows versions).
// I know this is an ugly hack but it still beats polluting the global
// namespace with thousands of generic names or adding a .cpp for nothing.
extern "C" {
struct _SECURITY_ATTRIBUTES;
__declspec(dllimport) void* __stdcall CreateSemaphoreW(
_SECURITY_ATTRIBUTES* lpSemaphoreAttributes,
long lInitialCount,
long lMaximumCount,
const wchar_t* lpName);
__declspec(dllimport) int __stdcall CloseHandle(void* hObject);
__declspec(dllimport) unsigned long __stdcall WaitForSingleObject(
void* hHandle,
unsigned long dwMilliseconds);
__declspec(dllimport) int __stdcall ReleaseSemaphore(
void* hSemaphore,
long lReleaseCount,
long* lpPreviousCount);
}
#elif defined(__MACH__)
#include <mach/mach.h> // @manual
#elif defined(__MVS__)
#include <zos-semaphore.h> // @manual
#elif defined(__unix__)
#include <semaphore.h>
#if defined(__GLIBC_PREREQ) && defined(_GNU_SOURCE)
#if __GLIBC_PREREQ(2, 30)
#define MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC
#endif
#endif
#endif
namespace moodycamel {
namespace details {
// Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's
// portable + lightweight semaphore implementations, originally from
// https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h
// LICENSE:
// Copyright (c) 2015 Jeff Preshing
//
// This software is provided 'as-is', without any express or implied
// warranty. In no event will the authors be held liable for any damages
// arising from the use of this software.
//
// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it
// freely, subject to the following restrictions:
//
// 1. The origin of this software must not be misrepresented; you must not
// claim that you wrote the original software. If you use this software
// in a product, an acknowledgement in the product documentation would be
// appreciated but is not required.
// 2. Altered source versions must be plainly marked as such, and must not be
// misrepresented as being the original software.
// 3. This notice may not be removed or altered from any source distribution.
#if defined(_WIN32)
class Semaphore {
private:
void* m_hSema;
Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
public:
Semaphore(int initialCount = 0) {
assert(initialCount >= 0);
const long maxLong = 0x7fffffff;
m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr);
assert(m_hSema);
}
~Semaphore() {
CloseHandle(m_hSema);
}
bool wait() {
const unsigned long infinite = 0xffffffff;
return WaitForSingleObject(m_hSema, infinite) == 0;
}
bool try_wait() {
return WaitForSingleObject(m_hSema, 0) == 0;
}
bool timed_wait(std::uint64_t usecs) {
return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) == 0;
}
void signal(int count = 1) {
while (!ReleaseSemaphore(m_hSema, count, nullptr))
;
}
};
#elif defined(__MACH__)
//---------------------------------------------------------
// Semaphore (Apple iOS and OSX)
// Can't use POSIX semaphores due to
// http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html
//---------------------------------------------------------
class Semaphore {
private:
semaphore_t m_sema;
Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
public:
Semaphore(int initialCount = 0) {
assert(initialCount >= 0);
kern_return_t rc = semaphore_create(
mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount);
assert(rc == KERN_SUCCESS);
(void)rc;
}
~Semaphore() {
semaphore_destroy(mach_task_self(), m_sema);
}
bool wait() {
return semaphore_wait(m_sema) == KERN_SUCCESS;
}
bool try_wait() {
return timed_wait(0);
}
bool timed_wait(std::uint64_t timeout_usecs) {
mach_timespec_t ts;
ts.tv_sec = static_cast<unsigned int>(timeout_usecs / 1000000);
ts.tv_nsec = static_cast<int>((timeout_usecs % 1000000) * 1000);
// added in OSX 10.10:
// https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html
kern_return_t rc = semaphore_timedwait(m_sema, ts);
return rc == KERN_SUCCESS;
}
void signal() {
while (semaphore_signal(m_sema) != KERN_SUCCESS)
;
}
void signal(int count) {
while (count-- > 0) {
while (semaphore_signal(m_sema) != KERN_SUCCESS)
;
}
}
};
#elif defined(__unix__) || defined(__MVS__)
//---------------------------------------------------------
// Semaphore (POSIX, Linux, zOS)
//---------------------------------------------------------
class Semaphore {
private:
sem_t m_sema;
Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION;
public:
Semaphore(int initialCount = 0) {
assert(initialCount >= 0);
int rc = sem_init(&m_sema, 0, static_cast<unsigned int>(initialCount));
assert(rc == 0);
(void)rc;
}
~Semaphore() {
sem_destroy(&m_sema);
}
bool wait() {
// http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error
int rc;
do {
rc = sem_wait(&m_sema);
} while (rc == -1 && errno == EINTR);
return rc == 0;
}
bool try_wait() {
int rc;
do {
rc = sem_trywait(&m_sema);
} while (rc == -1 && errno == EINTR);
return rc == 0;
}
bool timed_wait(std::uint64_t usecs) {
struct timespec ts;
const int usecs_in_1_sec = 1000000;
const int nsecs_in_1_sec = 1000000000;
#ifdef MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC
clock_gettime(CLOCK_MONOTONIC, &ts);
#else
clock_gettime(CLOCK_REALTIME, &ts);
#endif
ts.tv_sec += (time_t)(usecs / usecs_in_1_sec);
ts.tv_nsec += (long)(usecs % usecs_in_1_sec) * 1000;
// sem_timedwait bombs if you have more than 1e9 in tv_nsec
// so we have to clean things up before passing it in
if (ts.tv_nsec >= nsecs_in_1_sec) {
ts.tv_nsec -= nsecs_in_1_sec;
++ts.tv_sec;
}
int rc;
do {
#ifdef MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC
rc = sem_clockwait(&m_sema, CLOCK_MONOTONIC, &ts);
#else
rc = sem_timedwait(&m_sema, &ts);
#endif
} while (rc == -1 && errno == EINTR);
return rc == 0;
}
void signal() {
while (sem_post(&m_sema) == -1)
;
}
void signal(int count) {
while (count-- > 0) {
while (sem_post(&m_sema) == -1)
;
}
}
};
#else
#error Unsupported platform! (No semaphore wrapper available)
#endif
} // end namespace details
//---------------------------------------------------------
// LightweightSemaphore
//---------------------------------------------------------
class LightweightSemaphore {
public:
typedef std::make_signed<std::size_t>::type ssize_t;
private:
std::atomic<ssize_t> m_count;
details::Semaphore m_sema;
int m_maxSpins;
bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) {
ssize_t oldCount;
int spin = m_maxSpins;
while (--spin >= 0) {
oldCount = m_count.load(std::memory_order_relaxed);
if ((oldCount > 0) &&
m_count.compare_exchange_strong(
oldCount,
oldCount - 1,
std::memory_order_acquire,
std::memory_order_relaxed))
return true;
std::atomic_signal_fence(
std::memory_order_acquire); // Prevent the compiler from collapsing
// the loop.
}
oldCount = m_count.fetch_sub(1, std::memory_order_acquire);
if (oldCount > 0)
return true;
if (timeout_usecs < 0) {
if (m_sema.wait())
return true;
}
if (timeout_usecs > 0 && m_sema.timed_wait((std::uint64_t)timeout_usecs))
return true;
// At this point, we've timed out waiting for the semaphore, but the
// count is still decremented indicating we may still be waiting on
// it. So we have to re-adjust the count, but only if the semaphore
// wasn't signaled enough times for us too since then. If it was, we
// need to release the semaphore too.
while (true) {
oldCount = m_count.load(std::memory_order_acquire);
if (oldCount >= 0 && m_sema.try_wait())
return true;
if (oldCount < 0 &&
m_count.compare_exchange_strong(
oldCount,
oldCount + 1,
std::memory_order_relaxed,
std::memory_order_relaxed))
return false;
}
}
ssize_t waitManyWithPartialSpinning(
ssize_t max,
std::int64_t timeout_usecs = -1) {
assert(max > 0);
ssize_t oldCount;
int spin = m_maxSpins;
while (--spin >= 0) {
oldCount = m_count.load(std::memory_order_relaxed);
if (oldCount > 0) {
ssize_t newCount = oldCount > max ? oldCount - max : 0;
if (m_count.compare_exchange_strong(
oldCount,
newCount,
std::memory_order_acquire,
std::memory_order_relaxed))
return oldCount - newCount;
}
std::atomic_signal_fence(std::memory_order_acquire);
}
oldCount = m_count.fetch_sub(1, std::memory_order_acquire);
if (oldCount <= 0) {
if ((timeout_usecs == 0) || (timeout_usecs < 0 && !m_sema.wait()) ||
(timeout_usecs > 0 &&
!m_sema.timed_wait((std::uint64_t)timeout_usecs))) {
while (true) {
oldCount = m_count.load(std::memory_order_acquire);
if (oldCount >= 0 && m_sema.try_wait())
break;
if (oldCount < 0 &&
m_count.compare_exchange_strong(
oldCount,
oldCount + 1,
std::memory_order_relaxed,
std::memory_order_relaxed))
return 0;
}
}
}
if (max > 1)
return 1 + tryWaitMany(max - 1);
return 1;
}
public:
LightweightSemaphore(ssize_t initialCount = 0, int maxSpins = 10000)
: m_count(initialCount), m_maxSpins(maxSpins) {
assert(initialCount >= 0);
assert(maxSpins >= 0);
}
bool tryWait() {
ssize_t oldCount = m_count.load(std::memory_order_relaxed);
while (oldCount > 0) {
if (m_count.compare_exchange_weak(
oldCount,
oldCount - 1,
std::memory_order_acquire,
std::memory_order_relaxed))
return true;
}
return false;
}
bool wait() {
return tryWait() || waitWithPartialSpinning();
}
bool wait(std::int64_t timeout_usecs) {
return tryWait() || waitWithPartialSpinning(timeout_usecs);
}
// Acquires between 0 and (greedily) max, inclusive
ssize_t tryWaitMany(ssize_t max) {
assert(max >= 0);
ssize_t oldCount = m_count.load(std::memory_order_relaxed);
while (oldCount > 0) {
ssize_t newCount = oldCount > max ? oldCount - max : 0;
if (m_count.compare_exchange_weak(
oldCount,
newCount,
std::memory_order_acquire,
std::memory_order_relaxed))
return oldCount - newCount;
}
return 0;
}
// Acquires at least one, and (greedily) at most max
ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) {
assert(max >= 0);
ssize_t result = tryWaitMany(max);
if (result == 0 && max > 0)
result = waitManyWithPartialSpinning(max, timeout_usecs);
return result;
}
ssize_t waitMany(ssize_t max) {
ssize_t result = waitMany(max, -1);
assert(result > 0);
return result;
}
void signal(ssize_t count = 1) {
assert(count >= 0);
ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release);
ssize_t toRelease = -oldCount < count ? -oldCount : count;
if (toRelease > 0) {
m_sema.signal((int)toRelease);
}
}
std::size_t availableApprox() const {
ssize_t count = m_count.load(std::memory_order_relaxed);
return count > 0 ? static_cast<std::size_t>(count) : 0;
}
};
} // end namespace moodycamel

View File

@ -0,0 +1,140 @@
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
#include "torch/csrc/nativert/common/RecordFunction.h"
#include <c10/util/Logging.h>
#include "torch/csrc/nativert/executor/Weights.h"
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
namespace torch::nativert {
namespace {
template <typename T>
std::optional<at::ScalarType> parsePrecision(
const std::optional<T>& precision) {
if (precision) {
return static_cast<at::ScalarType>(*precision);
}
return std::nullopt;
}
} // namespace
AOTIDelegateExecutor::AOTIDelegateExecutor(
const std::string& path,
std::shared_ptr<Weights> weights,
c10::Device device,
const ExecutorConfig& executorConfig,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader) {
std::string aotInductorModelFileName = path + "/aotinductor_pickle_data.json";
LOG(INFO) << "Loading aotinductor model from archive path: "
<< aotInductorModelFileName;
CHECK(packageReader) << "Package reader cannot be null for lowered modules";
CHECK(packageReader->hasRecord(aotInductorModelFileName))
<< "Missing record " << aotInductorModelFileName;
const auto& [aotInductorModelData, aotInductorModelSize] =
packageReader->getRecord(aotInductorModelFileName);
const std::string aotInductorModelSerialized{
reinterpret_cast<char*>(aotInductorModelData.get()),
aotInductorModelSize};
LOG(INFO) << "Loaded aot_inductor_model: " << aotInductorModelSerialized;
auto aotInductorModel =
nlohmann::json::parse(aotInductorModelSerialized)
.template get<torch::_export::AOTInductorModelPickleData>();
std::string tmpDir = extractToTemporaryFolder(packageReader, path);
LOG(INFO) << "Extracted aot_inductor model to: " << tmpDir;
std::string modelName = aotInductorModel.get_library_basename();
std::string modelPath = tmpDir + "/" + modelName;
std::string externKernelNodesPath =
tmpDir + "/" + modelName.substr(0, modelName.size() - 3) + ".json";
LOG(INFO) << "Creating AOTInductorModelImpl with device " << device.str();
// We have to read the custom_objs_config.json,
// because the weights->customObjs_ keys are not the same as the arg names
// in the externKernelNodesPath.
std::string customObjsJsonPath = tmpDir + "/custom_objs_config.json";
std::ifstream customObjsJsonFile(customObjsJsonPath);
std::unordered_map<std::string, c10::IValue> custom_objs;
if (!customObjsJsonFile.is_open()) {
// BC-compatible with old files that don't have custom_objs_config.json
LOG(INFO) << "Unable to open file " + customObjsJsonPath;
} else {
LOG(INFO) << "Load custom object mapping from: " << customObjsJsonPath;
nlohmann::json customObjsJson;
customObjsJsonFile >> customObjsJson;
// Populate custom_objs with the custom object names from the json file,
// and the c10::IValue from the weights.
for (auto& [customObjName, file_name] : customObjsJson.items()) {
custom_objs[customObjName] = weights->getCustomObjByFileName(
std::string(archive_spec::CONSTANTS_DIR) +
file_name.get<std::string>());
LOG(INFO) << "Copy custom object to FbProxyExecutor: " << customObjName
<< " from " << file_name;
}
}
aotInductorModelImpl_ =
std::make_unique<torch::aot_inductor::AOTInductorModelImpl>(
modelPath,
tmpDir,
aotInductorModel.get_input_names(),
aotInductorModel.get_output_names(),
parsePrecision(aotInductorModel.get_floating_point_input_dtype()),
parsePrecision(aotInductorModel.get_floating_point_output_dtype()),
externKernelNodesPath,
device.str(),
/*num_runtimes*/ executorConfig.maxNumConcurrentThreads,
/*custom_objs*/ std::move(custom_objs));
auto constantInfos = aotInductorModelImpl_->getConstantInfos();
for (const auto& [name, constantInfo] : constantInfos) {
if (weights->contains(constantInfo.originalFqn)) {
weightsNameMap_[constantInfo.originalFqn] = name;
} else {
LOG(WARNING)
<< "AOTI's Constant " << constantInfo.originalFqn
<< " is not found in weights, it's likely a constant created by AOTI constant folding. "
<< "Valid weight FQNs are " << weights->toString();
}
}
// AOTI's DelegateExecutor doesn't need to call processWeights or
// commitWeights here because it's invoked from Executor's ctor already.
}
void AOTIDelegateExecutor::processWeights(std::shared_ptr<Weights> weights) {
LOG(INFO) << "AOTIDelegateExecutor processing weights";
std::unordered_map<std::string, torch::Tensor*> newWeights;
for (const auto& [original_fqn, name] : weightsNameMap_) {
newWeights.emplace(name, &weights->at(original_fqn));
}
aotInductorModelImpl_->updateInactiveConstantBuffer(std::move(newWeights));
aotInductorModelImpl_->runConstantFolding(/*use_inactive*/ true);
}
void AOTIDelegateExecutor::commitWeights() {
LOG(INFO) << "AOTIDelegateExecutor committing weights";
aotInductorModelImpl_->swapConstantBuffers();
}
std::vector<at::Tensor> AOTIDelegateExecutor::run(
std::vector<at::Tensor>& inputs) {
RecordFunction func("nativert::AOTIDelegateExecutor::run");
std::vector<at::Tensor> outputs = aotInductorModelImpl_->forward(inputs);
return outputs;
}
} // namespace torch::nativert

View File

@ -0,0 +1,37 @@
#pragma once
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
#include "torch/csrc/nativert/executor/AOTInductorModelImpl.h" // @manual=//sigmoid/core/executor:aoti_model_impl
namespace torch::nativert {
class Weights;
class Node;
class AOTIDelegateExecutor : public DelegateExecutor {
public:
explicit AOTIDelegateExecutor(
const std::string& path,
std::shared_ptr<Weights> weights,
c10::Device device,
const ExecutorConfig& executorConfig,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader);
~AOTIDelegateExecutor() override {}
void processWeights(std::shared_ptr<Weights> weights) override;
void commitWeights() override;
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) override;
private:
std::unique_ptr<torch::aot_inductor::AOTInductorModelImpl>
aotInductorModelImpl_;
// key is weight's original fqn, value is weight's name in AOTI
std::unordered_map<std::string, std::string> weightsNameMap_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,605 @@
#include "torch/csrc/nativert/executor/AOTInductorModelImpl.h" // @manual
// TODO Always use OSS proxy executor.
#ifdef FBCODE_CAFFE2
#include "deeplearning/aot_inductor/fb/FbProxyExecutor.h"
#else
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h> // @manual
#endif
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h> // @manual
#include <type_traits>
#include <dlfcn.h>
#include <cstdlib> // for getenv
#include <filesystem>
#include <sstream>
#include <string>
#include <vector>
#include "ATen/Context.h" // @manual
#if defined(__SIGRID_USE_GPU__)
#include "ATen/cuda/CUDAContext.h" // @manual
#include "c10/cuda/CUDAStream.h" // @manual
#endif // __SIGRID_USE_GPU__
#include "torch/csrc/nativert/common/FileUtil.h"
namespace torch::aot_inductor {
namespace {
template <typename T>
struct GetLastArgType;
template <typename T>
struct tag {
using type = T;
};
template <typename Function, typename... Args>
struct GetLastArgType<Function(Args...)> {
using last_arg_type = typename decltype((tag<Args>{}, ...))::type;
};
template <typename T>
struct AOTInductorCallImpl;
template <typename... Args>
struct AOTInductorCallImpl<
AOTIRuntimeError(AOTInductorModelContainerHandle*, Args...)> {
// Special version for ModelContainer creation
void operator()(
AOTIRuntimeError (*f)(AOTInductorModelContainerHandle*, Args...),
AOTInductorModelContainerHandle* handle,
Args... args) {
AOTI_RUNTIME_ERROR_CODE_CHECK(f(handle, args...));
}
};
template <typename... Args>
struct AOTInductorCallImpl<
AOTIRuntimeError(AOTInductorModelContainerHandle, Args...)> {
using Function = AOTIRuntimeError(AOTInductorModelContainerHandle, Args...);
template <typename... ArgsWithoutLastArgument>
auto operator()(
Function* f,
AOTInductorModelContainerHandle handle,
ArgsWithoutLastArgument... args) {
std::remove_pointer_t<typename GetLastArgType<Function>::last_arg_type>
result;
AOTI_RUNTIME_ERROR_CODE_CHECK(f(handle, args..., &result));
return result;
}
void operator()(
AOTIRuntimeError (*f)(AOTInductorModelContainerHandle, Args...),
AOTInductorModelContainerHandle handle,
Args... args) {
AOTI_RUNTIME_ERROR_CODE_CHECK(f(handle, args...));
}
};
template <typename Function, typename... Args>
auto AOTInductorCall(
Function* f,
AOTInductorModelContainerHandle handle,
Args... args) {
return AOTInductorCallImpl<Function>()(f, handle, args...);
}
template <typename Function>
auto AOTInductorCallCreate(
Function* f,
AOTInductorModelContainerHandle* handle,
size_t num_runtimes,
bool is_cpu,
const char* cubin_dir) {
return AOTInductorCallImpl<Function>()(
f, handle, num_runtimes, is_cpu, cubin_dir);
}
template <typename Function>
auto AOTInductorCallCreateWithDevice(
Function* f,
AOTInductorModelContainerHandle* handle,
size_t num_runtimes,
const char* device_str,
const char* cubin_dir) {
return AOTInductorCallImpl<Function>()(
f, handle, num_runtimes, device_str, cubin_dir);
}
std::string getFileBasename(const std::string& filename) {
const auto slash = filename.rfind('/');
return slash != std::string::npos ? filename.substr(slash + 1) : filename;
}
// TODO: can we simply use std::filesystem::exists?
inline bool fileExists(const std::string& name) {
const auto fd =
torch::nativert::openNoInt(name.c_str(), O_RDONLY | O_CLOEXEC);
if (fd == -1) {
return false;
}
torch::nativert::closeNoInt(fd);
return true;
}
std::unique_ptr<ProxyExecutor> makeProxyExecutor(
const std::string& filename,
bool is_cpu,
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs) {
#ifdef FBCODE_CAFFE2
return std::make_unique<FbProxyExecutor>(
filename, is_cpu, std::move(custom_objs));
#else
return std::make_unique<OSSProxyExecutor>(filename, is_cpu);
#endif
}
} // namespace
// private static
std::vector<std::string> AOTInductorModelImpl::library_search_paths_;
AOTInductorModelImpl::AOTInductorModelImpl(
const std::string& model_path,
std::optional<std::string> cubin_dir,
std::vector<std::string> input_names,
std::vector<std::string> output_names,
std::optional<at::ScalarType> input_dtype,
std::optional<at::ScalarType> output_dtype,
std::optional<std::string> extern_kernel_nodes_path,
const std::string& device_str,
int64_t num_runtimes,
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs)
: // handle_(dlopen(model_path.c_str(), RTLD_LAZY | RTLD_LOCAL)),
libraryBasename_(getFileBasename(model_path)),
libraryPath_(model_path),
inputNames_(std::move(input_names)),
outputNames_(std::move(output_names)),
floatingPointInputDtype_(input_dtype),
floatingPointOutputDtype_(output_dtype),
deviceStr_(device_str) {
LOG(INFO) << "Loading .so lib from " << model_path
<< " onto device: " << device_str;
handle_.reset(dlopen(model_path.c_str(), RTLD_NOW | RTLD_LOCAL));
TORCH_CHECK(
handle_ != nullptr, "could not dlopen ", model_path, ": ", dlerror());
TORCH_CHECK(num_runtimes > 0, "num_runtimes must be positive");
if (extern_kernel_nodes_path.has_value()) {
const std::string& filename = extern_kernel_nodes_path.value();
if (fileExists(filename)) {
LOG(INFO) << "Loading extern_kernel_nodes .json file from " << filename;
proxyExecutor_ =
makeProxyExecutor(filename, is_cpu(), std::move(custom_objs));
}
}
#if defined(__SIGRID_USE_GPU__)
// It's not clear what stream we want to use yet. Create a new one.
// We could alternatively use the default stream, but that could cause extra
// synchronization.
using StreamGuard = std::unique_ptr<
std::remove_pointer_t<cudaStream_t>,
decltype(&cudaStreamDestroy)>;
std::optional<StreamGuard> creation_stream_guard = [&] {
if (is_cpu()) {
return std::optional<StreamGuard>();
}
cudaStream_t creation_stream;
TORCH_CHECK(
cudaStreamCreateWithFlags(&creation_stream, cudaStreamNonBlocking) ==
cudaSuccess);
return std::make_optional<StreamGuard>(creation_stream, cudaStreamDestroy);
}();
#endif // __SIGRID_USE_GPU__
#define LOAD_SYMBOL(var, name_str) \
var = reinterpret_cast<decltype(var)>(dlsym(handle_.get(), name_str)); \
TORCH_CHECK(var, "could not dlsym " name_str);
LOAD_SYMBOL(deleteFunc_, "AOTInductorModelContainerDelete");
LOAD_SYMBOL(runFunc_, "AOTInductorModelContainerRun");
LOAD_SYMBOL(getOutputNameFunc_, "AOTInductorModelContainerGetOutputName");
LOAD_SYMBOL(getCallSpecFunc_, "AOTInductorModelContainerGetCallSpec");
// We never call these functions again after the constructor returns, so
// there's no point in caching them in member variables.
decltype(&AOTInductorModelContainerCreate) createFunc;
decltype(&AOTInductorModelContainerGetInputName) getInputNameFunc;
decltype(&AOTInductorModelContainerGetNumInputs) getNumInputsFunc;
decltype(&AOTInductorModelContainerGetNumOutputs) getNumOutputsFunc;
LOAD_SYMBOL(createFunc, "AOTInductorModelContainerCreate");
LOAD_SYMBOL(getInputNameFunc, "AOTInductorModelContainerGetInputName");
LOAD_SYMBOL(getNumInputsFunc, "AOTInductorModelContainerGetNumInputs");
LOAD_SYMBOL(getNumOutputsFunc, "AOTInductorModelContainerGetNumOutputs");
#undef LOAD_SYMBOL
#define LOAD_SYMBOL_WARN(var, name_str) \
var = reinterpret_cast<decltype(var)>(dlsym(handle_.get(), name_str)); \
if (!var) { \
LOG(WARNING) << "Could not dlsym " << name_str; \
}
// "AOTInductorModelContainerCreateWithDevice" is only available in the binary
// compiled after Jan.15.2024
decltype(&AOTInductorModelContainerCreateWithDevice) createFuncWithDevice;
LOAD_SYMBOL_WARN(
createFuncWithDevice, "AOTInductorModelContainerCreateWithDevice");
LOAD_SYMBOL_WARN(
getNumConstantsFunc_, "AOTInductorModelContainerGetNumConstants");
LOAD_SYMBOL_WARN(
getConstantNameFunc_, "AOTInductorModelContainerGetConstantName");
LOAD_SYMBOL_WARN(
getConstantOriginalFQNFunc_,
"AOTInductorModelContainerGetConstantOriginalFQN");
LOAD_SYMBOL_WARN(
getConstantFromFoldedFunc_,
"AOTInductorModelContainerGetConstantFromFolded");
LOAD_SYMBOL_WARN(
getConstantTypeFunc_, "AOTInductorModelContainerGetConstantType");
LOAD_SYMBOL_WARN(
getConstantDtypeFunc_, "AOTInductorModelContainerGetConstantDtype");
LOAD_SYMBOL_WARN(
runConstantFoldingFunc_, "AOTInductorModelContainerRunConstantFolding");
LOAD_SYMBOL_WARN(
updateConstantBufferFunc_,
"AOTInductorModelContainerUpdateConstantBuffer");
LOAD_SYMBOL_WARN(
updateInactiveConstantBufferFunc_,
"AOTInductorModelContainerUpdateInactiveConstantBuffer");
LOAD_SYMBOL_WARN(
swapConstantBufferFunc_, "AOTInductorModelContainerSwapConstantBuffer");
#undef LOAD_SYMBOL_WARN
if (createFuncWithDevice) {
AOTInductorCallCreateWithDevice(
createFuncWithDevice,
&containerHandle_,
num_runtimes,
deviceStr_.c_str(),
cubin_dir ? cubin_dir->c_str() : nullptr);
} else {
AOTInductorCallCreate(
createFunc,
&containerHandle_,
num_runtimes,
is_cpu(),
cubin_dir ? cubin_dir->c_str() : nullptr);
}
const auto num_inputs = AOTInductorCall(getNumInputsFunc, containerHandle_);
const auto num_outputs = AOTInductorCall(getNumOutputsFunc, containerHandle_);
TORCH_CHECK(
inputNames_.size() == num_inputs,
"the size of input_names is ",
inputNames_.size(),
", but the model expects ",
num_inputs);
TORCH_CHECK(
outputNames_.size() == num_outputs,
"the size of output_names is ",
outputNames_.size(),
", but the model expects ",
num_outputs);
for (const auto idx : c10::irange(num_inputs)) {
inputNameToIndex_.emplace(
AOTInductorCall(getInputNameFunc, containerHandle_, idx), idx);
}
for (const auto idx : c10::irange(num_outputs)) {
outputNameToIndex_.emplace(
AOTInductorCall(getOutputNameFunc_, containerHandle_, idx), idx);
}
}
std::vector<torch::Tensor> AOTInductorModelImpl::processInputs(
std::vector<torch::Tensor>& python_inputs) {
RECORD_USER_SCOPE("AOTInductorModel::ProcessInputs");
const auto num_inputs = inputNameToIndex_.size();
TORCH_CHECK(
python_inputs.size() == num_inputs,
"User passed ",
python_inputs.size(),
" inputs, but the model expects ",
num_inputs);
std::vector<torch::Tensor> inputs(python_inputs.size());
for (int python_input_idx = 0; python_input_idx < inputNames_.size();
python_input_idx++) {
auto input_name = inputNames_[python_input_idx];
auto& input = python_inputs[python_input_idx];
if (floatingPointInputDtype_ != std::nullopt && input.is_floating_point()) {
// Need to keep input alive; cannot just stash result of to()
// call in a local!
input = input.to(*floatingPointInputDtype_);
}
// FIXME: get currect aot_input_idx once we figure out name-mapping
// in AOTInductor.
// Currently, we have strong assumption that python_input_idx
// (fx inputs) is the same as aot_input_idx.
// const auto aot_input_idx = input_name_to_index_.at(input_name);
const auto aot_input_idx = python_input_idx;
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
inputs[aot_input_idx] = input;
}
return inputs;
}
std::vector<torch::Tensor> AOTInductorModelImpl::processOutputs(
std::vector<torch::Tensor>&& outputs) {
if (floatingPointOutputDtype_.has_value()) {
for (auto& output : outputs) {
if (output.is_floating_point()) {
output = output.to(*floatingPointOutputDtype_);
}
}
}
return std::move(outputs);
}
std::vector<torch::Tensor> AOTInductorModelImpl::forward(
std::vector<torch::Tensor>& python_inputs) {
RECORD_USER_SCOPE("AOTInductorModel::Forward");
TORCH_CHECK(!python_inputs.empty());
std::vector<torch::Tensor> input_tensors = processInputs(python_inputs);
// For outputs, we only allocate a vector to hold returned tensor handles,
// not allocating the actual output tensor storage here
const auto num_outputs = outputNameToIndex_.size();
std::vector<AtenTensorHandle> output_handles(num_outputs);
{
auto input_handles =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
input_tensors);
#if defined(__SIGRID_USE_GPU__)
const auto device = python_inputs[0].device();
AOTInductorStreamHandle stream_handle = is_cpu() ? nullptr : [&] {
const auto& cuda_stream = at::cuda::getCurrentCUDAStream(device.index());
const auto stream_id = cuda_stream.stream();
return reinterpret_cast<AOTInductorStreamHandle>(stream_id);
}();
#else
AOTInductorStreamHandle stream_handle = nullptr;
#endif // __SIGRID_USE_GPU__
AOTIProxyExecutorHandle proxy_executor_handle =
reinterpret_cast<AOTIProxyExecutorHandle>(proxyExecutor_.get());
RECORD_USER_SCOPE("AOTInductorModel::AOTInductorRuntime");
AOTIRuntimeError run_result = runFunc_(
containerHandle_,
input_handles.data(),
input_tensors.size(),
output_handles.data(),
output_handles.size(),
stream_handle,
proxy_executor_handle);
if (run_result != AOTI_RUNTIME_SUCCESS) {
std::stringstream ss;
ss << "AOTInductorModel run failed with input spec: ";
for (const auto& i : python_inputs) {
ss << i.sizes() << ":" << i.dtype() << ", ";
}
TORCH_CHECK(false, ss.str());
}
return processOutputs(
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
output_handles.data(), output_handles.size()));
}
}
std::vector<const char*> AOTInductorModelImpl::get_call_spec() {
std::vector<const char*> call_spec = {nullptr, nullptr};
getCallSpecFunc_(containerHandle_, call_spec.data(), &call_spec[1]);
return call_spec;
}
std::unordered_map<std::string, ConstantInfo>
AOTInductorModelImpl::getConstantInfos() const {
TORCH_CHECK(
getNumConstantsFunc_, "getNumConstantsFunc_ was not loaded from .so");
TORCH_CHECK(
getConstantNameFunc_, "getConstantNameFunc_ was not loaded from .so");
TORCH_CHECK(
getConstantOriginalFQNFunc_,
"getConstantOriginalFQNFunc_ was not loaded from .so");
TORCH_CHECK(
getConstantDtypeFunc_, "getConstantDtypeFunc_ was not loaded from .so");
std::unordered_map<std::string, ConstantInfo> result;
auto num_constants = AOTInductorCall(getNumConstantsFunc_, containerHandle_);
for (size_t i = 0; i < num_constants; ++i) {
const auto name =
AOTInductorCall(getConstantNameFunc_, containerHandle_, i);
const auto original_fqn =
AOTInductorCall(getConstantOriginalFQNFunc_, containerHandle_, i);
const auto dtype =
AOTInductorCall(getConstantDtypeFunc_, containerHandle_, i);
ConstantType constant_type = ConstantType::Unknown;
if (getConstantTypeFunc_) {
constant_type = static_cast<ConstantType>(
AOTInductorCall(getConstantTypeFunc_, containerHandle_, i));
}
if (getConstantFromFoldedFunc_ &&
AOTInductorCall(getConstantFromFoldedFunc_, containerHandle_, i)) {
continue;
}
TORCH_CHECK(original_fqn, "Cannot find orignal FQN of constant ", name);
result.emplace(
name,
ConstantInfo{
static_cast<at::ScalarType>(dtype), original_fqn, constant_type});
}
return result;
}
void AOTInductorModelImpl::runConstantFolding(bool use_inactive) {
if (!runConstantFoldingFunc_) {
// We will just return if runtime constant folding doesn't exist.
// Only models compiled after 2024 Feb has such capability.
return;
}
#if defined(__SIGRID_USE_GPU__)
AOTInductorStreamHandle stream_handle = is_cpu() ? nullptr : [&] {
const auto& cuda_stream = at::cuda::getCurrentCUDAStream();
const auto stream_id = cuda_stream.stream();
return reinterpret_cast<AOTInductorStreamHandle>(stream_id);
}();
#else
AOTInductorStreamHandle stream_handle = nullptr;
#endif // __SIGRID_USE_GPU__
AOTIProxyExecutorHandle proxy_executor_handle =
reinterpret_cast<AOTIProxyExecutorHandle>(proxyExecutor_.get());
auto result = runConstantFoldingFunc_(
containerHandle_, use_inactive, stream_handle, proxy_executor_handle);
TORCH_CHECK(
result == AOTI_RUNTIME_SUCCESS, "Unable to run constant folding.");
}
void AOTInductorModelImpl::updateConstantBuffer(
std::unordered_map<std::string, torch::Tensor*>&& constants,
bool use_inactive,
bool validate_full_update) {
TORCH_CHECK(
updateConstantBufferFunc_,
"updateConstantBufferFunc_ was not loaded from .so");
auto result = updateConstantBufferFunc_(
containerHandle_,
(AOTInductorConstantMapHandle)&constants,
use_inactive,
validate_full_update);
TORCH_CHECK(
result == AOTI_RUNTIME_SUCCESS, "Unable to update constant buffer");
}
void AOTInductorModelImpl::updateInactiveConstantBuffer(
std::unordered_map<std::string, torch::Tensor*>&& constants) {
TORCH_CHECK(
updateInactiveConstantBufferFunc_,
"updateInactiveConstantBufferFunc_ was not loaded from .so");
auto result = updateInactiveConstantBufferFunc_(
containerHandle_, (AOTInductorConstantMapHandle)&constants);
TORCH_CHECK(
result == AOTI_RUNTIME_SUCCESS,
"Unable to update inactive constant buffer");
}
void AOTInductorModelImpl::swapConstantBuffers() {
TORCH_CHECK(
swapConstantBufferFunc_,
"swapConstantBufferFunc_ was not loaded from .so");
auto result = swapConstantBufferFunc_(containerHandle_);
TORCH_CHECK(
result == AOTI_RUNTIME_SUCCESS, "Unable to swap constant buffers");
}
thread_local std::unordered_map<std::string, std::string>
AOTInductorModelImpl::lib_name_to_path_;
thread_local bool AOTInductorModelImpl::deserialize_pickled_model_{true};
thread_local std::optional<std::string> AOTInductorModelImpl::cubin_dir_;
thread_local std::unordered_map<std::string, std::string>
AOTInductorModelImpl::extern_kernels_spec_name_to_path_;
void AOTInductorModelImpl::registerLibraryNameToPathMap(
std::unordered_map<std::string, std::string> map) {
std::ostringstream ss;
ss << "{\n";
for (const auto& [k, v] : map) {
ss << " " << k << " => " << v << ",\n";
}
ss << "}";
LOG(INFO) << "Registering .so lib paths: " << ss.str();
lib_name_to_path_ = std::move(map);
}
std::string AOTInductorModelImpl::getFullPathForLibraryName(
const std::string& name) {
auto path = lib_name_to_path_.find(name);
std::ostringstream ss;
ss << "{\n";
for (const auto& [k, v] : lib_name_to_path_) {
ss << " " << k << " => " << v << ",\n";
}
if ((path == lib_name_to_path_.end()) ||
(!std::filesystem::exists(path->second))) {
for (const auto& lib_path : library_search_paths_) {
std::string fullpath =
lib_path + std::filesystem::path::preferred_separator + name;
if (std::filesystem::exists(fullpath)) {
return fullpath;
}
ss << " searched for " << name << " at " << lib_path << ",\n";
}
}
ss << "}";
TORCH_CHECK(
path != lib_name_to_path_.end(),
"could not find full path for AOTInductor model .so named ",
name,
". available paths: ",
ss.str());
return path->second;
}
void AOTInductorModelImpl::setCubinDir(std::optional<std::string> cubin_dir) {
cubin_dir_ = cubin_dir;
}
std::optional<std::string> AOTInductorModelImpl::getCubinDir() {
return cubin_dir_;
}
void AOTInductorModelImpl::registerExternKernelsSpecNameToPathMap(
std::unordered_map<std::string, std::string> map) {
std::ostringstream ss;
ss << "{\n";
for (const auto& [k, v] : map) {
ss << " " << k << " => " << v << ",\n";
}
ss << "}";
LOG(INFO) << "Registering extern kernels spec paths: " << ss.str();
extern_kernels_spec_name_to_path_ = std::move(map);
}
std::optional<std::string>
AOTInductorModelImpl::getFullPathForExternKernelsSpecName(
const std::string& name) {
auto it = extern_kernels_spec_name_to_path_.find(name);
if (it == extern_kernels_spec_name_to_path_.end()) {
LOG(INFO) << "Didn't find extern kernels spec file for " << name;
return {};
}
if (!std::filesystem::exists(it->second)) {
TORCH_CHECK(false, "Extern kernels spec file doesn't exist: ", it->second);
}
return it->second;
}
bool AOTInductorModelImpl::getDeserializePickledModel() {
return deserialize_pickled_model_;
}
// Set thread local boolean to disable real loading from .so file
// for reusing the same module later on
void AOTInductorModelImpl::setDeserializePickledModel(
bool deserializePickledModel) {
deserialize_pickled_model_ = deserializePickledModel;
}
} // namespace torch::aot_inductor

View File

@ -0,0 +1,201 @@
#pragma once
#include <dlfcn.h>
#include <torch/csrc/inductor/aoti_runtime/interface.h> // @manual
#include <torch/csrc/inductor/aoti_runtime/model.h> // @manual
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h> // @manual
#include <torch/torch.h> // @manual=//caffe2:torch-cpp
#include <memory>
#include <optional>
#include "c10/util/FbcodeMaps.h"
namespace torch::aot_inductor {
struct ConstantInfo {
at::ScalarType dtype;
std::string originalFqn;
ConstantType type;
};
class AOTInductorModelImpl {
public:
explicit AOTInductorModelImpl(
const std::string& model_path,
std::optional<std::string> cubin_dir,
std::vector<std::string> input_names,
std::vector<std::string> output_names,
std::optional<at::ScalarType> input_dtype,
std::optional<at::ScalarType> output_dtype,
std::optional<std::string> extern_kernel_nodes_path,
const std::string& device_str,
int64_t num_runtimes = 2,
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs =
std::nullopt);
~AOTInductorModelImpl() {
if (containerHandle_) {
deleteFunc_(containerHandle_);
}
}
std::vector<torch::Tensor> forward(std::vector<torch::Tensor>& inputs);
std::vector<const char*> get_call_spec();
std::unordered_map<std::string, ConstantInfo> getConstantInfos() const;
void updateConstantBuffer(
std::unordered_map<std::string, torch::Tensor*>&& constants,
bool use_inactive,
bool validate_full_update);
void updateInactiveConstantBuffer(
std::unordered_map<std::string, torch::Tensor*>&& constants);
void runConstantFolding(bool use_inactive);
void swapConstantBuffers();
void profile(
std::vector<torch::Tensor>& inputs,
const std::string& filename,
size_t num_iters);
// If we need to move or copy this object, then we should just
// define a unique_ptr with deleter for the handle.
AOTInductorModelImpl(const AOTInductorModelImpl&) = delete;
AOTInductorModelImpl& operator=(const AOTInductorModelImpl&) = delete;
static void registerLibraryNameToPathMap(
std::unordered_map<std::string, std::string> map);
static std::string getFullPathForLibraryName(const std::string& name);
static void setCubinDir(std::optional<std::string> cubin_dir);
static std::optional<std::string> getCubinDir();
static void registerExternKernelsSpecNameToPathMap(
std::unordered_map<std::string, std::string> mapping);
static std::optional<std::string> getFullPathForExternKernelsSpecName(
const std::string& name);
static bool getDeserializePickledModel();
static void setDeserializePickledModel(bool deserializePickledModel);
/*
* Returns a path to .so file (either relative or absolute).
*/
const std::string& libraryPath() const {
return libraryPath_;
}
const std::string& libraryBasename() const {
return libraryBasename_;
}
const std::vector<std::string>& inputNames() const {
return inputNames_;
}
const std::vector<std::string>& outputNames() const {
return outputNames_;
}
const std::optional<at::ScalarType> floatingPointInputDtype() const {
return floatingPointInputDtype_;
}
const std::optional<at::ScalarType> floatingPointOutputDtype() const {
return floatingPointOutputDtype_;
}
static void add_library_search_path(const std::string& path) {
library_search_paths_.push_back(path);
}
bool is_cpu() const {
return deviceStr_ == "cpu";
}
private:
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
static std::vector<std::string> library_search_paths_;
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
static thread_local std::unordered_map<std::string, std::string>
lib_name_to_path_;
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
static thread_local std::optional<std::string> cubin_dir_;
/*
* Example:
* {
* "aaa.json": "/tmp/abcdef/aaa.json",
* "bbb.json": "/tmp/abcdef/bbb.json",
* }
*/
// @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration
static thread_local std::unordered_map<std::string, std::string>
extern_kernels_spec_name_to_path_;
static thread_local bool deserialize_pickled_model_;
struct DlcloseDeleter {
void operator()(void* p) const {
if (p) {
dlclose(p);
}
}
};
std::vector<torch::Tensor> processInputs(
std::vector<torch::Tensor>& python_inputs);
std::vector<torch::Tensor> processOutputs(
std::vector<torch::Tensor>&& outputs);
std::unique_ptr<void, DlcloseDeleter> handle_ = nullptr;
AOTInductorModelContainerHandle containerHandle_;
decltype(&AOTInductorModelContainerDelete) deleteFunc_ = nullptr;
decltype(&AOTInductorModelContainerRun) runFunc_ = nullptr;
decltype(&AOTInductorModelContainerGetOutputName) getOutputNameFunc_ =
nullptr;
decltype(&AOTInductorModelContainerGetCallSpec) getCallSpecFunc_ = nullptr;
decltype(&AOTInductorModelContainerGetNumConstants) getNumConstantsFunc_{
nullptr};
decltype(&AOTInductorModelContainerGetConstantName) getConstantNameFunc_{
nullptr};
decltype(&AOTInductorModelContainerGetConstantOriginalFQN)
getConstantOriginalFQNFunc_{nullptr};
decltype(&AOTInductorModelContainerGetConstantFromFolded)
getConstantFromFoldedFunc_{nullptr};
decltype(&AOTInductorModelContainerGetConstantType) getConstantTypeFunc_{
nullptr};
decltype(&AOTInductorModelContainerGetConstantDtype) getConstantDtypeFunc_{
nullptr};
decltype(&AOTInductorModelContainerRunConstantFolding)
runConstantFoldingFunc_{nullptr};
decltype(&AOTInductorModelContainerUpdateConstantBuffer)
updateConstantBufferFunc_{nullptr};
decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer)
updateInactiveConstantBufferFunc_{nullptr};
decltype(&AOTInductorModelContainerSwapConstantBuffer)
swapConstantBufferFunc_{nullptr};
const std::string libraryBasename_;
const std::string libraryPath_;
const std::vector<std::string> inputNames_;
const std::vector<std::string> outputNames_;
const std::optional<at::ScalarType> floatingPointInputDtype_;
const std::optional<at::ScalarType> floatingPointOutputDtype_;
c10::FastMap<const char*, size_t> inputNameToIndex_;
c10::FastMap<const char*, size_t> outputNameToIndex_;
std::unique_ptr<ProxyExecutor> proxyExecutor_;
std::string deviceStr_;
};
} // namespace torch::aot_inductor

View File

@ -0,0 +1,166 @@
#include "torch/csrc/nativert/executor/ConstantFolder.h"
#include <algorithm>
#include <queue>
#include "torch/csrc/nativert/common/Enumerate.h"
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
#include "torch/csrc/nativert/executor/Weights.h"
namespace torch::nativert {
/*
side effects:
1. nodes deemed const-foldable nodes are unlinked from the graph.
they are still owned by the graph (i.e., show up in graph.nodeOwner_)
but are not accessible through the node iterator.
2. kernels associated with const-foldable nodes are removed from the
'kernels' input
3. mark values deemed foldable as such, removing thier producers
*/
void ConstantFolder::unlinkConstants(
std::vector<std::unique_ptr<OpKernel>>& kernels) {
TORCH_CHECK_EQ(kernels.size(), graph_.nodes().size())
<< "graph node count and kernel count should be equal";
unlinked_ = true;
/* resolve all of the nodes that are const foldable */
c10::FastMap<Node*, uint32_t> nodeDynInputs;
nodeDynInputs.reserve(graph_.nodes().size());
c10::FastMap<const Node*, std::unique_ptr<OpKernel>*> nodeKernels;
nodeKernels.reserve(graph_.nodes().size());
const auto* input = &*graph_.nodes().begin();
const auto* output = &*graph_.nodes().end();
{ // ignore prim.Input and prim.Output
auto ct = 0;
for (auto& n : graph_.nodes()) {
if (&n == input || &n == output) {
continue;
}
nodeDynInputs[&n] = n.numInputs();
nodeKernels[&n] = &kernels[++ct];
}
}
const auto& inputsToWeights = graph_.signature().inputsToWeights();
for (const auto& [inputName, weightName] : inputsToWeights) {
for (auto* user : graph_.getValue(inputName)->users()) {
if (user == input || user == output) {
continue;
}
nodeDynInputs[user] -= 1;
}
}
// set of foldable nodes for dedupe purposes
c10::FastSet<const Node*> foldable;
std::queue<Node*> constFoldableCandidates;
for (auto& [node, ct] : nodeDynInputs) {
if (ct++ /* will be decremented once dequeued */ == 0) {
constFoldableCandidates.push(node);
}
}
while (!constFoldableCandidates.empty()) {
auto* candidate = constFoldableCandidates.front();
constFoldableCandidates.pop();
if (auto& ct = nodeDynInputs[candidate]; --ct == 0) {
foldable.insert(candidate);
foldables_.push_back(Foldable{
.node = candidate, .kernel = std::move(*nodeKernels[candidate])});
candidate->unlink();
for (auto* user : candidate->users()) {
if (user == output) {
continue;
}
if (foldable.find(user) == foldable.end()) {
constFoldableCandidates.push(user);
}
}
for (auto* out : candidate->outputs()) {
auto* value = graph_.getValue(out->name());
value->setIsFolded();
// we only store folded values if there is a non-foldable user
if (const auto& users = value->users();
std::any_of(users.begin(), users.end(), [&](const auto* u) {
return foldable.find(u) == foldable.end();
})) {
foldedOutputValueIds_.insert(value->id());
}
}
}
}
for (const auto& f : foldables_) {
VLOG(1) << "Const-folded node: " << *f.node;
}
// remove moved (i.e., associated w/ const-folded nodes) kernels
// from the input kernel vector
kernels.erase(
std::remove_if(
kernels.begin(),
kernels.end(),
[](const auto& k) { return k == nullptr; }),
kernels.end());
graph_.renumberValues();
graph_.finalize();
graph_.lint();
return;
}
/*
side effects:
1. weights whose users are ONLY const-foldable nodes will be removed
from the 'weights' input
*/
void ConstantFolder::evaluate(Weights& weights) {
CHECK(unlinked_)
<< "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants";
weights.validateAllWeightsLoaded();
ExecutionFrame frame(graph_);
frame.setWeights(weights);
c10::FastMap<ValueId, c10::IValue> foldedValues;
for (const auto& f : foldables_) {
f.kernel->compute(frame);
for (auto&& [i, out] : nativert::enumerate(f.node->outputs())) {
if (foldedOutputValueIds_.find(out->id()) !=
foldedOutputValueIds_.end()) {
foldedValues[out->id()] = f.kernel->output(i, frame);
}
}
}
for (auto it = std::make_move_iterator(foldedValues.begin());
it != std::make_move_iterator(foldedValues.end());
++it) {
auto [v, iv] = std::move(*it);
weights.setConstFoldedValue(v, std::move(iv));
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,55 @@
#pragma once
#include <memory>
#include <vector>
#include "torch/csrc/nativert/executor/OpKernel.h"
#include "torch/csrc/nativert/executor/Weights.h"
#include "torch/csrc/nativert/graph/Graph.h"
namespace torch::nativert {
struct Foldable {
Node* node;
std::unique_ptr<OpKernel> kernel;
};
class ConstantFolder {
public:
explicit ConstantFolder(Graph& graph) : graph_(graph) {}
/*
1. identify nodes without dynamic inputs, mark as foldable
2. traverse the nodes deemed foldable as if they were being evaluated,
pushing nodes that become foldable after it's inputs were traversed.
unlink foldable nodes from the graph in the topological order in which
they were traversed, storing the node and its associated kernel (moved
from 'kernels') as a foldable in Constantfolder
*/
void unlinkConstants(
/* kernels for const-foldable nodes will be removed from this vector */
std::vector<std::unique_ptr<OpKernel>>& kernels);
/*
1. execute foldables_ on an execution frame initialized with the passed-in
weights, calling Weights::setConstFoldedValue if the folded value is
consumed by a non-foldable node
*/
void evaluate(Weights& weights);
private:
Graph& graph_;
// unlinked nodes sorted in their topological order
// s.t., they can be evaluated sequentially
std::vector<Foldable> foldables_;
bool unlinked_{false};
c10::FastSet<ValueId> foldedOutputValueIds_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,52 @@
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
#include <unistd.h>
#include "c10/util/Logging.h"
#include "torch/csrc/nativert/common/FileUtil.h"
#include "torch/csrc/nativert/common/String.h"
namespace torch::nativert {
std::string extractToTemporaryFolder(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader,
const std::string& targetPath) {
char outputDir[] = "/tmp/model_XXXXXX";
char* tempdir = mkdtemp(outputDir);
TORCH_CHECK(
tempdir != nullptr,
"error creating temporary directory for compiled model. errno: ",
errno);
std::vector<std::string> allRecords = packageReader->getAllRecords();
for (const auto& path : allRecords) {
if (!starts_with(path, targetPath) || ends_with(path, "/")) {
continue;
}
TORCH_CHECK(
packageReader->hasRecord(path), path, " not present in model package");
auto [dataPointer, dataSize] = packageReader->getRecord(path);
std::string fileName = path.substr(path.rfind('/') + 1);
std::string extractedFilename = std::string(outputDir) + "/" + fileName;
VLOG(1) << "Extracting " << extractedFilename
<< " from archive path: " << path << " size: " << dataSize;
File extracted(extractedFilename, O_CREAT | O_WRONLY, 0640);
const auto bytesWritten =
writeFull(extracted.fd(), dataPointer.get(), dataSize);
TORCH_CHECK(
bytesWritten != -1,
"failure copying from archive path ",
path,
" to temporary file");
}
return std::string(outputDir);
}
} // namespace torch::nativert

View File

@ -0,0 +1,47 @@
#pragma once
#include <memory>
#include <vector>
#include <torch/script.h>
namespace torch::nativert {
class Weights;
std::string extractToTemporaryFolder(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageReader,
const std::string& targetPath);
// This is the extension point for delegation backends.
// Please refer to AOTIDelegateExecutor as an example.
class DelegateExecutor {
public:
virtual ~DelegateExecutor() {}
// Runtime calls processWeights() to pass the weights to the delegate backend.
// Typically, a backend would perform some form of validation and processing,
// such as constant folding. The processed weights stays in the inactivate
// state until commitWeights() is called.
//
// Weights tensors are co-owned by the runtime and the delegate backend.
// In the regular inference run() path, neither Runtime or Delegate backend
// can modify the weights tensor.
// To support inplace weight update, weight tensors are be exposed by
// ModelRunner::getWeights() to an external caller. The external caller can
// then modify the weight tensors in-place. Such mutation would instantly
// affect the weight tensors in the delegate backend.
// When a weight tensor is no longer used by the delegate backend, the backend
// must release it by decreasing a refcount. Runtime would
// also release the refcount for weight tensor if it's no longer activte. The
// underlying storage for weight tensors will be freed when the refcount
// reaches 0.
virtual void processWeights(std::shared_ptr<Weights> weights) = 0;
// This call activate the processed weights.
virtual void commitWeights() = 0;
virtual std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) = 0;
};
} // namespace torch::nativert

View File

@ -0,0 +1,141 @@
#include <c10/util/Logging.h>
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
namespace torch::nativert {
ExecutionFrame::ExecutionFrame(const Graph& graph)
: graph_(graph),
allValues_(graph.numValues()),
persistent_(graph.numValues()) {
// load constant SymInts into execution frame
for (const auto& [valueId, constSymintValue] :
graph_.getConstantSymIntValues()) {
setPersistentIValue(valueId, constSymintValue);
}
for (const Node& node : graph_.nodes()) {
if (node.target() == "torch.ops.higher_order.run_const_graph") {
const auto& const_graph =
std::get<std::unique_ptr<Graph>>(node.attributes().at(0).value);
for (size_t i = 0; i < node.outputs().size(); ++i) {
foldedConstIds_[std::string{const_graph->outputs().at(i)->name()}] =
node.outputs()[i]->id();
}
}
}
}
ExecutionFrame::ExecutionFrame(const Graph& graph, const Weights& weights)
: ExecutionFrame(graph) {
setWeights(weights);
}
void ExecutionFrame::setWeights(const Weights& weights) {
weightVersion_ = weights.version();
const auto& inputsToWeights = graph_.signature().inputsToWeights();
for (const auto& [inputName, weightName] : inputsToWeights) {
const Value* value = graph_.getValue(inputName);
setPersistentIValue(value->id(), weights.at(weightName));
}
const auto& inputsToCustomObjs = graph_.signature().inputsToCustomObjs();
for (const auto& [inputName, customObjName] : inputsToCustomObjs) {
const Value* value = graph_.getValue(inputName);
setPersistentIValue(value->id(), weights.getCustomObj(customObjName));
}
for (const auto& [value, tensor] : weights.getFoldedConsts()) {
setPersistentIValue(foldedConstIds_.at(value), tensor);
}
for (const auto& [v, iv] : weights.getConstFoldedValues()) {
setPersistentIValue(v, iv);
}
}
ExecutionFrame::ExecutionFrame(
const Graph& graph,
size_t numValues,
const std::vector<ValueId>&,
const std::vector<ValueId>&)
: graph_(graph) {
allValues_.resize(numValues);
}
void ExecutionFrame::setIValue(ValueId id, c10::IValue ivalue) {
DCHECK(id < allValues_.size());
allValues_[id] = std::move(ivalue);
}
at::Tensor ExecutionFrame::getTensor(ValueId id) const {
const auto& ivalue = getIValue(id);
if (C10_LIKELY(ivalue.isTensor())) {
return ivalue.toTensor();
} else {
throw std::runtime_error("getTensor called on non-tensor value");
}
}
std::vector<c10::IValue> ExecutionFrame::getUserOutputs() const {
std::vector<c10::IValue> ret;
ret.reserve(graph_.userOutputs().size());
for (const auto& outputValue : graph_.userOutputs()) {
if (std::holds_alternative<Value*>(outputValue)) {
Value* valuePtr = std::get<Value*>(outputValue);
if (valuePtr) {
const auto& id = valuePtr->id();
ret.push_back(getIValue(id));
}
} else if (std::holds_alternative<Constant>(outputValue)) {
const Constant& constValue = std::get<Constant>(outputValue);
ret.push_back(constantToIValue(constValue));
}
}
return ret;
}
c10::List<c10::IValue> ExecutionFrame::getUserOutputsAsTensorList() const {
c10::List<c10::IValue> ret(c10::TensorType::get());
ret.reserve(graph_.userOutputs().size());
for (const auto& outputValue : graph_.userOutputs()) {
if (std::holds_alternative<Value*>(outputValue)) {
Value* valuePtr = std::get<Value*>(outputValue);
if (valuePtr) {
const auto& id = valuePtr->id();
ret.push_back(getIValue(id));
}
} else if (std::holds_alternative<Constant>(outputValue)) {
const Constant& constValue = std::get<Constant>(outputValue);
ret.push_back(constantToIValue(constValue));
}
}
return ret;
}
std::unordered_map<std::string, at::Tensor> ExecutionFrame::getAllOutputs()
const {
std::unordered_map<std::string, at::Tensor> ret;
for (const auto& outputValue : graph_.outputs()) {
const auto& name = outputValue->name();
const auto& id = outputValue->id();
ret.emplace(name, getTensor(id));
}
return ret;
}
std::unordered_map<std::string, at::Tensor> ExecutionFrame::getBufferMutations()
const {
// key is buffer name, value is tensor to be written to buffer
std::unordered_map<std::string, at::Tensor> ret;
const auto& buffersToMutate = graph_.signature().buffersToMutate();
for (auto& [mutationOutputName, bufferName] : buffersToMutate) {
const auto& id = graph_.getValue(mutationOutputName)->id();
ret.emplace(bufferName, getTensor(id));
}
return ret;
}
} // namespace torch::nativert

View File

@ -0,0 +1,131 @@
#pragma once
#include <unordered_map>
#include "torch/csrc/nativert/executor/Weights.h"
#include "torch/csrc/nativert/graph/Graph.h"
#include <c10/util/Logging.h>
#include <torch/csrc/distributed/c10d/Work.hpp> // @manual
namespace torch::nativert {
/**
* This class encapsulate the stateful values of an execution,
* most notably, the tensor values passed between nodes, aka intermediate
* activations.
*/
class ExecutionFrame {
public:
// Constructor for weight-less graph, used for higher order ops, e.g.
// torch.cond
explicit ExecutionFrame(const Graph& graph);
explicit ExecutionFrame(const Graph& graph, const Weights& weights);
// Constructor for testing purpose
explicit ExecutionFrame(
const Graph& graph,
size_t numValues,
const std::vector<ValueId>& graphInputIds,
const std::vector<ValueId>& graphOutputIds);
~ExecutionFrame() {}
std::vector<c10::IValue> getUserOutputs() const;
c10::List<c10::IValue> getUserOutputsAsTensorList() const;
std::unordered_map<std::string, at::Tensor> getBufferMutations() const;
std::unordered_map<std::string, at::Tensor> getAllOutputs() const;
const c10::IValue& getIValue(ValueId id, bool allowNone = true) const {
const auto& iValue = allValues_[id];
if (allowNone && iValue.isNone()) {
return iValue;
}
DCHECK(!iValue.isNone());
return iValue;
}
c10::IValue& getIValue(ValueId id, bool allowNone = true) {
auto& iValue = allValues_[id];
if (allowNone && iValue.isNone()) {
return iValue;
}
DCHECK(!iValue.isNone());
return iValue;
}
void setIValue(ValueId id, c10::IValue ivalue);
at::Tensor getTensor(ValueId id) const;
std::vector<at::Tensor> getTensorVector(ValueId id) const {
return getIValue(id).toTensorVector();
}
int64_t getSymInt(ValueId id) const {
return getIValue(id).toInt();
}
double getSymFloat(ValueId id) const {
return getIValue(id).toDouble();
}
void setPersistentIValue(ValueId id, c10::IValue ivalue) {
setIValue(id, std::move(ivalue));
persistent_[id] = true;
}
void releaseValue(ValueId id) {
CHECK(!persistent_[id]) << "Cannot release persistent value";
allValues_[id] = c10::IValue();
}
void releaseUserOutputs() {
for (const auto& outputValue : graph_.userOutputs()) {
if (std::holds_alternative<Value*>(outputValue)) {
Value* valuePtr = std::get<Value*>(outputValue);
if (valuePtr) {
const auto& id = valuePtr->id();
if (!persistent_[id]) {
releaseValue(id);
}
}
}
}
}
void setWork(int64_t workId, const c10::intrusive_ptr<c10d::Work>& work) {
work_[workId] = work;
}
c10::intrusive_ptr<c10d::Work> getWork(int64_t workId) const {
CHECK(work_.find(workId) != work_.end())
<< "Couldn't find work with Id: " << workId;
return work_.at(workId);
}
WeightVersion weightVersion() const {
return weightVersion_;
}
void setWeights(const Weights& weights);
private:
const Graph& graph_;
WeightVersion weightVersion_ = -1;
// All the intermediate values for the entire graph, including graph inputs
// and outputs This table is fixed once constructed
std::vector<c10::IValue> allValues_;
std::vector<bool> persistent_;
std::unordered_map<int64_t, c10::intrusive_ptr<c10d::Work>> work_;
std::unordered_map<std::string, ValueId> foldedConstIds_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,117 @@
#include <unordered_map>
#include <c10/util/Logging.h>
#include "torch/csrc/nativert/common/Enumerate.h"
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
namespace torch::nativert {
std::unique_ptr<ExecutionPlan> ExecutionPlanner::createPlan() {
auto plan = std::make_unique<ExecutionPlan>();
// Current implementation assume that nodes will be executed
// in the same order as the thrift graph.
// In the future, we can do execution order plan, as long as it's
// comply with topological order
generateDeallocationPlan(*plan);
return plan;
}
/* static */ c10::FastSet<ValueId> ExecutionPlanner::staticValues(
const Graph& graph) {
c10::FastSet<ValueId> staticValues;
// Filter lastUsedBy by graph inputs
// parameters/buffer values should not be freed
// It's a policy decision to whether to free user inputs. For now, we don't
// free user inputs.
// TODO: It should be fine to "free" the user inputs. If the user holds a ref
// to it, it won't be deallocated.
for (const auto* input : graph.inputs()) {
if (input) {
const auto& id = input->id();
staticValues.insert(id);
}
}
// Filter lastUsedBy by graph outputs, as they are still needed to be returned
for (const auto& output : graph.outputs()) {
const auto& id = output->id();
staticValues.insert(id);
}
for (const auto& [id, _] : graph.getConstantSymIntValues()) {
staticValues.insert(id);
}
for (const Node& node : graph.nodes()) {
if (node.target() == "torch.ops.higher_order.run_const_graph") {
for (const auto& output : node.outputs()) {
// Do not free the outputs of run_const_graph, as they are newly
// produced folded constants
staticValues.insert(output->id());
}
} else {
for (const auto& input : node.inputs()) {
if (input.value->isFolded()) {
staticValues.insert(input.value->id());
}
}
}
}
return staticValues;
}
void ExecutionPlanner::generateDeallocationPlan(ExecutionPlan& plan) {
const auto& nodes = graph_.nodes();
size_t numNodes = nodes.size();
std::unordered_map<ValueId, NodeIndex> lastUsedBy;
// Traverse from the last node to the first node
// For each Value, find out which is the last node that uses it
// the Value can freed after executing the node
size_t nodeIdx = nodes.size() - 1;
for (auto it = std::rbegin(nodes); it != std::rend(nodes); it++) {
const auto& inputs = it->inputs();
for (const auto& input : inputs) {
const auto& id = input.value->id();
if (lastUsedBy.find(id) == lastUsedBy.end()) {
lastUsedBy.insert({id, nodeIdx});
}
}
nodeIdx--;
}
std::vector<std::vector<ValueId>> valuesToFree(numNodes);
const auto& statics = staticValues(graph_);
for (auto& [id, nodeIndex] : lastUsedBy) {
if (statics.find(id) == statics.end()) {
valuesToFree[nodeIndex].push_back(id);
}
}
plan.valuesToFree = std::move(valuesToFree);
// print allocation plan
VLOG(2) << plan;
return;
}
std::ostream& operator<<(std::ostream& out, const ExecutionPlan& plan) {
out << "****** Deallocation Plan ******\n";
for (auto&& [i, values] : enumerate(plan.valuesToFree)) {
out << "Node #" << i << ", valuesToFree = [";
for (const auto& value : values) {
out << value << ", ";
}
out << "]\n";
}
return out;
}
} // namespace torch::nativert

View File

@ -0,0 +1,41 @@
#pragma once
#include "torch/csrc/nativert/graph/Graph.h"
#include <torch/script.h>
namespace torch::nativert {
class Graph;
// ExecutionPlan is the result produced by ExecutionPlanner
// ATM, it only contains value deallocation plan.
// In the future, it can include execution order plan, allocation plan for
// parameter/gradient alignment, static memory plan for activation buffer reuse
// ect...
struct ExecutionPlan {
// i-th entry in this list are the Values can be freed *after* execution i-th
// node
std::vector<std::vector<ValueId>> valuesToFree;
};
class ExecutionPlanner {
public:
explicit ExecutionPlanner(const Graph& graph) : graph_(graph) {}
std::unique_ptr<ExecutionPlan> createPlan();
// get list of values we can't free
static c10::FastSet<ValueId> staticValues(const Graph& graph);
private:
void generateDeallocationPlan(ExecutionPlan& plan);
// NYI
void generatedMemoryPlan(ExecutionPlan& plan);
const Graph& graph_;
};
std::ostream& operator<<(std::ostream& out, const ExecutionPlan& plan);
} // namespace torch::nativert

View File

@ -0,0 +1,273 @@
#include "torch/csrc/nativert/executor/Executor.h"
#include <memory>
#include <c10/util/Logging.h>
#include "torch/csrc/nativert/common/AutoTimer.h"
#include "torch/csrc/nativert/common/Enumerate.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/ParallelGraphExecutor.h"
#include "torch/csrc/nativert/executor/SerialGraphExecutor.h"
#include "torch/csrc/nativert/executor/Weights.h"
namespace torch::nativert {
Executor::Executor(
ExecutorConfig executorConfig,
std::shared_ptr<Graph> graph,
std::shared_ptr<Weights> weights,
const Placement& placement,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader)
: executorConfig_(std::move(executorConfig)),
graph_(std::move(graph)),
placement_(placement),
constantFolder_(
executorConfig_.runConstFolding
? std::optional<ConstantFolder>(*graph_)
: std::nullopt),
executionFrames_(executorConfig_.maxNumConcurrentThreads),
numExecutionFrames_(0) {
if (weights) {
initialize(std::move(weights), std::move(pytorchStreamReader));
}
}
void Executor::initialize(
std::shared_ptr<Weights> weights,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader) {
AutoTimer t("Initialization completed");
auto executionKernels = KernelFactory().initializeNodeKernels(
*graph_,
weights,
executorConfig_,
placement_,
std::move(pytorchStreamReader));
if (constantFolder_.has_value()) {
constantFolder_->unlinkConstants(executionKernels.nodeKernels);
}
if (executorConfig_.maxParallelOps > 1) {
graphExecutor_ = std::make_unique<ParallelGraphExecutor>(
*graph_, std::move(executionKernels.nodeKernels), executorConfig_);
} else {
graphExecutor_ = std::make_unique<SerialGraphExecutor>(
*graph_, std::move(executionKernels.nodeKernels), executorConfig_);
}
delegateExecutors_ = std::move(executionKernels.delegateExecutors);
constFoldingExecutions_ = std::move(executionKernels.constFoldingExecutions);
// initialize weights_
processWeights(weights);
atomicSwapWeights(std::move(weights));
}
void Executor::atomicSwapWeights(std::shared_ptr<Weights> weights) {
weights_.withLock([&](auto& w) { w = std::move(weights); });
// update weights in delegate executors
for (auto& delegateExecutor : delegateExecutors_) {
delegateExecutor->commitWeights();
}
}
void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
for (auto& execution : constFoldingExecutions_) {
ExecutionFrame constFoldingFrame(execution.executor->graph());
std::vector<c10::IValue> inputs;
inputs.reserve(graph_->signature().inputsToWeights().size());
for (const auto& [_, name] : graph_->signature().inputsToWeights()) {
inputs.push_back(weights->at(name));
}
auto outputs = execution.executor->execute(constFoldingFrame, inputs);
for (const auto& [idx, value] :
enumerate(execution.executor->graph().outputs())) {
weights->updateFoldedConst(value->name(), outputs.at(idx));
}
}
}
void Executor::processWeights(std::shared_ptr<Weights> weights) {
maybeRunConstantFolding(weights);
if (constantFolder_.has_value()) {
constantFolder_->evaluate(*weights);
}
for (auto& delegateExecutor : delegateExecutors_) {
delegateExecutor->processWeights(weights);
}
}
namespace {
void validateInput(
const std::string& inputName,
const at::Tensor& inputTensor,
const TensorMeta& tensorValueMeta) {
CHECK(inputTensor.dtype() == tensorValueMeta.dtype())
<< "Input tensor dtype mismatch for " << inputName << ", expecting "
<< c10::toString(tensorValueMeta.dtype()) << " but got "
<< inputTensor.dtype().name();
CHECK(inputTensor.device() == tensorValueMeta.device())
<< "Input tensor device mismatch for " << inputName << ", expecting "
<< tensorValueMeta.device().str() << " but got "
<< inputTensor.device().str();
}
} // namespace
// validate input tensor's dtype matches tensorMeta
void Executor::validateInputs(const std::vector<c10::IValue>& inputs) const {
const auto& inputValues = graph_->userInputs();
const auto& tensorValuesMeta = graph_->tensorValuesMeta();
TORCH_CHECK(inputs.size() == inputValues.size(), "Input size mismatch");
for (auto&& [i, actualInput] : enumerate(inputs)) {
if (actualInput.isTensor()) {
const auto& inputName = std::string(inputValues[i]->name());
auto it = tensorValuesMeta.find(inputName);
CHECK(it != tensorValuesMeta.end())
<< "Couldn't find " << inputName << " in tensorValuesMeta";
validateInput(inputName, actualInput.toTensor(), it->second);
}
}
}
std::unique_ptr<ExecutionFrame> Executor::getExecutorFrameFromPool() {
std::shared_ptr<Weights> weights;
weights_.withLock([&](auto& w) { weights = w; });
std::unique_ptr<ExecutionFrame> frame;
while (!executionFrames_.readIfNotEmpty(frame)) {
int numFrames = numExecutionFrames_.load();
if (numFrames < executorConfig_.maxNumConcurrentThreads) {
if (numExecutionFrames_.compare_exchange_strong(
numFrames, numFrames + 1)) {
return std::make_unique<ExecutionFrame>(*graph_, *weights);
}
} else {
sem_.acquire();
}
}
if (frame->weightVersion() != weights->version()) {
frame->setWeights(*weights);
}
return frame;
}
void Executor::returnExecutorFrameToPool(
std::unique_ptr<ExecutionFrame> frame) {
if (executorConfig_.enableStaticCPUKernels) {
frame->releaseUserOutputs();
}
CHECK(executionFrames_.writeIfNotFull(std::move(frame)))
<< "ExecutionFrame pool full";
sem_.release();
}
std::unique_ptr<ExecutionFrame> Executor::executeInternal(
std::vector<c10::IValue> inputs) {
if (executorConfig_.validateInputs) {
validateInputs(inputs);
}
auto executionFrame = getExecutorFrameFromPool();
try {
graphExecutor_->execute(*executionFrame, std::move(inputs));
} catch (...) {
returnExecutorFrameToPool(std::move(executionFrame));
throw;
}
return executionFrame;
}
std::vector<c10::IValue> Executor::execute(std::vector<c10::IValue> inputs) {
auto executionFrame = executeInternal(std::move(inputs));
auto outputs = executionFrame->getUserOutputs();
returnExecutorFrameToPool(std::move(executionFrame));
return outputs;
}
std::vector<c10::IValue> Executor::execute(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& inputTreeSpec) {
auto executionFrame = getExecutorFrameFromPool();
std::optional<std::vector<c10::IValue>> outputs;
try {
const auto& userInputs = graph_->userInputs();
const auto& tensorValuesMeta = graph_->tensorValuesMeta();
TORCH_CHECK_EQ(userInputs.size(), inputTreeSpec.numLeaves());
size_t input_idx = 0;
auto executionFrameFillUserInputs = [&](const c10::IValue& leaf) {
auto value = userInputs[input_idx];
// skip if the value is not used
if (value && value->users().size() > 0) {
// validate input tensor's dtype and device matches tensorMeta
if (executorConfig_.validateInputs && leaf.isTensor()) {
const auto& inputName = std::string(value->name());
auto it = tensorValuesMeta.find(inputName);
CHECK(it != tensorValuesMeta.end())
<< "Couldn't find " << inputName << " in tensorValuesMeta";
validateInput(inputName, leaf.toTensor(), it->second);
}
executionFrame->setIValue(value->id(), leaf);
}
input_idx++;
};
leafApplyFromArgs(
executionFrameFillUserInputs, args, kwargs, inputTreeSpec);
outputs = graphExecutor_->executeWithPrefilledFrame(*executionFrame);
} catch (...) {
returnExecutorFrameToPool(std::move(executionFrame));
throw;
}
returnExecutorFrameToPool(std::move(executionFrame));
return *outputs;
}
ProfileMetrics Executor::benchmarkIndividualNodes(
std::vector<std::vector<c10::IValue>> inputsList,
const uint32_t warmupRuns,
const uint32_t mainRuns) {
CHECK(inputsList.size() > 0) << "Need at least one input to benchmark";
CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run";
for (const auto& inputs : inputsList) {
for (const auto& input : inputs) {
CHECK(input.isTensor() || input.isCustomClass())
<< "For now, all graph inputs should be tensor, or custom class object, but got "
<< input.tagKind();
}
if (executorConfig_.validateInputs) {
validateInputs(inputs);
}
}
auto executionFrame = getExecutorFrameFromPool();
auto benchmarkResults = graphExecutor_->benchmarkIndividualNodes(
*executionFrame, inputsList, warmupRuns, mainRuns);
returnExecutorFrameToPool(std::move(executionFrame));
return benchmarkResults;
}
std::vector<DelegateExecutor*> Executor::getDelegates() {
std::vector<DelegateExecutor*> delegates;
for (const auto& delegateExecutor : delegateExecutors_) {
delegates.push_back(delegateExecutor.get());
}
return delegates;
}
} // namespace torch::nativert

View File

@ -0,0 +1,119 @@
#pragma once
#include <atomic>
#include <memory>
#include <c10/util/Logging.h>
#include <torch/script.h>
#include "torch/csrc/nativert/common/MPMCQueue.h"
#include "torch/csrc/nativert/common/Semaphore.h"
#include "torch/csrc/nativert/executor/ConstantFolder.h"
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
#include "torch/csrc/nativert/executor/Placement.h"
#include "torch/csrc/nativert/graph/Graph.h"
#include "torch/csrc/nativert/graph/GraphSignature.h"
#include "torch/csrc/nativert/kernels/KernelFactory.h"
#include "torch/csrc/nativert/common/Pytree.h"
namespace torch::nativert {
class Weights;
struct DistributedRunConfig;
/**
* A very dumb executor. Basically just runs each node in order and contains a
* giant unordered map for every intermediate, no optimizations applied.
*/
class Executor {
public:
// Constrcutor used for Inference Path
Executor(
ExecutorConfig executorConfig,
std::shared_ptr<Graph> graph,
std::shared_ptr<Weights> weights,
const Placement& placement = Placement(),
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader = nullptr);
std::shared_ptr<Weights> getWeights() {
std::shared_ptr<Weights> ret;
weights_.withLock([&](auto& w) { ret = w; });
return ret;
}
void processWeights(std::shared_ptr<Weights> weights);
void atomicSwapWeights(std::shared_ptr<Weights> weights);
// This API only returns the flattened UserOutputs,
// intended to be used for Inference path
std::vector<c10::IValue> execute(std::vector<c10::IValue> inputs);
std::vector<c10::IValue> execute(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const TreeSpec& inputTreeSpec);
ProfileMetrics benchmarkIndividualNodes(
std::vector<std::vector<c10::IValue>> inputsList,
const uint32_t warmupRuns,
const uint32_t mainRuns);
const GraphSignature& graphSignature() const {
return graph_->signature();
}
static std::string className() {
return "Executor.v0";
}
const ExecutorConfig& executorConfig() const {
return executorConfig_;
}
std::vector<DelegateExecutor*> getDelegates();
protected:
ExecutorConfig executorConfig_;
std::shared_ptr<Graph> graph_;
// manages the parameters, buffers and tensor constants
c10::Synchronized<std::shared_ptr<Weights>> weights_;
std::unique_ptr<ExecutionFrame> executeInternal(
std::vector<c10::IValue> inputs);
void initialize(
std::shared_ptr<Weights> weights,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader);
std::unique_ptr<ExecutionFrame> getExecutorFrameFromPool();
void returnExecutorFrameToPool(std::unique_ptr<ExecutionFrame> frame);
private:
void maybeRunConstantFolding(std::shared_ptr<Weights> weights);
void validateInputs(const std::vector<c10::IValue>& inputs) const;
std::unique_ptr<GraphExecutorBase> graphExecutor_;
const Placement placement_;
// NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_.
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors_;
std::vector<ConstFoldingExecution> constFoldingExecutions_;
std::optional<ConstantFolder> constantFolder_;
Semaphore sem_;
MPMCQueue<std::unique_ptr<ExecutionFrame>> executionFrames_;
std::atomic_int numExecutionFrames_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,27 @@
#pragma once
#include <torch/script.h>
namespace torch::nativert {
struct ExecutorConfig {
bool validateInputs = false;
bool debugNan = false;
// allows up to max number of concurrent threads.
int64_t maxNumConcurrentThreads = 8;
// allows up to max number of parallel ops.
int64_t maxParallelOps = 1;
bool enableStaticCPUKernels = false;
bool enableStaticMemoryPlanning = false;
std::string modelName = "unknown";
bool runConstFolding = false;
};
} // namespace torch::nativert

View File

@ -0,0 +1,120 @@
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
#include "torch/csrc/nativert/common/RecordFunction.h"
#include <c10/util/Logging.h>
#include <caffe2/core/timer.h>
namespace torch::nativert {
GraphExecutorBase::GraphExecutorBase(
const Graph& graph,
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
const ExecutorConfig& executorConfig)
: graph_(graph),
nodeKernels_(std::move(nodeKernels)),
executorConfig_(executorConfig),
execPlan_(ExecutionPlanner{graph_}.createPlan()){};
void GraphExecutorBase::fillUserInputs(
ExecutionFrame& frame,
std::vector<c10::IValue> inputs) {
RecordFunction recordFunction("Executor::fillUserInputs");
const auto& inputValues = graph_.userInputs();
TORCH_CHECK_EQ(inputValues.size(), inputs.size());
// load user input tensor into execution frame
for (size_t i = 0; i < inputValues.size(); i++) {
if (inputValues[i]) {
frame.setIValue(inputValues[i]->id(), std::move(inputs[i]));
}
}
}
ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
ExecutionFrame& executionFrame,
std::vector<std::vector<c10::IValue>> inputsList,
const uint32_t warmupRuns,
const uint32_t mainRuns) {
// TODO: add support for memory profiling
TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1);
ProfileMetrics results;
const auto numNodes = static_cast<uint32_t>(nodeKernels_.size());
results.timePerNode.resize(numNodes, 0);
if (inputsList.empty()) {
auto i = 0;
for (const auto& nodeKernel : nodeKernels_) {
std::string target(nodeKernel->node()->target());
results.timePerNode[i] = 0;
results.timePerNodeType[target] = 0;
results.instancesPerNodeType[target]++;
if (nodeKernel->hasPrimKernel()) {
results.primNodesCount++;
results.primNodes.insert(target);
} else if (nodeKernel->hasStaticDispatch()) {
results.staticDispatchNodesCount++;
results.staticDispatchNodes.insert(target);
}
i++;
}
results.totalNodesCount = numNodes;
for (const auto& p : results.timePerNodeType) {
const std::string& kind = p.first;
results.percentPerNodeType[kind] = 0;
}
return results;
}
// Warmup
for (auto i = 0; i < warmupRuns; i++) {
for (const auto& inputs : inputsList) {
execute(executionFrame, inputs);
}
}
// Execute kernels
caffe2::Timer timer;
for (auto i = 0; i < mainRuns; i++) {
for (auto inputs : inputsList) {
const auto& inputValues = graph_.userInputs();
TORCH_CHECK_EQ(inputValues.size(), inputs.size());
for (size_t j = 0; j < inputValues.size(); j++) {
executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j]));
}
for (NodeIndex nodeIdx = 0; nodeIdx < nodeKernels_.size(); ++nodeIdx) {
timer.Start();
nodeKernels_[nodeIdx]->compute(executionFrame);
float millis = timer.MilliSeconds();
results.timePerNode[nodeIdx] += millis;
}
}
}
// Summarize results
const float numTotalIters =
(static_cast<float>(mainRuns) * static_cast<float>(inputsList.size()));
for (const auto i : c10::irange(numNodes)) {
const Node* node = nodeKernels_[i]->node();
std::string target(node->target());
results.timePerNode[i] /= numTotalIters;
results.timePerNodeType[target] += results.timePerNode[i];
results.instancesPerNodeType[target]++;
if (nodeKernels_[i]->hasPrimKernel()) {
results.primNodes.insert(target);
results.primNodesCount++;
} else if (nodeKernels_[i]->hasStaticDispatch()) {
results.staticDispatchNodes.insert(target);
results.staticDispatchNodesCount++;
}
results.totalTime += results.timePerNode[i];
}
results.totalNodesCount = numNodes;
for (const auto& r : results.timePerNodeType) {
const std::string& target = r.first;
results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime;
}
return results;
}
} // namespace torch::nativert

View File

@ -0,0 +1,79 @@
#pragma once
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
#include "torch/csrc/nativert/graph/Graph.h"
#include "torch/csrc/nativert/graph/GraphSignature.h"
namespace torch::nativert {
struct ProfileMetrics {
size_t primNodesCount{0};
size_t staticDispatchNodesCount{0};
size_t totalNodesCount{0};
std::vector<float> timePerNode;
std::unordered_map<std::string, float> timePerNodeType;
std::unordered_map<std::string, float> percentPerNodeType;
std::unordered_map<std::string, int> instancesPerNodeType;
std::unordered_set<std::string> staticDispatchNodes;
std::unordered_set<std::string> primNodes;
float totalTime{0};
};
/**
* GraphExecutor is a lightweight abstraction to execute a graph with
* execution frames without actually owning the graph nor the weights. This is
* introduced to decouple the state management of the top level runtime from the
* kernel executions so that sub graphs from higher order ops can be supported.
*/
class GraphExecutorBase {
public:
GraphExecutorBase(
const Graph& graph,
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
const ExecutorConfig& executorConfig);
virtual ~GraphExecutorBase() = default;
const Graph& graph() const {
return graph_;
}
// This API only returns the flattened UserOutputs,
// intended to be used for Inference path
virtual std::vector<c10::IValue> execute(
ExecutionFrame& frame,
std::vector<c10::IValue> inputs) = 0;
virtual std::vector<c10::IValue> executeWithPrefilledFrame(
ExecutionFrame& frame) = 0;
ProfileMetrics benchmarkIndividualNodes(
ExecutionFrame& executionFrame,
std::vector<std::vector<c10::IValue>> inputs,
const uint32_t warmup_runs,
const uint32_t main_runs);
std::vector<std::unique_ptr<OpKernel>> stealKernels() {
return std::move(nodeKernels_);
}
void setKernels(std::vector<std::unique_ptr<OpKernel>>&& kernels) {
nodeKernels_ = std::move(kernels);
}
protected:
void fillUserInputs(ExecutionFrame& frame, std::vector<c10::IValue> inputs);
const Graph& graph_;
// cache of the constructed kernels to avoid reconstruction per execution
std::vector<std::unique_ptr<OpKernel>> nodeKernels_;
const ExecutorConfig& executorConfig_;
std::unique_ptr<ExecutionPlan> execPlan_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,389 @@
#include "torch/csrc/nativert/executor/ModelRunnerBase.h"
#include "torch/csrc/nativert/common/RecordFunction.h"
#include <c10/util/Logging.h>
#include <caffe2/serialize/inline_container.h>
#include <fmt/format.h>
#include "torch/csrc/nativert/executor/Weights.h"
#include "torch/csrc/nativert/graph/TensorMeta.h"
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
namespace torch::nativert {
ModelRunnerBase::ModelRunnerBase(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
const std::function<Placement(const torch::nativert::Graph& graph)>&
buildPlacementFn)
: modelName_(modelName),
executorType_(executorType),
runtimeConfigs_(runtimeConfigs) {}
void ModelRunnerBase::setExecutorType(
ExecutorType type,
const std::string& platformArch) {
LOG(INFO) << fmt::format(
"Setting executor type to {} with platformArch='{}'", type, platformArch);
executorType_ = type;
if (type == ExecutorType::AOTINDUCTOR) {
runtimeConfigs_.platformArch = platformArch;
} else if (type == ExecutorType::MTIA) {
// TODO: hardcoded for now (Sigmoid packages specify platformArch as
// "mtia")
runtimeConfigs_.platformArch = "mtia";
}
}
const std::string& ModelRunnerBase::getModelName() const {
return modelName_;
}
std::shared_ptr<Weights> ModelRunnerBase::getWeights() {
if (executor_ != nullptr) {
return executor_->getWeights();
} else if (newWeights_ != nullptr) {
return newWeights_;
} else {
TORCH_CHECK(
false, "ModelRunner is not initialized, and no weights are loaded.");
}
}
std::shared_ptr<Weights> ModelRunnerBase::loadNewWeights(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> packageStreamReader,
std::function<bool(const std::string&)> skipSizeCheck,
std::function<bool(const std::string&)> skipDtypeCheck) {
LOG(INFO) << "ModelRunner loading new weights";
newWeights_ = std::make_shared<Weights>(
graph_.get(),
packageStreamReader,
stateDictPath_,
archive_spec::WEIGHTS_DIR,
constantPaths_,
archive_spec::CONSTANTS_DIR,
placement_,
std::move(skipSizeCheck),
std::move(skipDtypeCheck));
return newWeights_;
}
void ModelRunnerBase::commitNewWeights() {
CHECK(newWeights_) << "No new weights loaded";
CHECK(executor_) << "ModelRunner not initialized";
executor_->processWeights(newWeights_);
executor_->atomicSwapWeights(std::move(newWeights_));
newWeights_ = nullptr;
}
bool ModelRunnerBase::loadExtraFiles(
ExtraFilesMap& extraFiles,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader) {
auto filesExist = false;
for (const auto& kv : extraFiles) {
const auto key = std::string{archive_spec::EXTRA_DIR} + kv.first;
if (pytorchStreamReader->hasRecord(key)) {
auto [metaPtr, metaSize] = pytorchStreamReader->getRecord(key);
extraFiles[kv.first] =
std::string(static_cast<char*>(metaPtr.get()), metaSize);
filesExist = true;
}
}
return filesExist;
}
c10::IValue ModelRunnerBase::run(
const std::unordered_map<std::string, c10::IValue>& kwargs,
const RunConfigs& runConfigs) {
return run({}, kwargs, runConfigs);
}
c10::IValue ModelRunnerBase::run(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const RunConfigs& runConfigs) {
RecordFunction recordFunction("nativert::ModelRunner::run");
CHECK(executor_) << "ModelRunner not initialized";
// ModelRunner is only used for inference
c10::InferenceMode mode;
return treeUnflatten(
executor_->execute(args, kwargs, inputSpec_), outputSpec_);
}
std::vector<c10::IValue> ModelRunnerBase::runWithFlatInputsAndOutputs(
std::vector<c10::IValue>&& flatInputs,
const RunConfigs& /* runConfigs */) {
RecordFunction recordFunction(
"nativert::ModelRunner::runWithFlatInputsAndOutputs");
CHECK(executor_) << "ModelRunner not initialized";
// ModelRunner is only used for inference
c10::InferenceMode mode;
return executor_->execute(std::move(flatInputs));
}
void ModelRunnerBase::benchmarkIndividualNodes(
const std::vector<std::vector<c10::IValue>>& argsList,
const std::vector<std::unordered_map<std::string, c10::IValue>>& kwargsList,
const uint32_t warmupRuns,
const uint32_t mainRuns,
const bool printPerNodeTime,
const RunConfigs& runConfigs) {
std::vector<std::vector<c10::IValue>> flatInputsList;
for (const auto& args : argsList) {
if (!kwargsList.empty()) {
for (const auto& kwargs : kwargsList) {
flatInputsList.emplace_back(
treeFlattenFromArgs(args, kwargs, inputSpec_));
}
} else {
flatInputsList.emplace_back(treeFlattenFromArgs(args, {}, inputSpec_));
}
}
c10::InferenceMode mode;
auto results =
executor_->benchmarkIndividualNodes(flatInputsList, warmupRuns, mainRuns);
if (printPerNodeTime) {
size_t i = 0;
for (const auto& node : graph_->nodes()) {
LOG(INFO) << "Node #" << i << ": " << node.toString()
<< "\n Time: " << results.timePerNode[i] << " ms/iter, ";
i++;
}
}
std::vector<std::pair<std::string, double>> sortedTimePerOp{
results.timePerNodeType.begin(), results.timePerNodeType.end()};
if (argsList.empty()) {
// alphabetical sort
std::sort(
sortedTimePerOp.begin(),
sortedTimePerOp.end(),
[&results](auto& left, auto& right) {
return results.instancesPerNodeType[left.first] >
results.instancesPerNodeType[right.first];
});
} else {
// sort by time
std::sort(
sortedTimePerOp.begin(),
sortedTimePerOp.end(),
[](auto& left, auto& right) { return left.second > right.second; });
}
LOG(INFO) << "Time per node type:" << '\n';
std::ostringstream unsupportedNodeKinds;
for (const auto& p : sortedTimePerOp) {
const std::string& kind = p.first;
const double ms = p.second;
std::ostringstream oss;
oss << std::setw(15) << ms << " ms. " << std::setw(10)
<< results.percentPerNodeType[kind] << "%. " << kind << " ("
<< results.instancesPerNodeType[kind] << " nodes";
if (results.primNodes.count(kind)) {
oss << ", prim) \n";
} else if (results.staticDispatchNodes.count(kind)) {
oss << ", static dispatch) \n";
} else {
unsupportedNodeKinds << kind << ", ";
oss << ")\n";
}
LOG(INFO) << oss.str();
}
LOG(INFO) << std::setw(15) << results.totalTime << " ms. in Total" << '\n';
LOG(INFO) << "Number of nodes: " << graph_->nodes().size() << '\n';
auto unsupportedCount = results.totalNodesCount -
results.staticDispatchNodesCount - results.primNodesCount;
LOG(INFO) << "Total number of static dispatch nodes/total number of nodes: "
<< results.staticDispatchNodesCount << "/"
<< results.totalNodesCount << " ("
<< 100.0 * static_cast<float>(results.staticDispatchNodesCount) /
static_cast<float>(results.totalNodesCount)
<< "%)" << '\n';
LOG(INFO) << "Total number of prim nodes/total number of nodes: "
<< results.primNodesCount << "/" << results.totalNodesCount << " ("
<< 100.0 * static_cast<float>(results.primNodesCount) /
static_cast<float>(results.totalNodesCount)
<< "%)" << '\n';
LOG(INFO)
<< "Total number of nodes not covered by static dispatch/total number of nodes: "
<< unsupportedCount << "/" << results.totalNodesCount << " ("
<< 100.0 * static_cast<float>(unsupportedCount) /
static_cast<float>(results.totalNodesCount)
<< "%)" << "\n Uncovered node kinds: " << unsupportedNodeKinds.str()
<< '\n';
}
std::vector<std::optional<c10::Device>>
ModelRunnerBase::getUserInputTargetDevices() const {
std::vector<std::optional<c10::Device>> devices;
for (const auto& tensorMeta : graph_->userInputsMeta()) {
c10::Device targetDevice = placement_.getMappedDevice(tensorMeta.device());
devices.push_back(targetDevice);
}
return devices;
}
std::pair<
std::vector<c10::IValue>,
std::unordered_map<std::string, c10::IValue>>
ModelRunnerBase::loadSampleInputs(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
const Placement& placement) {
LOG(INFO) << "Loading sample inputs in ModelRunner for model " << modelName_;
std::string sampleInputsPath =
fmt::format(archive_spec::SAMPLE_INPUTS_FILENAME_FORMAT, modelName_);
CHECK(pytorchStreamReader->hasRecord(sampleInputsPath))
<< sampleInputsPath << " is not found in package";
size_t size = pytorchStreamReader->getRecordSize(sampleInputsPath);
std::vector<char> buffers(size);
size_t sizeRead =
pytorchStreamReader->getRecord(sampleInputsPath, buffers.data(), size);
CHECK(sizeRead == size);
c10::IValue value = torch::jit::pickle_load(buffers);
// Move userInputs on the target device
std::vector<TensorMeta> userInputsMeta = graph_->userInputsMeta();
size_t tensorInputId = 0;
value = treeMap(
[&](const c10::IValue& inputVal) -> c10::IValue {
if (inputVal.isTensor()) {
auto& tensorMeta = userInputsMeta[tensorInputId];
c10::Device targetDevice =
placement.getMappedDevice(tensorMeta.device());
auto r = inputVal.toTensor().to(targetDevice);
VLOG(1) << "input #" << tensorInputId << " has been placed on "
<< targetDevice;
tensorInputId++;
return r;
} else {
return inputVal;
}
},
value,
inputSpec_);
CHECK(value.isTuple());
CHECK(value.toTupleRef().elements().size() == 2);
const auto& argsVal = value.toTupleRef().elements().at(0);
const auto& kwargsVal = value.toTupleRef().elements().at(1);
CHECK(argsVal.isTuple());
CHECK(kwargsVal.isGenericDict());
std::vector<c10::IValue> args;
for (const auto& arg : argsVal.toTupleRef().elements()) {
args.push_back(arg);
}
std::unordered_map<std::string, c10::IValue> kwargs;
for (const auto& entry : kwargsVal.toGenericDict()) {
kwargs[entry.key().toStringRef()] = entry.value();
}
return {std::move(args), std::move(kwargs)};
}
const std::vector<std::string>& ModelRunnerBase::getArgumentNames() const {
return graph_->signature().userInputsVec();
}
c10::ArrayRef<const Value*> ModelRunnerBase::getArguments() const {
return graph_->userInputs();
}
const TreeSpec& ModelRunnerBase::getOutputSpec() const {
return outputSpec_;
}
const TreeSpec& ModelRunnerBase::getInputSpec() const {
return inputSpec_;
}
std::string ModelRunnerBase::loadSerializedModel(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader)
const {
std::string modelFilePath =
fmt::format(archive_spec::MODELS_FILENAME_FORMAT, modelName_);
LOG(INFO) << "Loading model from: " << modelFilePath;
CHECK(pytorchStreamReader->hasRecord(modelFilePath))
<< modelFilePath << " not found in package";
const auto& [modelData, modelSize] =
pytorchStreamReader->getRecord(modelFilePath);
const std::string modelSerialized{
reinterpret_cast<char*>(modelData.get()), modelSize};
return modelSerialized;
}
void ModelRunnerBase::initialize(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader) {
if (executor_ != nullptr) {
LOG(WARNING)
<< "ModelRunner already initialized, re-initialization is an no op.";
return;
}
initializeExecutor(newWeights_, std::move(pytorchStreamReader));
newWeights_ = nullptr;
}
void ModelRunnerBase::initializeExecutor(
std::shared_ptr<Weights> weights,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader) {
CHECK(executor_ == nullptr) << "ModelRunner already initialized";
weights->validateAllWeightsLoaded();
ExecutorConfig config;
config.maxParallelOps = runtimeConfigs_.maxParallelOps;
config.validateInputs = runtimeConfigs_.validateInputs;
config.enableStaticCPUKernels = runtimeConfigs_.enableStaticCPUKernels;
config.maxNumConcurrentThreads = runtimeConfigs_.maxNumConcurrentThreads;
config.enableStaticMemoryPlanning =
runtimeConfigs_.enableStaticMemoryPlanning;
config.runConstFolding = runtimeConfigs_.enableRuntimeConstFolding;
if (executorType_ == ExecutorType::INTERPRETER) {
executor_ = std::make_unique<Executor>(
config, graph_, weights, placement_, pytorchStreamReader);
} else if (executorType_ == ExecutorType::AOTINDUCTOR) {
delegateGraph_ = deserializeDelegateGraph();
delegateGraph_->applyDevicePlacement(placement_);
VLOG(1) << "Delegate graph: \n" << *delegateGraph_;
executor_ = std::make_unique<Executor>(
config, delegateGraph_, weights, placement_, pytorchStreamReader);
} else if (executorType_ == ExecutorType::MTIA) {
delegateGraph_ = deserializeDelegateGraph();
delegateGraph_->applyDevicePlacement(placement_);
VLOG(1) << "Delegate graph: \n" << *delegateGraph_;
config.modelName = modelName_;
executor_ = std::make_unique<Executor>(
config, delegateGraph_, weights, placement_, pytorchStreamReader);
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,260 @@
#pragma once
#include <unordered_map>
#include <vector>
#include <fmt/format.h>
#include "torch/csrc/nativert/common/Pytree.h"
#include "torch/csrc/nativert/executor/Executor.h"
#include "torch/csrc/nativert/executor/Placement.h"
namespace torch::nativert {
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
enum class ExecutorType {
INTERPRETER = 0,
AOTINDUCTOR = 1,
MTIA = 2,
};
struct BaseRuntimeConfigs {
bool isDebug = false;
bool validateInputs = false;
// use static kernels
bool enableStaticCPUKernels = false;
// whether to enable static memory planning
bool enableStaticMemoryPlanning = false;
// whether to load node's metadata, e.g. stacktrace etc.
// This is only used for debugging purpose. For production, we commonly set
// this to false, as it would incur extra memory usage.
bool loadNodeMetadata = false;
// whether to initialize the executor in the constructor
// In some cases, Weights are not available when the constructor is called.
// In this case, the executor should be initialized later.
bool deferInitialization = false;
// platform arch for delegate, e.g "cpu", "sm80_x86" etc
// See https://fburl.com/code/3pym9ipj for supported platforms
std::string platformArch;
// allows up to max number of concurrent threads.
int64_t maxNumConcurrentThreads = 8;
// allows up to max number of parallel ops.
int64_t maxParallelOps = 1;
// whether to enable runtime const folding
bool enableRuntimeConstFolding = false;
};
struct RunConfigs {};
class TORCH_API ModelRunnerBase {
public:
ModelRunnerBase(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader,
const std::string& modelName,
ExecutorType executorType,
const BaseRuntimeConfigs& runtimeConfigs,
// functor to build the placement after the graph is loaded, but before
// loading the weights.
const std::function<Placement(const torch::nativert::Graph& graph)>&
buildPlacementFn);
ModelRunnerBase(ModelRunnerBase&&) = default;
ModelRunnerBase& operator=(ModelRunnerBase&&) = default;
ModelRunnerBase(const ModelRunnerBase&) = delete;
ModelRunnerBase& operator=(const ModelRunnerBase&) = delete;
virtual ~ModelRunnerBase() = default;
const std::string& getModelName() const;
std::shared_ptr<Weights> getWeights();
// loadNewWeights() loads the weights from the specified model into
// the newWeights_ buffer. The weights will stay shadow and should be
// actualized by the commitNewWeights() function.
std::shared_ptr<Weights> loadNewWeights(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
packageStreamReader,
std::function<bool(const std::string&)> skipSizeCheck = {},
std::function<bool(const std::string&)> skipDtypeCheck = {});
void commitNewWeights();
c10::IValue run(
const std::unordered_map<std::string, c10::IValue>& kwargs,
const RunConfigs& runConfigs = RunConfigs());
c10::IValue run(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const RunConfigs& runConfigs = RunConfigs());
/**
* A low level API which expects user to always pass in flattened inputs.
* The ownership of the entire input list must be transferred to the
* executor via std::move or in-place construction.
*/
std::vector<c10::IValue> runWithFlatInputsAndOutputs(
std::vector<c10::IValue>&& flatInputs,
const RunConfigs& runConfigs = RunConfigs());
void benchmarkIndividualNodes(
const std::vector<std::vector<c10::IValue>>& argsList,
const std::vector<std::unordered_map<std::string, c10::IValue>>&
kwargsList,
const uint32_t warmupRuns,
const uint32_t mainRuns,
const bool printPerNodeTime,
const RunConfigs& runConfigs);
std::vector<std::optional<c10::Device>> getUserInputTargetDevices() const;
std::pair<
std::vector<c10::IValue>,
std::unordered_map<std::string, c10::IValue>>
loadSampleInputs(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader,
const Placement& placement = Placement());
const std::vector<std::string>& getArgumentNames() const;
c10::ArrayRef<const Value*> getArguments() const;
/*
* Load extra files indicated in extraFiles from the model package.
* Return true if any extra files were loaded
* and false otherwise
*/
bool loadExtraFiles(
ExtraFilesMap& extraFiles,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader);
const TreeSpec& getOutputSpec() const;
const TreeSpec& getInputSpec() const;
void setExecutorType(ExecutorType type, const std::string& platformArch = "");
ExecutorType getExecutorType() const {
return executorType_;
}
void setEnableStaticDispatchKernels(bool enabled) {
runtimeConfigs_.enableStaticCPUKernels = enabled;
}
void setEnableStaticMemoryPlanning(bool enabled) {
runtimeConfigs_.enableStaticMemoryPlanning = enabled;
}
template <typename T>
std::vector<T*> getDelegates() {
std::vector<T*> delegates;
for (const auto& delegate : executor_->getDelegates()) {
if (auto* d = dynamic_cast<T*>(delegate)) {
delegates.push_back(d);
}
}
return delegates;
}
// Manually initialize the executor when config.deferInitialization is True.
//
// initlaize() must be call after
// - weights are fully loaded
// - executor is selected via setExecutorType()
// ModelRunner is not ready to serve with run() until initlaized.
//
// Note that pytorchStreamReader is required to load lowered modules.
//
// When initialization failed at any point, it will throw an exception. Caller
// should catch the exception and call initialize() again with valid weights
// and executor type.
//
// ModelRunner can only be initialized once, it will be noop to call this
// function when ModelRunner is already ready to serve.
void initialize(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader = nullptr);
protected:
#ifdef ModelRunnerTest_TEST_FRIENDS
ModelRunnerTest_TEST_FRIENDS;
#endif
virtual std::unique_ptr<Graph> deserializeDelegateGraph() const = 0;
void initializeExecutor(
std::shared_ptr<Weights> weights,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader);
std::string loadSerializedModel(
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader) const;
std::string modelName_;
ExecutorType executorType_;
BaseRuntimeConfigs runtimeConfigs_;
Placement placement_;
// original non-delegated graph from torch.export()
std::shared_ptr<Graph> graph_;
// graph with delegated nodes after lowering/compilation
std::shared_ptr<Graph> delegateGraph_;
// key is weight name, value is archive path for weight
std::unordered_map<std::string, std::string> stateDictPath_;
// contains both tensor constants and CustomClassHolder (aka. torchbind
// object)
std::unordered_map<std::string, std::string> constantPaths_;
std::unique_ptr<Executor> executor_;
// recently loaded and not yet committed weights.
std::shared_ptr<Weights> newWeights_;
TreeSpec inputSpec_;
TreeSpec outputSpec_;
};
} // namespace torch::nativert
template <>
struct fmt::formatter<torch::nativert::ExecutorType> {
template <typename ParseContext>
constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template <typename FormatContext>
auto format(torch::nativert::ExecutorType et, FormatContext& ctx) const {
using namespace torch::nativert;
switch (et) {
case ExecutorType::INTERPRETER:
return format_to(ctx.out(), "INTERPRETER");
case ExecutorType::AOTINDUCTOR:
return format_to(ctx.out(), "AOTINDUCTOR");
case ExecutorType::MTIA:
return format_to(ctx.out(), "MTIA");
default:
return format_to(ctx.out(), "UNKNOWN");
}
}
};

View File

@ -0,0 +1,168 @@
#include "torch/csrc/nativert/executor/OpKernel.h"
#include <c10/util/Logging.h>
#include <fmt/ostream.h>
#include "torch/csrc/nativert/common/ConfigUtils.h"
#include "torch/csrc/nativert/common/Enumerate.h"
#include "torch/csrc/nativert/common/String.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#ifdef __SIGRID_USE_GPU__
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#endif
namespace torch::nativert {
c10::OperatorHandle getOperatorForTarget(
std::string_view target,
const Node* node) {
// target could come as either "torch.ops.aten.add.default" or
// "aten.add.default"
std::vector<std::string_view> atoms = split(target, '.');
size_t numAtoms = atoms.size();
if (numAtoms < 3) {
TORCH_CHECK(false, "Invalid target: ", target);
}
const std::string_view ns = atoms[numAtoms - 3];
const std::string_view opName = atoms[numAtoms - 2];
const std::string_view overloadName = atoms[numAtoms - 1];
const auto operatorName = fmt::format("{}::{}", ns, opName);
std::string normalizedOverloadName;
if (overloadName == "default") {
normalizedOverloadName = "";
} else {
normalizedOverloadName = overloadName;
}
auto handle = c10::Dispatcher::singleton().findSchemaOrThrow(
operatorName.c_str(), normalizedOverloadName.c_str());
return handle;
}
std::string readableArgs(
const c10::FunctionSchema& schema,
const std::vector<c10::IValue>& stack) {
auto schemaArgs = schema.arguments();
std::stringstream ss;
for (const auto& [i, arg] : enumerate(stack)) {
ss << "arg" << i << " " << schemaArgs[i].name() << ": " << arg.tagKind()
<< " ";
if (arg.isTensor()) {
auto t = arg.toTensor();
ss << t.dtype() << t.sizes() << t.device();
} else if (arg.isTensorList()) {
auto tl = arg.toTensorVector();
ss << "[";
for (const auto& t : tl) {
ss << t.dtype() << t.sizes() << t.device() << ", ";
}
ss << "]";
} else if (arg.isNone()) {
// pass
} else {
ss << arg;
}
ss << "\n";
}
return ss.str();
}
const bool OpKernel::blockingEnabled_ =
maybeGetEnv("TORCH_NATIVE_RUNTIME_CUDA_LAUNCH_BLOCKING").value_or("0") == "1";
void OpKernel::compute(ExecutionFrame& executionFrame) const {
VLOG(2) << "Executing: " << *node_;
computeInternal(executionFrame);
#ifdef __SIGRID_USE_GPU__
if (device_.has_value() && device_->is_cuda() && blockingEnabled_) {
AT_CUDA_CHECK(cudaDeviceSynchronize());
AT_CUDA_CHECK(cudaGetLastError());
}
#endif
VLOG(2) << "Completed: " << *node_;
}
Arguments prefillStackWithStaticArgs(
const Node* node,
const c10::FunctionSchema& schema) {
std::vector<c10::IValue> stackWithStaticArgs;
std::vector<Value*> dynamicArgs;
const auto& schemaArgs = schema.arguments();
stackWithStaticArgs.resize(schemaArgs.size());
dynamicArgs.resize(schemaArgs.size());
// initialized stackWithStaticArgs_ with static inputs
for (const auto& [idx, schemaArg] : enumerate(schemaArgs)) {
const auto& argName = schemaArg.name();
// Check if this is a dynamic input to the op.
const auto input = node->tryGetInput(argName);
if (input != nullptr) {
stackWithStaticArgs.at(idx) = c10::IValue();
dynamicArgs.at(idx) = input->value;
continue;
}
// Check if this is a statically known input to the op.
const auto attribute = node->tryGetAttribute(argName);
if (attribute != nullptr) {
stackWithStaticArgs.at(idx) = constantToIValue(attribute->value);
continue;
}
// Otherwise, this must have a default value
if (schemaArg.default_value().has_value()) {
stackWithStaticArgs.at(idx) = schemaArg.default_value().value();
continue;
}
TORCH_CHECK(
false,
"Cannot initialize argument ",
argName,
" for node ",
*node,
" with schema ",
schema);
}
return Arguments{std::move(stackWithStaticArgs), std::move(dynamicArgs)};
}
void fillDynamicInputs(
const ExecutionFrame& executionFrame,
const Arguments& arguments,
std::vector<c10::IValue>& stack) {
// fill the stack with dynamic values from execution frame,
// including tensor, tensors, symint, symints
for (auto [idx, value] : arguments.getDynamicArgs()) {
CHECK(idx < stack.size()) << "invalid idx";
CHECK(stack.at(idx).isNone()) << "stack[idx] shouldn't have been populated";
if (value->type() == Type::TensorList) {
// TODO: This for passing List<Tensor> as an input to op that takes a
// List<Optional<Tensor>>.
// this is awful, but if I don't cast it to a vector and back to a
// list, I get list covariance problems where List<Tensor> is not a
// subtype of List<Optional<Tensor>>, which pops up when trying to
// execute aten.index.Tensor. See the code in:
// https://fburl.com/code/t1poq3z3. Our lists should be covariant
// because they are static, but IValues don't know that :(
stack[idx] = executionFrame.getIValue(value->id()).toTensorList().vec();
} else if (value->type() == Type::None) {
stack[idx] = c10::IValue();
} else {
stack[idx] = executionFrame.getIValue(value->id());
}
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,156 @@
#pragma once
#include <torch/script.h>
#include "c10/core/Device.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/graph/Graph.h"
namespace torch::nativert {
c10::OperatorHandle getOperatorForTarget(
std::string_view target,
const Node* node = nullptr);
class Arguments {
public:
Arguments(
std::vector<c10::IValue> stackWithStaticArgs,
std::vector<Value*> dynamicArgs)
: stackWithStaticArgs_(std::move(stackWithStaticArgs)),
dynamicArgs_(std::move(dynamicArgs)) {
for (size_t i = 0; i < dynamicArgs_.size(); i++) {
if (dynamicArgs_[i]) {
indices_.push_back(i);
}
}
}
/**
* Dynamic arguments are the inputs that were not baked in the graph
* during graph capture, i.e. all the tensor inputs to operators.
*
* This API will return a view of pairs consist of the argument index
* and the corresponding Value pointer from the graph.
*/
auto getDynamicArgs() const {
std::vector<std::pair<size_t, Value*>> ret;
ret.reserve(indices_.size());
for (auto i : indices_) {
ret.emplace_back(i, dynamicArgs_[i]);
}
return ret;
}
// Argument i means the i-th input to the operator in the argument list.
// Will return nullptr if the argument is not dynamic.
Value* findDynamic(size_t i) const {
DCHECK(i < dynamicArgs_.size()) << "Invalid input index: " << i;
return dynamicArgs_[i];
}
// Argument i means the i-th input to the operator in the argument list.
// Will return None as IValue if the argument is not static.
const c10::IValue& getStatic(size_t i) const {
DCHECK(i < stackWithStaticArgs_.size()) << "Invalid input index: " << i;
return stackWithStaticArgs_[i];
}
/**
* Static arguments are the inputs that were specialized to a fixed value
* during graph capture phase. For example, scalar inputs and device
* are considered arguments.
*/
const std::vector<c10::IValue>& getStackWithStaticArgs() const {
return stackWithStaticArgs_;
}
private:
// stack pre-populated with attributes, aka static arguments
const std::vector<c10::IValue> stackWithStaticArgs_;
// Argument can only be asTensor, asTensors, asSymInt, asSymInts
const std::vector<Value*> dynamicArgs_;
std::vector<size_t> indices_;
};
void fillDynamicInputs(
const ExecutionFrame& executionFrame,
const Arguments& arguments,
std::vector<c10::IValue>& stack);
Arguments prefillStackWithStaticArgs(
const Node* node,
const c10::FunctionSchema& schema);
std::string readableArgs(
const c10::FunctionSchema& schema,
const std::vector<c10::IValue>& stack);
// Abstract interface representing a kernel, which is responsible for executing
// a single Node.
class OpKernel {
public:
explicit OpKernel(
const Node* node,
std::optional<c10::Device> device = std::nullopt)
: node_(node), device_(device) {
VLOG(1) << "Initializing kernel for node: " << *node_;
}
enum class Kind : uint8_t {
kPrimKernel,
kStaticDispatchKernel,
kInterpreterFallbackKernel,
};
const Node* node() const {
return node_;
}
void compute(ExecutionFrame& executionFrame) const;
Kind kind() const {
return kind_;
}
bool hasPrimKernel() const {
return kind() == Kind::kPrimKernel;
}
bool hasStaticDispatch() const {
return kind() == Kind::kStaticDispatchKernel;
}
size_t numInputs() const {
return node_->inputs().size();
}
size_t numOutputs() const {
return node_->outputs().size();
}
// Input is readonly
[[nodiscard]] virtual const c10::IValue& input(
uint32_t i,
ExecutionFrame& executionFrame) const {
CHECK(i < numInputs()) << "Invalid input index: " << i;
return executionFrame.getIValue(node_->inputs()[i].value->id());
}
// Output is readwrite
c10::IValue& output(uint32_t i, ExecutionFrame& executionFrame) const {
CHECK(i < numOutputs()) << "Invalid output index: " << i;
return executionFrame.getIValue(node_->outputs()[i]->id(), true);
}
virtual ~OpKernel() = default;
protected:
virtual void computeInternal(ExecutionFrame& executionFrame) const = 0;
const Node* node_;
std::optional<c10::Device> device_;
const static bool blockingEnabled_;
Kind kind_ = Kind::kInterpreterFallbackKernel;
};
} // namespace torch::nativert

View File

@ -0,0 +1,245 @@
#include "torch/csrc/nativert/executor/ParallelGraphExecutor.h"
#include "torch/csrc/nativert/common/concurrentqueue.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
namespace {
#define LIKELY(x) (__builtin_expect((x), 1))
#define UNLIKELY(x) (__builtin_expect((x), 0))
#define WITH_LOCK(m, block) \
{ \
std::unique_lock<decltype(m)> lk_(m); \
block \
}
} // namespace
namespace torch::nativert {
ThreadPoolExecutor::ThreadPoolExecutor()
: work_(std::make_unique<moodycamel::ConcurrentQueue<Work>>()) {}
ThreadPoolExecutor::~ThreadPoolExecutor() {
stop();
}
C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() {
thread_local moodycamel::ProducerToken ptok(*work_);
return ptok;
}
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() {
thread_local moodycamel::ConsumerToken ctok(*work_);
return ctok;
}
void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) {
session->addWork();
unit->run(this, session);
}
void ThreadPoolExecutor::start(int32_t numThreads) {
stopped_ = false;
for (uint32_t i = 0; i < numThreads; ++i) {
threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this));
}
}
void ThreadPoolExecutor::loop() {
while (true) {
Work unit;
sem_->acquire();
if (stopped_) {
return;
}
while (!work_->try_dequeue(ctok(), unit)) {
};
unit();
}
}
void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) {
session->addWork();
work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session));
sem_->release();
}
void ThreadPoolExecutor::add(
SessionState* session,
std::vector<WorkUnit*>::const_iterator&& begin,
const std::vector<WorkUnit*>::const_iterator&& end) {
const auto count = end - begin;
switch (count) {
case 0: {
return;
}
case 1: {
return add(session, *begin);
}
}
session->addWork(count);
std::vector<Work> runnables;
runnables.reserve(count);
for (; begin != end; ++begin) {
runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session));
}
work_->enqueue_bulk(ptok(), runnables.begin(), count);
sem_->release(count);
}
void ThreadPoolExecutor::stop() {
stopped_ = true;
sem_->release(threads_.size());
std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); });
threads_.clear();
{
// reset sem
auto tmp = std::make_unique<Semaphore>();
sem_.swap(tmp);
}
{
// flush queue
auto tmp = moodycamel::ConcurrentQueue<Work>();
work_->swap(tmp);
}
}
void ThreadPoolExecutor::run(
SessionState& session,
const std::vector<WorkUnit*>& roots) {
// case where thread ptok exists but work_ was swapped
if (auto& tok = ptok(); UNLIKELY(!tok.valid())) {
moodycamel::ProducerToken tmp(*work_);
tok.swap(tmp);
}
const auto rootCount = roots.size();
if (UNLIKELY(rootCount == 0)) {
return;
} else if (LIKELY(rootCount > 1)) {
add(&session, roots.begin() + 1, roots.end());
}
execute_inline(&session, roots[0]);
session.wait();
}
void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) {
thread_local std::vector<WorkUnit*> newWorkUnits;
thread_local c10::InferenceMode mode;
WorkUnit* unit = this;
while (true) {
unit->kernel->compute(session->frame());
for (auto* user : unit->users) {
if (session->decrementProducers(user->node)) {
newWorkUnits.push_back(user);
}
}
switch (newWorkUnits.size()) {
case 0: {
return session->removeWork();
}
case 1: {
break;
}
case 2: {
executor->add(session, newWorkUnits[1]);
break;
}
default: {
executor->add(session, newWorkUnits.begin() + 1, newWorkUnits.end());
break;
}
}
unit = newWorkUnits[0];
newWorkUnits.clear();
}
}
ParallelGraphExecutor::ParallelGraphExecutor(
const Graph& graph,
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
const ExecutorConfig& executorConfig)
: GraphExecutorBase(graph, std::move(nodeKernels), executorConfig),
workUnits_(
graph.nodes().size() - 2 /* no need for prim.Input or Prim.Output */),
graph_(graph) {
auto& nodes = graph_.nodes();
auto input = &*nodes.begin();
auto output = &*nodes.rbegin();
{
// get rid of prim.Input and prim.Output kernels
// since we won't be needing them
nodeKernels_.erase(nodeKernels_.begin());
nodeKernels_.pop_back();
}
size_t idx = 0;
for (const auto& node : nodes) {
if (&node == input || &node == output) {
continue;
}
auto& workUnit =
nodeToWorkUnit_.insert_or_assign(&node, &workUnits_[idx]).first->second;
workUnit->node = &node;
workUnit->kernel = nodeKernels_[idx++].get();
producers_.insert({&node, 0});
}
for (auto& unit : workUnits_) {
for (const auto* dep : unit.node->users()) {
if (dep != output) {
unit.users.push_back(nodeToWorkUnit_[dep]);
producers_[dep] += 1;
}
}
}
for (auto& [node, p] : producers_) {
if (p == 0) {
inputWorkUnits_.push_back(nodeToWorkUnit_[node]);
}
}
executor_.start(executorConfig.maxParallelOps);
}
std::vector<c10::IValue> ParallelGraphExecutor::execute(
ExecutionFrame& executionFrame,
std::vector<c10::IValue> inputs) {
fillUserInputs(executionFrame, std::move(inputs));
return executeWithPrefilledFrame(executionFrame);
}
std::vector<c10::IValue> ParallelGraphExecutor::executeWithPrefilledFrame(
ExecutionFrame& executionFrame) {
auto session = SessionState(executionFrame, producers_);
executor_.run(session, inputWorkUnits_);
return executionFrame.getUserOutputs();
}
} // namespace torch::nativert

View File

@ -0,0 +1,94 @@
#pragma once
#include "torch/csrc/nativert/common/Semaphore.h"
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
#include "torch/csrc/nativert/executor/SessionState.h"
namespace moodycamel {
struct ProducerToken;
struct ConsumerToken;
struct ConcurrentQueueDefaultTraits;
template <typename T, typename Traits>
class ConcurrentQueue;
} // namespace moodycamel
namespace torch::nativert {
class ThreadPoolExecutor;
typedef std::function<void()> Work;
struct WorkUnit {
const Node* node;
OpKernel* kernel;
std::vector<WorkUnit*> users;
void run(ThreadPoolExecutor* executor, SessionState* sessionState);
};
class ThreadPoolExecutor {
public:
explicit ThreadPoolExecutor();
~ThreadPoolExecutor();
ThreadPoolExecutor(const ThreadPoolExecutor&) = delete;
ThreadPoolExecutor& operator=(ThreadPoolExecutor const&) = delete;
ThreadPoolExecutor(ThreadPoolExecutor&&) = delete;
ThreadPoolExecutor& operator=(ThreadPoolExecutor&&) = delete;
void run(SessionState& session, const std::vector<WorkUnit*>& roots);
void start(int32_t numThreads);
void stop();
// execute unit on the current thread
// NOTE: children can still be offloaded to other threads
C10_ALWAYS_INLINE void execute_inline(SessionState* session, WorkUnit* unit);
void add(SessionState* session, WorkUnit* unit);
void add(
SessionState* session,
std::vector<WorkUnit*>::const_iterator&& begin,
const std::vector<WorkUnit*>::const_iterator&& end);
C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok();
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok();
private:
void loop();
std::atomic_bool stopped_{false};
std::unique_ptr<Semaphore> sem_{std::make_unique<Semaphore>()};
std::unique_ptr<moodycamel::ConcurrentQueue<
Work,
moodycamel::ConcurrentQueueDefaultTraits>>
work_;
std::vector<std::thread> threads_;
};
class ParallelGraphExecutor : public GraphExecutorBase {
public:
ParallelGraphExecutor(
const Graph& graph,
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
const ExecutorConfig& executorConfig);
std::vector<c10::IValue> execute(
ExecutionFrame& frame,
std::vector<c10::IValue> inputs) override;
std::vector<c10::IValue> executeWithPrefilledFrame(
ExecutionFrame& frame) override;
private:
ThreadPoolExecutor executor_;
std::vector<WorkUnit*> inputWorkUnits_;
c10::FastMap<const Node*, WorkUnit*> nodeToWorkUnit_;
std::vector<WorkUnit> workUnits_;
const Graph& graph_;
c10::FastMap<const Node*, __copyable_atomic<std::uint_fast32_t>> producers_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,91 @@
#include "torch/csrc/nativert/executor/Placement.h"
#include <map>
namespace torch::nativert {
std::ostream& operator<<(std::ostream& os, const Placement& placement) {
std::map<std::string, c10::Device> keys;
for (const auto& pair : placement.deviceMap_) {
keys.insert({pair.first.str(), pair.first});
}
bool first = true;
auto checkComma = [&]() {
if (!first) {
os << ",";
}
first = false;
};
os << "";
for (const auto& pair : keys) {
checkComma();
const auto& key = pair.second;
const auto& value = placement.deviceMap_.at(key);
os << pair.first << "|" << value.str();
}
if (placement.defaultDevice_.has_value()) {
checkComma();
os << "|" << placement.defaultDevice_.value().str();
}
return os;
}
c10::Device normalizeDevice(const c10::Device& device) {
// cpu device doesn't have index
// cuda device index must have a index
if (device.is_cpu()) {
return c10::Device(c10::DeviceType::CPU);
} else if (device.is_cuda()) {
return c10::Device(
c10::DeviceType::CUDA, device.has_index() ? device.index() : 0);
} else {
TORCH_CHECK(false, "Unsupported device type", device);
}
}
bool isSameDevice(const c10::Device& a, const c10::Device& b) {
if (a.is_cpu()) {
return b.is_cpu();
}
if (a.is_cuda()) {
if (b.is_cuda()) {
auto aIndex = a.has_index() ? a.index() : 0;
auto bIndex = b.has_index() ? b.index() : 0;
return aIndex == bIndex;
} else {
return false;
}
}
TORCH_CHECK(false, "Unsupported device type", a, " and ", b);
return false;
}
Placement::Placement(std::optional<c10::Device> defaultDevice)
: Placement({}, defaultDevice) {}
Placement::Placement(
const std::unordered_map<c10::Device, c10::Device>& deviceMap,
std::optional<c10::Device> defaultDevice) {
for (const auto& [srcDevice, dstDevice] : deviceMap) {
deviceMap_.emplace(normalizeDevice(srcDevice), normalizeDevice(dstDevice));
}
if (defaultDevice.has_value()) {
defaultDevice_ = normalizeDevice(defaultDevice.value());
}
}
c10::Device Placement::getMappedDevice(const c10::Device& srcDevice) const {
auto it = deviceMap_.find(normalizeDevice(srcDevice));
if (it != deviceMap_.end()) {
return it->second;
}
if (defaultDevice_.has_value()) {
return defaultDevice_.value();
}
return srcDevice;
}
} // namespace torch::nativert

View File

@ -0,0 +1,29 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Logging.h>
#include <optional>
#include <unordered_map>
namespace torch::nativert {
c10::Device normalizeDevice(const c10::Device& device);
bool isSameDevice(const c10::Device& device1, const c10::Device& device2);
struct TORCH_API Placement {
Placement() = default;
explicit Placement(std::optional<c10::Device> defaultDevice);
explicit Placement(
const std::unordered_map<c10::Device, c10::Device>& deviceMap,
std::optional<c10::Device> defaultDevice = std::nullopt);
c10::Device getMappedDevice(const c10::Device& srcDevice) const;
friend std::ostream& operator<<(std::ostream& os, const Placement& obj);
protected:
std::unordered_map<c10::Device, c10::Device> deviceMap_;
std::optional<c10::Device> defaultDevice_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,35 @@
#include "torch/csrc/nativert/executor/SerialGraphExecutor.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/ExecutionPlanner.h"
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
namespace torch::nativert {
std::vector<c10::IValue> SerialGraphExecutor::execute(
ExecutionFrame& executionFrame,
std::vector<c10::IValue> inputs) {
fillUserInputs(executionFrame, std::move(inputs));
return executeWithPrefilledFrame(executionFrame);
}
std::vector<c10::IValue> SerialGraphExecutor::executeWithPrefilledFrame(
ExecutionFrame& executionFrame) {
// Execute kernels for all nodes except prim.Input and prim.Output
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
nodeKernels_[nodeIdx]->compute(executionFrame);
// don't free intermediate values when static memory planning is enabled
if (!executorConfig_.enableStaticMemoryPlanning) {
// Free the intermediate values that are no used anymore
for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) {
executionFrame.releaseValue(valueKey);
}
}
}
return executionFrame.getUserOutputs();
}
} // namespace torch::nativert

View File

@ -0,0 +1,23 @@
#pragma once
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
namespace torch::nativert {
class SerialGraphExecutor : public GraphExecutorBase {
public:
SerialGraphExecutor(
const Graph& graph,
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
const ExecutorConfig& executorConfig)
: GraphExecutorBase(graph, std::move(nodeKernels), executorConfig) {}
std::vector<c10::IValue> execute(
ExecutionFrame& frame,
std::vector<c10::IValue> inputs) override;
std::vector<c10::IValue> executeWithPrefilledFrame(
ExecutionFrame& frame) override;
};
} // namespace torch::nativert

View File

@ -0,0 +1,77 @@
#pragma once
#include <atomic>
#include <c10/macros/Macros.h>
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/graph/Graph.h"
namespace torch::nativert {
template <typename T, typename __atomic_base = std::atomic<T>>
struct __copyable_atomic : public __atomic_base {
public:
__copyable_atomic() = default;
__copyable_atomic(const T& t) noexcept(__atomic_base::is_always_lock_free)
: __atomic_base(t) {}
__copyable_atomic(const __copyable_atomic& other) noexcept(
__atomic_base::is_always_lock_free)
: __atomic_base(other.load()) {}
__copyable_atomic& operator=(const __copyable_atomic& other) noexcept(
__atomic_base::is_always_lock_free) {
this->store(other.load());
return this;
}
};
class SessionState {
public:
explicit SessionState(
ExecutionFrame& frame,
c10::FastMap<const Node*, __copyable_atomic<std::uint_fast32_t>>
producers = {})
: producers_(std::move(producers)), frame_(frame) {}
C10_ALWAYS_INLINE void wait() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&]() {
return workOutstanding_.load(std::memory_order_seq_cst) == 0;
});
}
C10_ALWAYS_INLINE void addWork(uint32_t ct = 1) {
workOutstanding_.fetch_add(ct, std::memory_order_seq_cst);
}
C10_ALWAYS_INLINE void removeWork() {
if (workOutstanding_.fetch_sub(1, std::memory_order_seq_cst) == 1) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.notify_one();
}
}
C10_ALWAYS_INLINE ExecutionFrame& frame() {
return frame_;
}
C10_ALWAYS_INLINE /* producersRemaining == 0 */ bool decrementProducers(
const Node* node) {
return producers_.at(node).fetch_sub(1, std::memory_order_seq_cst) == 1;
}
C10_ALWAYS_INLINE void setProducers(const Node* node, uint32_t v = 1) {
producers_[node] += v;
}
private:
std::atomic_uint_fast32_t workOutstanding_;
c10::FastMap<const Node*, __copyable_atomic<std::uint_fast32_t>> producers_;
std::condition_variable cv_;
std::mutex mutex_;
ExecutionFrame& frame_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,394 @@
#include "torch/csrc/nativert/executor/Weights.h"
#include <utility>
#include "c10/util/Logging.h"
#include <torch/csrc/jit/serialization/import.h> // @manual=//caffe2:torch-cpp-cpu
#include <torch/csrc/jit/serialization/import_read.h> // @manual=//caffe2:torch-cpp-cpu
#include "caffe2/serialize/inline_container.h" // @manual=//caffe2:torch-cpp-cpu
#include "torch/csrc/nativert/common/String.h"
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
namespace torch::nativert {
WeightVersion Weights::globalVersion_ = 0;
Weights::Weights(
const Graph* graph,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
stateDict,
const Placement& placement)
: graph_(graph),
weightsMeta_(graph->weightsMeta()),
placement_(placement),
version_(globalVersion_++) {
if (stateDict.has_value()) {
loadStateDict(stateDict.value());
}
}
Weights::Weights(
const Graph* graph,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
const std::unordered_map<std::string, std::string>& stateDictPaths,
std::string_view stateDictPathPrefix,
const std::unordered_map<std::string, std::string>& constantPaths,
std::string_view constantPathPrefix,
const Placement& placement,
std::function<bool(const std::string&)> skipSizeCheck,
std::function<bool(const std::string&)> skipDtypeCheck)
: graph_(graph),
weightsMeta_(graph->weightsMeta()),
placement_(placement),
version_(globalVersion_++),
skipSizeCheck_(std::move(skipSizeCheck)),
skipDtypeCheck_(std::move(skipDtypeCheck)) {
auto loadAndInsert =
[&](const std::string& tensorName,
std::string_view pathPrefix,
const std::unordered_map<std::string, std::string>& tensorPaths,
bool isUsed) {
auto pathIt = tensorPaths.find(tensorName);
CHECK(pathIt != tensorPaths.end())
<< "Couldn't find " << tensorName << " in tensorPaths";
const std::string tensorPath = std::string{pathPrefix} + pathIt->second;
VLOG(1) << "Loading weight from: " << tensorPath;
CHECK(pytorchStreamReader->hasRecord(tensorPath))
<< tensorPath << " not found";
auto [tensorData, tensorDataSize] =
pytorchStreamReader->getRecord(tensorPath);
// TODO: We now have two copies of metadata for weights, one in
// model definition /models/<model_name>.json, another in
// /extra/xl_weights/<model_name>_model_param_config.json
// Currently, we only use the metadata from model definition.
std::optional<TensorMeta> tensorMeta;
if (weightsMeta_.find(tensorName) != weightsMeta_.end()) {
tensorMeta = weightsMeta_.at(tensorName);
} else {
TORCH_CHECK(false, "Tensor meta not found for: ", tensorName);
}
if (tensorDataSize == 0 && tensorMeta->numel() > 0) {
VLOG(1) << "Tensor " << tensorName
<< " does not have data and create on Meta device";
allValues_[tensorName] = torch::empty_strided(
tensorMeta->sizes(),
tensorMeta->strides(),
tensorMeta->asTensorOptions().device(torch::kMeta));
return;
}
if (!isUsed) {
VLOG(1) << "Tensor " << tensorName << " is not used during inference";
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
allValues_[tensorName] =
at::scalar_tensor(0, at::TensorOptions().device(targetDevice));
return;
}
size_t bytesPerEntry =
c10::scalarTypeToTypeMeta(tensorMeta->dtype()).itemsize();
auto device = tensorData.device();
auto storage = c10::Storage(
c10::Storage::use_byte_size_t(),
at::detail::computeStorageNbytes(
tensorMeta->sizes(), tensorMeta->strides(), bytesPerEntry),
std::move(tensorData), // ownership is transferred
nullptr,
false);
const auto tensorOptions = at::TensorOptions(device)
.dtype(tensorMeta->dtype())
.requires_grad(false);
auto tensor =
at::empty({0}, tensorOptions)
.set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides());
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice;
if (!isSameDevice(targetDevice, tensor.device())) {
tensor = tensor.to(targetDevice);
}
allValues_[tensorName] = tensor;
};
auto loadAndInsertParamsBuffers = [&](const std::string& tensorName,
bool isUsed) {
return loadAndInsert(
tensorName, stateDictPathPrefix, stateDictPaths, isUsed);
};
size_t weightIndex = 0;
bool isUsed;
const auto& weightValues = graph->weightValues();
for (const auto& tensorName : graph->signature().parameters()) {
isUsed = (weightValues[weightIndex]->users().size() > 0);
if (!isUsed) {
unusedWeights_.insert(tensorName);
}
loadAndInsertParamsBuffers(tensorName, isUsed);
weightIndex++;
}
for (const auto& tensorName : graph->signature().buffers()) {
isUsed = (weightValues[weightIndex]->users().size() > 0);
if (!isUsed) {
unusedWeights_.insert(tensorName);
}
loadAndInsertParamsBuffers(tensorName, isUsed);
weightIndex++;
}
// Load tensor constants and custom object constants, they are both stored
// in the same directory in the archive, i.e. "extra/constants/" tensor
// constants are prefixed with "tensor_" custom objects are prefixed with
// "custom_obj_"
auto loadConstants = [&](const std::vector<std::string>& constants) {
for (const auto& constantName : constants) {
auto pathIt = constantPaths.find(constantName);
CHECK(pathIt != constantPaths.end())
<< "Couldn't find " << constantName << " in constantPaths";
auto& fileName = pathIt->second;
if (starts_with(
fileName, archive_spec::TENSOR_CONSTANT_FILENAME_PREFIX)) {
// tensor constants
isUsed = (weightValues[weightIndex]->users().size() > 0);
if (!isUsed) {
unusedWeights_.insert(constantName);
}
loadAndInsert(constantName, constantPathPrefix, constantPaths, isUsed);
weightIndex++;
} else {
TORCH_CHECK(false, "Unknown constant path: ", fileName);
}
}
};
loadConstants(graph->signature().nonPersistentBuffers());
loadConstants(graph->signature().tensorConstants());
// custom object constants
for (const auto& customObjName : graph->signature().customObjs()) {
auto pathIt = constantPaths.find(customObjName);
CHECK(pathIt != constantPaths.end())
<< "Couldn't find " << customObjName << " in constantPaths";
auto& fileName = pathIt->second;
if (!starts_with(fileName, archive_spec::CUSTOM_OBJ_FILENAME_PREFIX)) {
TORCH_CHECK(false, "Unknown constant path: ", fileName);
}
std::string customObjPath = std::string{constantPathPrefix} + fileName;
LOG(INFO) << "Loading custom object from: " << customObjPath;
CHECK(pytorchStreamReader->hasRecord(customObjPath))
<< customObjPath << " not found";
const auto& [customObjData, customObjDataSize] =
pytorchStreamReader->getRecord(customObjPath);
const char* customObjDataPtr =
reinterpret_cast<const char*>(customObjData.get());
std::string customObjBytes(
customObjDataPtr, customObjDataPtr + customObjDataSize);
c10::IValue customObj = torch::jit::pickle_load_obj(customObjBytes);
CHECK(customObj.isCustomClass());
CHECK(!customObj.isNone());
customObjs_[customObjName] = std::move(customObj);
customObjsPaths_[customObjPath] = customObjName;
}
}
std::unordered_map<std::string, at::Tensor> Weights::parameters() const {
std::unordered_map<std::string, at::Tensor> result;
for (const auto& name : graph_->signature().parameters()) {
result.emplace(name, allValues_.at(name));
}
return result;
}
std::unordered_map<std::string, at::Tensor> Weights::buffers() const {
std::unordered_map<std::string, at::Tensor> result;
for (const auto& name : graph_->signature().buffers()) {
result.emplace(name, allValues_.at(name));
}
return result;
}
std::unordered_map<std::string, at::Tensor> Weights::attributes() const {
return allValues_;
}
at::Tensor Weights::at(const std::string& name) const {
auto it = allValues_.find(name);
if (it != allValues_.end()) {
return it->second;
}
TORCH_CHECK(false, name, " not found in Weights ", toString());
}
at::Tensor& Weights::at(const std::string& name) {
auto it = allValues_.find(name);
if (it != allValues_.end()) {
return it->second;
}
TORCH_CHECK(false, name, " not found in Weights ", toString());
}
bool Weights::contains(const std::string& name) const {
return allValues_.find(name) != allValues_.end();
}
c10::IValue Weights::getCustomObj(const std::string& name) const {
auto it = customObjs_.find(name);
if (it != customObjs_.end()) {
return it->second;
}
TORCH_CHECK(false, "Custom objects ", name, " not found in Weights");
}
c10::IValue Weights::getCustomObjByFileName(const std::string& name) const {
auto it = customObjsPaths_.find(name);
TORCH_CHECK(
it != customObjsPaths_.end(),
"Custom objects with file name ",
name,
" not found in Weights");
const std::string obj_name = it->second;
return getCustomObj(obj_name);
}
void Weights::loadStateDict(
const std::unordered_map<std::string, c10::IValue>& stateDict) {
auto validateAndInsert = [&](const std::string& name) {
auto stateDictIt = stateDict.find(name);
CHECK(stateDictIt != stateDict.end())
<< "Couldn't find " << name << " in stateDict";
// Verify that the tensor matches the tensorMeta
auto it = weightsMeta_.find(name);
CHECK(it != weightsMeta_.end())
<< "Couldn't find " << name << " in weightsMeta";
auto targetDevice = placement_.getMappedDevice(it->second.device());
auto tensor = stateDictIt->second.toTensor().to(targetDevice);
CHECK(tensor.sizes() == it->second.sizes());
CHECK(tensor.dtype() == it->second.dtype());
allValues_.emplace(name, tensor);
};
for (const auto& name : graph_->signature().parameters()) {
validateAndInsert(name);
}
for (const auto& name : graph_->signature().buffers()) {
validateAndInsert(name);
}
// TensorConstants_ not filled !!
}
void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
const {
auto& weightMeta = weightsMeta_.at(name);
CHECK(
weightMeta.sizes() == newValue.sizes() ||
(skipSizeCheck_ && skipSizeCheck_(name)) ||
unusedWeights_.find(name) != unusedWeights_.end())
<< "Mismatched sizes for " << name << ": " << weightMeta.sizes() << " vs "
<< newValue.sizes();
CHECK(
weightMeta.dtype() == newValue.dtype() ||
(skipDtypeCheck_ && skipDtypeCheck_(name)) ||
unusedWeights_.find(name) != unusedWeights_.end())
<< "Mismatched dtype for " << name << ": " << weightMeta.dtype() << " vs "
<< newValue.dtype();
auto targetDevice = placement_.getMappedDevice(weightMeta.device());
if (targetDevice.is_cpu() && targetDevice.has_index()) {
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
}
CHECK(isSameDevice(targetDevice, newValue.device()))
<< "Mismatched device for " << name << ": " << targetDevice << " vs "
<< newValue.device();
}
void Weights::setValue(const std::string& name, const at::Tensor& newValue) {
if (allValues_.find(name) != allValues_.end()) {
validateValue(name, newValue);
} else {
LOG(WARNING) << name << " is not found in the registered weights";
}
allValues_[name] = newValue;
}
void Weights::updateValue(const std::string& name, const at::Tensor& newValue) {
auto it = allValues_.find(name);
CHECK(it != allValues_.end())
<< name << " not found in Weights " << toString();
validateValue(name, newValue);
it->second.copy_(newValue);
}
void Weights::updateValues(
const std::unordered_map<std::string, at::Tensor>& newValues) {
for (auto& [name, newValue] : newValues) {
updateValue(name, newValue);
}
}
std::string Weights::toString() const {
std::stringstream ss;
ss << "[";
for (const auto& [name, _] : allValues_) {
ss << name << ", ";
}
ss << "]";
ss << "[";
for (const auto& [name, _] : customObjs_) {
ss << name << ", ";
}
ss << "]";
return ss.str();
}
void Weights::validateAllWeightsLoaded() {
auto checkNames = [&](const std::vector<std::string>& names) {
for (const auto& name : names) {
if (unusedWeights_.find(name) != unusedWeights_.end()) {
continue;
}
auto it = allValues_.find(name);
CHECK(it != allValues_.end()) << "Missing weight: " << name;
CHECK(it->second.defined()) << "Weight not defined: " << name;
if (it->second.device().is_meta()) {
LOG(WARNING) << "Weight is on meta device: " << name;
}
}
};
checkNames(graph_->signature().parameters());
checkNames(graph_->signature().buffers());
checkNames(graph_->signature().nonPersistentBuffers());
checkNames(graph_->signature().tensorConstants());
}
void Weights::updateFoldedConst(std::string_view name, c10::IValue tensor) {
foldedConstsMap_[std::string{name}] = std::move(tensor);
}
const std::unordered_map<std::string, c10::IValue>& Weights::getFoldedConsts()
const {
return foldedConstsMap_;
}
} // namespace torch::nativert

View File

@ -0,0 +1,136 @@
#pragma once
#include <unordered_map>
#include <c10/util/Logging.h>
#include "torch/csrc/nativert/executor/Placement.h"
#include "torch/csrc/nativert/graph/Graph.h"
namespace torch::nativert {
class Graph;
class Value;
using WeightVersion = int;
class Weights {
public:
explicit Weights(
const Graph* graph,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
stateDict = std::nullopt,
const Placement& placement = Placement());
// Arguments
// - pytorchStreamReader: the reader for the model archive
// - stateDictPath: a map from parameter/buffer/constant name to file path in
// the archive
// - stateDictPathPrefix: a prefix that will be prepended to paths in
// stateDictPathPrefix
// - constantPaths: a map from constant name to file path in the archive
// - constantPathPrefix: a prefix that will be prepended to paths in
// constantPathPrefix
// - placement: the device placement of the weights, default to follow the
// original device in the weight's metadata
explicit Weights(
const Graph* graph,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader,
const std::unordered_map<std::string, std::string>& stateDictPaths,
std::string_view stateDictPathPrefix,
const std::unordered_map<std::string, std::string>& constantPaths,
std::string_view constantPathPrefix,
const Placement& placement = Placement(),
std::function<bool(const std::string&)> skipSizeCheck = {},
std::function<bool(const std::string&)> skipDtypeCheck = {});
at::Tensor at(const std::string& name) const;
at::Tensor& at(const std::string& name);
bool contains(const std::string& name) const;
c10::IValue getCustomObj(const std::string& name) const;
c10::IValue getCustomObjByFileName(const std::string& name) const;
std::unordered_map<std::string, at::Tensor> parameters() const;
std::unordered_map<std::string, at::Tensor> buffers() const;
std::unordered_map<std::string, at::Tensor> attributes() const;
void loadStateDict(
const std::unordered_map<std::string, c10::IValue>& stateDict);
/*
* Replace the value stored at the weight with name "name".
*/
void setValue(const std::string& name, const at::Tensor& newValue);
/*
* Update the value stored at the weight with name "name".
* This is done in-place.
*/
void updateValue(const std::string& name, const at::Tensor& newValue);
void updateValues(
const std::unordered_map<std::string, at::Tensor>& newValues);
void validateValue(const std::string& name, const at::Tensor& newValue) const;
void validateAllWeightsLoaded();
const std::unordered_map<std::string, c10::IValue>& getFoldedConsts() const;
C10_ALWAYS_INLINE const c10::FastMap<ValueId, c10::IValue>&
getConstFoldedValues() const {
return constFoldedValues_;
}
C10_ALWAYS_INLINE void setConstFoldedValue(const ValueId v, c10::IValue iv) {
constFoldedValues_.insert_or_assign(v, std::move(iv));
}
std::string toString() const;
WeightVersion version() const {
return version_;
}
private:
friend class Executor;
friend class ExecutionFrame;
void updateFoldedConst(std::string_view name, c10::IValue tensor);
const Graph* graph_;
const std::unordered_map<std::string, TensorMeta>& weightsMeta_;
Placement placement_;
// keys are parameter/buffer/constant names, not graph input names!
std::unordered_map<std::string, at::Tensor> allValues_;
std::unordered_map<std::string, c10::IValue> customObjs_;
// contains CustomClassHolder map from a file name to an arbitray
// key in customObjs_ that hold the loaded content of the file.
// This is used in AOTIDelegateExecutor.
std::unordered_map<std::string, std::string> customObjsPaths_;
// The liftcycle of folded consts should be tied with the weights from which
// it was derived. The ordering of the constant should be consistent with
// the output order of const graph.
std::vector<c10::IValue> foldedConsts_;
std::unordered_map<std::string, c10::IValue> foldedConstsMap_;
c10::FastMap<ValueId, c10::IValue> constFoldedValues_;
// unique version number for this instance of weight
const WeightVersion version_;
// every instance of Weight has a unique version number
static WeightVersion globalVersion_;
std::function<bool(const std::string&)> skipSizeCheck_ = {};
std::function<bool(const std::string&)> skipDtypeCheck_ = {};
// save the names of unused weights
std::unordered_set<std::string> unusedWeights_;
};
} // namespace torch::nativert

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,695 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <variant>
#include <vector>
#include <c10/util/Logging.h>
#include <torch/script.h>
#include "torch/csrc/nativert/common/IntrusiveList.h"
#include "torch/csrc/nativert/executor/Placement.h"
#include "torch/csrc/nativert/graph/GraphSignature.h"
#include "torch/csrc/nativert/graph/TensorMeta.h"
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
namespace torch::nativert {
using NodeIndex = size_t;
class Value;
class Type {
public:
enum Kind {
None,
Tensor,
TensorList,
OptionalTensorList,
SymInt,
SymIntList,
SymBool,
SymFloat,
CustomObj,
};
friend std::ostream& operator<<(std::ostream& out, const Type& ty);
/* implicit */ Type(Kind kind) : kind_(kind) {}
explicit Type(Kind kind, const std::string& classFqn)
: kind_(kind), classFqn_(classFqn) {
CHECK(kind == Kind::CustomObj);
CHECK(!classFqn.empty());
}
Kind kind() const {
return kind_;
}
std::string classFqn() const {
CHECK(kind_ == Kind::CustomObj)
<< "Only CustomObj type can have classFqn, got " << kind_;
return classFqn_;
}
private:
const Kind kind_;
const std::string classFqn_;
};
bool operator==(const Type& left, const Type& right);
// These are all the constant types that are allowed as attributes on Nodes.
struct None {};
// None always equals itself
inline bool operator==(const None&, const None&) {
return true;
}
using Constant = std::variant<
None,
int64_t,
std::vector<int64_t>,
double,
std::vector<double>,
std::string,
c10::ScalarType,
c10::MemoryFormat,
c10::Layout,
c10::Device,
bool,
std::vector<bool>,
std::vector<std::string>,
std::unique_ptr<Graph>>;
c10::IValue constantToIValue(const Constant& constant);
class Node;
/**
* Represents a single symbolic value (tensor/symint/list of them). Values are
* inputs and outputs of Nodes.
*/
using ValueId = int;
class Value {
public:
explicit Value(ValueId id, std::string name, Type t, Node* producer)
: name_(std::move(name)), id_(id), type_(t), producer_(producer) {
TORCH_CHECK_EQ(name_, this->name());
}
// Each Value should be uniquely created and managed by a Graph. It's an anti
// pattern to copy/move Value instances anyway.
Value(Value&&) = delete;
Value& operator=(Value&&) = delete;
Value(const Value&) = delete;
Value& operator=(Value&) = delete;
Type type() const {
return type_;
}
ValueId id() const {
return id_;
}
std::string_view name() const {
return name_;
}
const Node* producer(bool resolve_folded = false) const {
return !resolve_folded && isFolded() ? nullptr : producer_;
}
Node* producer() {
return producer_;
}
void addUser(Node* node);
void eraseUser(Node* node);
void eraseAllUsers() {
users_.clear();
}
// Fatals if the value is not a TensorList
std::vector<const Value*> getListElements() const;
const auto& users() const {
return users_;
}
auto& users() {
return users_;
}
void setId(ValueId newId) {
// This should only be used inside the renumberValues pass
id_ = newId;
}
void setIsFolded() {
isFolded_ = true;
}
bool isFolded() const {
return isFolded_;
}
private:
friend std::ostream& operator<<(std::ostream& out, const Value& v);
std::string name_;
bool isFolded_{false};
ValueId id_;
Type type_;
Node* producer_;
// All nodes which have this value as input.
// Note that this is a vector to avoid nondeterminism in iteration, but
// probably should be an unordered set given usage patterns. If this becomes a
// perf problem we should revise.
std::vector<Node*> users_;
};
struct NamedArgument {
std::string name;
Value* value;
};
struct Attribute {
std::string name;
Constant value;
};
class Graph;
/**
* Node represents a single unit of execution, typically a PyTorch operator.
*/
class Node : public IntrusiveListHook {
public:
Node(
Graph* owningGraph,
std::string target,
std::vector<NamedArgument> inputs,
std::unordered_map<std::string, std::string> metadata);
std::string_view target() const {
return target_;
}
void setTarget(std::string_view target) {
target_ = target;
}
const auto& inputs() const {
return inputs_;
}
auto& inputs() {
return inputs_;
}
// NOTE: this invalidates spans given out by inputs()
Value* addInput(NamedArgument input);
void addInputs(const std::vector<NamedArgument>& inputs);
// NOTE: this invalidates spans given out by attributes()
void addAttribute(Attribute attr);
// NOTE: this is ONLY for graph's constant inputs and NOT the common case
void addOutput(void);
Value* addOutput(Type type);
// NOTE: this invalidates spans given out by outputs()
Value* addOutput(std::string_view name, Type type);
size_t numInputs() const {
return inputs_.size();
}
size_t numOutputs() const {
return outputs_.size();
}
// Return the next node in the Graph's node ordering.
// NOTE: Calling next on the last node (prim.Output) returns nullptr.
Node* next();
const Node* next() const;
// Return the previous node in the Graph's node ordering.
// NOTE: Calling prev on the first node (prim.Input) returns nullptr.
Node* prev();
const Node* prev() const;
bool isBefore(const Node* n) const;
std::vector<Node*> producers() const;
std::vector<Node*> users() const;
// Returns nullptr if `name` is not an input
const NamedArgument* tryGetInput(std::string_view name) const;
// Fatals if `name` is not an input
const NamedArgument& getInput(std::string_view name) const;
const auto& attributes() const {
return attributes_;
}
// Returns nullptr if `name` is not an attribute
const Attribute* tryGetAttribute(std::string_view name) const;
// Fatals if `name` is not an attribute
const Attribute& getAttribute(std::string_view name) const;
const auto& outputs() const {
return outputs_;
}
void applyDevicePlacement(const Placement& placement);
std::optional<std::string_view> getMetadata(std::string_view key) const {
return metadata_.find(std::string{key}) != metadata_.end()
? std::optional(std::string_view{metadata_.at(std::string{key})})
: std::nullopt;
}
Graph* owningGraph() {
return owningGraph_;
}
const Graph* owningGraph() const {
return owningGraph_;
}
void destroy();
const std::unordered_map<std::string, std::string>& metadata() const {
return metadata_;
}
std::string toString() const {
std::stringstream ss;
ss << *this;
return ss.str();
}
void updateInputName(std::string_view oldName, std::string_view newName) {
for (auto& input : inputs_) {
if (input.name == oldName) {
input.name = newName;
}
}
}
void updateAttributeName(std::string_view oldName, std::string_view newName) {
for (auto& attr : attributes_) {
if (attr.name == oldName) {
attr.name = newName;
}
}
}
private:
friend std::ostream& operator<<(std::ostream& out, const Node& n);
Graph* owningGraph_;
// Target used to retrieve the actual thing to execute.
// If an aten operator, we expect this to be fully qualified, including an
// overload name, e.g. "aten.unsqueeze.default"
std::string target_;
// *Symbolic* inputs to this node. NOTE: this does not match the ATen operator
// schema inputs directly. It only represents things that actually participate
// in dataflow, like tensors/symints and lists thereof.
//
// The "name" of the NamedArgument refers to the name of the parameter.
std::vector<NamedArgument> inputs_;
// Constant inputs to the node. The "name" of the Attribute refers to the
// name of the parameter.
std::vector<Attribute> attributes_;
std::vector<Value*> outputs_;
// Extra bits of info added to the node. Contents that are guaranteed will be
// eventually moved to a first-class field on the thrift struct.
// Current contents:
// "stack_trace" => original Python source traceback
std::unordered_map<std::string, std::string> metadata_;
};
/**
* A representation of a model's computation graph. This representation is
* designed to facilicate transformation/analysis.
*
* Ownership semantics:
* - Graph owns Nodes and Values
* - Nodes own their constant attributes (which we treat as value types)
* - Nodes have non-owning pointers back to the graph.
*
* NOTE: this class is marked noncopyable/nonmovable and only can be
* heap-allocated via `createGraph()`. This is to ensure stability of
* back-pointers held by Nodes/Values.
*/
class Graph {
public:
static std::unique_ptr<Graph> createGraph() {
return std::unique_ptr<Graph>(new Graph());
}
Graph(const Graph&) = delete;
Graph& operator=(const Graph&) = delete;
Graph(Graph&&) = delete;
Graph& operator=(Graph&&) = delete;
~Graph() = default;
// NOTE: this invalidates spans given out by inputs()
Value* addInput(std::string_view name, Type type);
// NOTE: this is ONLY for graph's constant inputs and NOT the common case
void addInput(void);
// NOTE: this invalidates spans given out by outputs()
Value* addOutput(Value* v);
void addConstantOutput(Constant& c);
// Create and insert a node at insertionPoint_
Node* insertNode(
std::string target,
std::vector<NamedArgument> inputs = {},
std::unordered_map<std::string, std::string> metadata = {});
// Returns the inserted node.
Node* insertBefore(Node* toInsert, Node* insertionPoint);
// Returns the inserted node.
Node* insertAfter(Node* toInsert, Node* insertionPoint);
// Insert at the insertionPoint. Returns the inserted node.
Node* insert(Node* toInsert);
// Create a node without inserting it into the execution graph.
Node* createNode(
std::string target,
std::vector<NamedArgument> inputs = {},
std::unordered_map<std::string, std::string> metadata = {});
Value* createConstantSymIntValue(int value);
Node* createListPack(std::vector<Value*> inputs, Type inputType);
Node* createOptionalListPack(std::vector<Value*> inputs);
size_t numValues() const {
return values_.size();
}
// throws on missing name
Value* getValue(std::string_view name) const;
// returns nullptr on missing name
Value* tryGetValue(std::string_view name) const;
const std::unordered_map<ValueId, int> getConstantSymIntValues() const {
return constantSymIntValues_;
}
Value*
addValue(const std::optional<std::string>& name, Type type, Node* producer);
void removeValue(Value* value);
void replaceAllUses(Value* old, Value* replacement);
void replaceAllUsesAfterNode(Value* old, Value* replacement, Node* afterThis);
void removeNode(Node* node);
void applyDevicePlacement(const Placement& placement);
std::string getUniqueValueName();
ValueId getNextValueId() {
return uniqueValueId_++;
}
// NOTE: this range can be invalidated by mutations to the graph.
const auto& inputs() const {
return inputNode_->outputs();
}
c10::ArrayRef<const Value*> userInputs() const {
size_t offset = signature().inputsToWeights().size() +
signature().inputsToCustomObjs().size();
return {inputs().data() + offset, inputs().data() + inputs().size()};
}
c10::ArrayRef<const Value*> weightValues() const {
return {
inputs().data(),
inputs().data() + signature().inputsToWeights().size()};
}
// Return a bidirectional range over `const Value*`
// NOTE: this range can be invalidated by mutations to the graph.
auto outputs() const {
std::vector<const Value*> ret;
ret.reserve(outputNode_->inputs().size());
for (const auto& namedArg : outputNode_->inputs()) {
ret.push_back(namedArg.value);
}
return ret;
}
// Return a bidirectional range over `Value*`
// NOTE: this range can be invalidated by mutations to the graph.
auto outputs() {
std::vector<Value*> ret;
ret.reserve(outputNode_->inputs().size());
for (const auto& namedArg : outputNode_->inputs()) {
ret.push_back(namedArg.value);
}
return ret;
}
const auto& userOutputs() const {
return userOutputs_;
}
// Return a list over `const Node&`.
// NOTE: this can be invalidated by mutations to the graph.
const auto& nodes() const {
return nodes_;
}
auto& nodes() {
return nodes_;
}
// Return a forward range over `const Value*`.
// NOTE: this range can be invalidated by mutations to the graph.
auto values() const {
std::vector<const Value*> ret;
ret.reserve(values_.size());
for (const auto& [_, value] : values_) {
ret.push_back(value.get());
}
return ret;
}
Node* inputNode() {
return inputNode_;
}
Node* outputNode() {
return outputNode_;
}
const Node* outputNode() const {
return outputNode_;
}
// Assert various graph invariants
void lint() const;
void cleanupDeadNodes();
void finalize();
Node* insertionPoint() {
// This should never happen, since the last-most insertion point is the
// prim.Outputs node, not end().
CHECK(insertBefore_ != nodes_.end());
auto& node = *insertBefore_;
return &node;
}
void setInsertionPoint(Node* n) {
CHECK(n != inputNode_) << "can't insert before prim.Input";
insertBefore_ = nodes_.iterator_to(*n);
}
void setInsertionPointAfter(Node* n) {
CHECK(n != outputNode_) << "can't insert after prim.Output";
auto it = nodes_.iterator_to(*n);
++it;
insertBefore_ = it;
}
// Return the next node in the Graph's node ordering.
// NOTE: Calling on the last node (prim.Output) returns nullptr.
Node* nodeAfter(Node* n);
const Node* nodeAfter(const Node* n) const;
// Return the previous node in the Graph's node ordering.
// NOTE: Calling on the first node (prim.Input) returns nullptr.
Node* nodeBefore(Node* n);
const Node* nodeBefore(const Node* n) const;
// Clone each node from subgraph (except prim.Input/prim.Output) into current
// graph.
// @param subgraph: the subgraph to be cloned
// @param inputs: values from the target graph that will serve as the
// subgraph's inputs
// @param valueMap: a map from the cloned subgraph's values to the target
// graph's values
std::vector<Value*> insertGraph(
const Graph& subgraph,
std::vector<Value*> inputs,
std::unordered_map<const Value*, Value*>& valueMap);
const GraphSignature& signature() const {
return signature_;
}
void setSignature(GraphSignature signature) {
signature_ = std::move(signature);
}
void setWeightsMeta(
const std::unordered_map<std::string, torch::_export::TensorMeta>&
tensorsMeta) {
for (auto [name, tensorMeta] : tensorsMeta) {
weightsMeta_.emplace(name, TensorMeta{tensorMeta});
}
}
const std::unordered_map<std::string, TensorMeta>& weightsMeta() const {
return weightsMeta_;
}
std::vector<TensorMeta> userInputsMeta() const {
std::vector<TensorMeta> userInputsMeta;
for (auto inputName : signature_.userInputs()) {
userInputsMeta.push_back(tensorValuesMeta_.at(inputName));
}
return userInputsMeta;
}
void setTensorValuesMeta(
const std::unordered_map<std::string, torch::_export::TensorMeta>&
tensorsMeta) {
for (auto [name, tensorMeta] : tensorsMeta) {
tensorValuesMeta_.emplace(name, TensorMeta{tensorMeta});
}
}
const std::unordered_map<std::string, TensorMeta>& tensorValuesMeta() const {
return tensorValuesMeta_;
}
std::string toString() const {
std::stringstream ss;
ss << *this;
return ss.str();
}
/* Reassigns IDs to every Value in this Graph so that they are contiguous from
* 0..(numValues()-1). Should be used after values are removed
*/
void renumberValues();
private:
Graph();
friend std::ostream& operator<<(std::ostream& out, const Graph& g);
GraphSignature signature_;
// keys are parameters, buffers, tensor_constants' names
std::unordered_map<std::string, TensorMeta> weightsMeta_;
// keys are tensor_values' names
std::unordered_map<std::string, TensorMeta> tensorValuesMeta_;
// Node lifetime is managed by nodesOwner_, but the actual ordering is
// maintained intrusively using nodes_.
// This is to facilitate quick insertion before/after a given Node*.
std::vector<std::unique_ptr<Node>> nodesOwner_;
IntrusiveList<Node> nodes_;
// The current insertion point. New nodes are inserted before this node.
// Defaults to prim.Output.
IntrusiveList<Node>::iterator insertBefore_;
// Graphs always start with an input and output node.
// "prim.input() -> Value[]" take no input, and produces some outputs. AKA
// "source“ of a graph.
Node* inputNode_; // target: prim.Input
// "prim.output(Value[]) -> None", take some inputs, but produce no output.
// AKA "sink" of a graph.
Node* outputNode_; // target: prim.Output
std::unordered_map<std::string, std::unique_ptr<Value>> values_;
// constantSymIntValues_ is a subset of values_
std::unordered_map<ValueId, int> constantSymIntValues_;
// Output values of the graph, which is a subset of values_.
std::vector<std::variant<Value*, Constant>> userOutputs_;
// Output constant values of the graph
std::vector<Constant> constantOutputs_;
size_t uniqueValueName_ = 0;
ValueId uniqueValueId_ = 0;
};
/**
* Scoped utility class for setting temporary insertion points.
*
* Use like:
* {
* InsertingAfter guard(node)
* graph.insertNode(...) // this will be inserted after `node`.
* }
*/
class InsertingAfter {
public:
explicit InsertingAfter(Node* n)
: insertAfter_(n), prev_(n->owningGraph()->insertionPoint()) {
insertAfter_->owningGraph()->setInsertionPointAfter(insertAfter_);
}
~InsertingAfter() {
insertAfter_->owningGraph()->setInsertionPoint(prev_);
}
private:
Node* insertAfter_;
Node* prev_;
};
inline constexpr std::string_view kMemoryFormatPrefix = "MemoryFormat::";
inline constexpr std::string_view kLayoutPrefix = "Layout::";
inline constexpr std::string_view kDevicePrefix = "Device";
inline constexpr std::string_view kScalarTypePrefix = "ScalarType::";
/**
* Debug format serialization. The format here is intended to be human readable
* and easy to work with, and is intended for debugging and testing only.
* If you want stable serialization, use the thrift conversion utils.
*
* NOTE: node metadata currently not serialized
*/
std::string graphToString(const Graph& g, bool include_signature = false);
std::unique_ptr<Graph> stringToGraph(std::string_view source);
// Standalone functions to parse common constructs
// Parse something that looks like `Device{cuda:1}` to a thrift device.
c10::Device convertDevice(std::string_view symbol);
// We have separate functions for parsing atomic and list constants because
// there are restrictive rules about which constants can go in lists (i.e.
// it's not recursive).
Constant convertAtomicConstant(std::string_view symbol);
Constant convertListConstant(std::string_view symbol);
} // namespace torch::nativert

View File

@ -0,0 +1,162 @@
#include "torch/csrc/nativert/graph/GraphPasses.h"
#include <fmt/format.h>
#include "torch/csrc/nativert/common/String.h"
namespace torch::nativert {
namespace {
bool isScalar(const Constant& c) {
return std::holds_alternative<int64_t>(c) ||
std::holds_alternative<double>(c);
}
bool isScalar(const Value& v) {
return v.type() == Type::SymInt || v.type() == Type::SymFloat;
}
bool schemaTypeMatch(const c10::FunctionSchema& schema, const Node& node) {
for (const auto& input : node.inputs()) {
// The number of arguments is always O(10), so we can just do a linear scan.
for (const auto& schemaArg : schema.arguments()) {
if (schemaArg.name() == input.name) {
if (schemaArg.type() == c10::TensorType::get() && input.value &&
isScalar(*input.value)) {
return false;
}
break;
}
}
}
for (const auto& constant : node.attributes()) {
for (const auto& schemaArg : schema.arguments()) {
if (schemaArg.name() == constant.name) {
if (schemaArg.type() == c10::TensorType::get() &&
isScalar(constant.value)) {
return false;
}
break;
}
}
}
return true;
}
} // namespace
// PT2 intentionally broadcast things like aten.sub.Scalar
// to aten.sub.Tensor. https://github.com/pytorch/pytorch/issues/90923.
std::string selectScalarOverloadName(const Node& node) {
// Copied from torch/csrc/utils/python_arg_parser.cpp
// torch::should_allow_numbers_as_tensors() to workaround
// some linking issues.
static std::unordered_set<std::string> allowed = {
"add",
"add_",
"add_out",
"div",
"div_",
"div_out",
"divide",
"divide_",
"divide_out", // alias of div
"mul",
"mul_",
"mul_out",
"multiply",
"multiply_",
"multiply_out", // alias of mul
"sub",
"sub_",
"sub_out",
"subtract",
"subtract_",
"subtract_out", // alias of sub
"true_divide",
"true_divide_",
"true_divide_out",
"to",
"_to_copy",
"copy_",
"copy",
"floor_divide",
"floor_divide_",
"floor_divide_out",
"_conj"};
std::vector<std::string_view> atoms = split(node.target(), '.');
TORCH_CHECK_GE(atoms.size(), 3);
std::string ns = std::string{atoms[atoms.size() - 3]};
std::string opName = std::string{atoms[atoms.size() - 2]};
std::string overloadName = std::string{atoms[atoms.size() - 1]};
if (overloadName != "Tensor" && overloadName != "Tensor_Tensor") {
return overloadName;
}
if (allowed.find(std::string{opName}) == allowed.end()) {
return overloadName;
}
auto op = c10::Dispatcher::singleton().findSchemaOrThrow(
fmt::format("{}::{}", ns, opName.c_str()).c_str(), overloadName.c_str());
if (schemaTypeMatch(op.schema(), node)) {
return overloadName;
}
for (const auto& variant : {"Scalar", "Scalar_Tensor", "Tensor_Scalar"}) {
if (auto schema = c10::Dispatcher::singleton().findSchema(
{fmt::format("{}::{}", ns, opName.c_str()).c_str(), variant})) {
if (schemaTypeMatch(schema->schema(), node)) {
return variant;
}
}
}
return overloadName;
}
void selectScalarOverload(Graph* graph) {
for (auto& node : graph->nodes()) {
for (auto& attr : node.attributes()) {
if (std::holds_alternative<std::unique_ptr<Graph>>(attr.value)) {
selectScalarOverload(
std::get<std::unique_ptr<Graph>>(attr.value).get());
}
}
auto target = node.target();
std::vector<std::string_view> atoms = split(target, '.');
size_t numAtoms = atoms.size();
if (numAtoms != 5) {
continue;
}
const std::string_view ns = atoms[numAtoms - 3];
const std::string_view opName = atoms[numAtoms - 2];
if (atoms[0] != "torch" || atoms[1] != "ops" || ns != "aten") {
continue;
}
auto overloadName = selectScalarOverloadName(node);
if (overloadName != atoms[numAtoms - 1]) {
node.setTarget(
fmt::format("torch.ops.{}.{}.{}", ns, opName, overloadName));
} else if (ns == "aten" && opName == "sub" && overloadName == "Tensor") {
// Special case for aten.sub.Tensor.
if (auto i = node.tryGetInput("self")) {
if (isScalar(*i->value)) {
node.updateInputName("self", "other");
node.updateInputName("other", "self");
node.setTarget("torch.ops.aten.rsub.Scalar");
}
}
if (auto a = node.tryGetAttribute("self")) {
if (isScalar(a->value)) {
node.updateAttributeName("self", "other");
node.updateInputName("other", "self");
node.setTarget("torch.ops.aten.rsub.Scalar");
}
}
}
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,11 @@
#pragma once
#include "torch/csrc/nativert/graph/Graph.h"
namespace torch::nativert {
void selectScalarOverload(Graph* graph);
std::string selectScalarOverloadName(const Node& node);
} // namespace torch::nativert

View File

@ -0,0 +1,450 @@
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <nlohmann/json.hpp>
#include "c10/util/Exception.h"
#include "torch/csrc/nativert/graph/GraphSignature.h"
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
namespace torch::nativert {
namespace {
bool isSymbolicOutput(torch::_export::Argument::Tag t) {
switch (t) {
case torch::_export::Argument::Tag::AS_TENSOR:
case torch::_export::Argument::Tag::AS_TENSORS:
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
case torch::_export::Argument::Tag::AS_SYM_BOOL:
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
case torch::_export::Argument::Tag::AS_SYM_INT:
case torch::_export::Argument::Tag::AS_SYM_INTS:
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
return true;
default:
return false;
}
}
std::pair<std::string, std::string> getSpecDetails(
const torch::_export::InputSpec& inputSpec) {
// Retrieve the argument name and spec tag name
std::string argName;
std::string tagName;
switch (inputSpec.tag()) {
case torch::_export::InputSpec::Tag::PARAMETER:
argName = inputSpec.get_parameter().get_arg().get_name();
tagName = "PARAMETER";
break;
case torch::_export::InputSpec::Tag::BUFFER:
argName = inputSpec.get_buffer().get_arg().get_name();
tagName = "BUFFER";
break;
case torch::_export::InputSpec::Tag::TENSOR_CONSTANT:
argName = inputSpec.get_tensor_constant().get_arg().get_name();
tagName = "TENSOR_CONSTANT";
break;
case torch::_export::InputSpec::Tag::CUSTOM_OBJ:
tagName = "CUSTOM_OBJ";
argName = inputSpec.get_custom_obj().get_arg().get_name();
break;
case torch::_export::InputSpec::Tag::USER_INPUT:
tagName = "USER_INPUT";
if (inputSpec.get_user_input().get_arg().tag() ==
torch::_export::Argument::Tag::AS_TENSOR) {
argName =
inputSpec.get_user_input().get_arg().get_as_tensor().get_name();
} else if (
inputSpec.get_user_input().get_arg().tag() ==
torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
argName =
inputSpec.get_user_input().get_arg().get_as_custom_obj().get_name();
} else {
throw std::runtime_error("Unsupported USER_INPUT argument type.");
}
break;
case torch::_export::InputSpec::Tag::CONSTANT_INPUT:
argName = inputSpec.get_constant_input().get_name();
tagName = "CONSTANT_INPUT";
break;
case torch::_export::InputSpec::Tag::TOKEN:
throw std::runtime_error("Token inputs not implemented yet");
default:
throw std::runtime_error("Unknown InputSpec tag encountered.");
}
return std::make_pair(argName, tagName);
}
void checkInputOrders(
const std::vector<torch::_export::InputSpec>& inputSpecs) {
// Map each tag to its index in the expected order
std::unordered_map<torch::_export::InputSpec::Tag, size_t> tagOrderMap = {
{torch::_export::InputSpec::Tag::TOKEN, 0},
{torch::_export::InputSpec::Tag::PARAMETER, 1},
{torch::_export::InputSpec::Tag::BUFFER, 2},
{torch::_export::InputSpec::Tag::TENSOR_CONSTANT, 3},
{torch::_export::InputSpec::Tag::CUSTOM_OBJ, 4}};
size_t currentOrderIndex = 0;
bool seenNonPersistentBuffer = false;
for (const auto& inputSpec : inputSpecs) {
if (inputSpec.tag() == torch::_export::InputSpec::Tag::USER_INPUT ||
inputSpec.tag() == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
continue;
}
auto it = tagOrderMap.find(inputSpec.tag());
if (it == tagOrderMap.end()) {
throw std::runtime_error("Unknown InputSpec tag encountered.");
}
size_t tagIndex = it->second;
if (tagIndex < currentOrderIndex) {
auto [argName, tagName] = getSpecDetails(inputSpec);
throw std::runtime_error(fmt::format(
"Input arg {} with InputSpec {} is out of order!", argName, tagName));
}
currentOrderIndex = tagIndex;
// Additional check for buffers
if (inputSpec.tag() == torch::_export::InputSpec::Tag::BUFFER) {
if (!inputSpec.get_buffer().get_persistent()) {
seenNonPersistentBuffer = true;
} else if (
inputSpec.get_buffer().get_persistent() && seenNonPersistentBuffer) {
throw std::runtime_error(
"Persistent buffer found after a non-persistent buffer. "
"Persistent buffers must come before non-persistent buffers.");
}
}
}
}
void checkInputNames(
const std::set<std::string>& sigNames,
const std::set<std::string>& graphNames) {
if (sigNames == graphNames) {
return;
}
std::string errorMsg =
"Error: Value name difference detected between graph signature and graph nodes:\n";
errorMsg += "Signature value names:\n";
errorMsg += fmt::format("[{}]\n", fmt::join(sigNames, ", "));
errorMsg += "Graph node names:\n";
errorMsg += fmt::format("[{}]", fmt::join(graphNames, ", "));
LOG(FATAL) << errorMsg;
};
void checkOutputNames(
const std::set<std::optional<std::string>>& sigNames,
const std::set<std::string>& graphNames) {
std::vector<std::string> validNames;
for (const auto& nameOpt : sigNames) {
if (nameOpt.has_value()) {
validNames.push_back(*nameOpt);
}
}
for (const auto& name : validNames) {
if (graphNames.find(name) == graphNames.end()) {
std::string errorMsg =
"Error: Value name difference detected between graph signature and graph nodes:\n";
errorMsg += "Signature value names:\n";
errorMsg += fmt::format("[{}]\n", fmt::join(validNames, ", "));
errorMsg += "Graph node names:\n";
errorMsg += fmt::format("[{}]", fmt::join(graphNames, ", "));
LOG(FATAL) << errorMsg;
}
}
};
void replaceInMap(
std::unordered_map<std::string, std::string>& map,
std::string_view old,
std::string_view replacement) {
auto it = map.find(std::string{old});
if (it == map.end()) {
return;
}
std::string value = it->second;
map.erase(it);
map.emplace(replacement, value);
}
} // namespace
GraphSignature::GraphSignature(const torch::_export::GraphSignature& storage) {
checkInputOrders(storage.get_input_specs());
for (const torch::_export::InputSpec& inputSpec : storage.get_input_specs()) {
switch (inputSpec.tag()) {
case torch::_export::InputSpec::Tag::USER_INPUT: {
if (inputSpec.get_user_input().get_arg().tag() ==
torch::_export::Argument::Tag::AS_TENSOR) {
userInputs_.push_back(
inputSpec.get_user_input().get_arg().get_as_tensor().get_name());
} else if (
inputSpec.get_user_input().get_arg().tag() ==
torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
userInputs_.push_back(inputSpec.get_user_input()
.get_arg()
.get_as_custom_obj()
.get_name());
} else {
// TODO: handle other types
LOG(FATAL) << "Non tensor inputs nyi";
}
break;
}
case torch::_export::InputSpec::Tag::PARAMETER: {
parameters_.push_back(inputSpec.get_parameter().get_parameter_name());
const auto& inputName = inputSpec.get_parameter().get_arg().get_name();
const auto& weightName = inputSpec.get_parameter().get_parameter_name();
inputsToParameters_.emplace(inputName, weightName);
inputsToWeights_.emplace_back(inputName, weightName);
break;
}
case torch::_export::InputSpec::Tag::BUFFER: {
const bool isPersistent = inputSpec.get_buffer().get_persistent();
const std::string& bufferName =
inputSpec.get_buffer().get_buffer_name();
if (isPersistent) {
buffers_.push_back(bufferName);
} else {
nonPersistentBuffers_.push_back(bufferName);
}
const auto& inputName = inputSpec.get_buffer().get_arg().get_name();
const auto& weightName = inputSpec.get_buffer().get_buffer_name();
inputsToBuffers_.emplace(inputName, weightName);
inputsToWeights_.emplace_back(inputName, weightName);
break;
}
case torch::_export::InputSpec::Tag::TENSOR_CONSTANT: {
tensorConstants_.push_back(
inputSpec.get_tensor_constant().get_tensor_constant_name());
const auto& inputName =
inputSpec.get_tensor_constant().get_arg().get_name();
const auto& weightName =
inputSpec.get_tensor_constant().get_tensor_constant_name();
inputsToTensorConstants_.emplace(inputName, weightName);
inputsToWeights_.emplace_back(inputName, weightName);
break;
}
case torch::_export::InputSpec::Tag::CUSTOM_OBJ: {
customObjs_.push_back(inputSpec.get_custom_obj().get_custom_obj_name());
inputsToCustomObjs_.insert(
{inputSpec.get_custom_obj().get_arg().get_name(),
inputSpec.get_custom_obj().get_custom_obj_name()});
break;
}
case torch::_export::InputSpec::Tag::CONSTANT_INPUT: {
constantInputs_.push_back(inputSpec.get_constant_input().get_name());
break;
}
case torch::_export::InputSpec::Tag::TOKEN: {
throw std::runtime_error("Token inputs not implemented yet");
}
default:
LOG(FATAL) << "Got empty thrift argument";
break;
}
}
std::string lossOutput;
for (const torch::_export::OutputSpec& outputSpec :
storage.get_output_specs()) {
switch (outputSpec.tag()) {
case torch::_export::OutputSpec::Tag::LOSS_OUTPUT:
lossOutput_ = outputSpec.get_loss_output().get_arg().get_name();
break;
case torch::_export::OutputSpec::Tag::USER_OUTPUT:
if (isSymbolicOutput(outputSpec.get_user_output().get_arg().tag())) {
switch (outputSpec.get_user_output().get_arg().tag()) {
case torch::_export::Argument::Tag::AS_TENSOR: {
userOutputs_.emplace_back(outputSpec.get_user_output()
.get_arg()
.get_as_tensor()
.get_name());
break;
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
userOutputs_.emplace_back(outputSpec.get_user_output()
.get_arg()
.get_as_sym_int()
.get_as_name());
break;
}
default: {
LOG(FATAL) << "Unsupported symbolic user output type ";
}
}
} else {
// for constant outputs, we don't have a name
userOutputs_.emplace_back(std::nullopt);
}
break;
case torch::_export::OutputSpec::Tag::BUFFER_MUTATION:
buffersToMutate_.insert(
{outputSpec.get_buffer_mutation().get_arg().get_name(),
outputSpec.get_buffer_mutation().get_buffer_name()});
break;
case torch::_export::OutputSpec::Tag::GRADIENT_TO_PARAMETER:
gradientsToParameters_.insert(
{outputSpec.get_gradient_to_parameter().get_arg().get_name(),
outputSpec.get_gradient_to_parameter().get_parameter_name()});
break;
case torch::_export::OutputSpec::Tag::GRADIENT_TO_USER_INPUT:
gradientsToUserInputs_.insert(
{outputSpec.get_gradient_to_user_input().get_arg().get_name(),
outputSpec.get_gradient_to_user_input().get_user_input_name()});
break;
case torch::_export::OutputSpec::Tag::USER_INPUT_MUTATION:
userInputsToMutate_.insert(
{outputSpec.get_user_input_mutation().get_arg().get_name(),
outputSpec.get_user_input_mutation().get_user_input_name()});
break;
case torch::_export::OutputSpec::Tag::TOKEN: {
throw std::runtime_error("Token outputs not implemented yet");
}
default:
LOG(FATAL) << "Got empty thrift argument";
}
}
auto keys_of = [&](const std::unordered_map<std::string, std::string>& dict) {
std::vector<std::string_view> keys;
keys.reserve(dict.size());
for (const auto& [key, _] : dict) {
keys.emplace_back(key);
}
return keys;
};
auto get_valid = [&](const std::vector<std::optional<std::string>>& outputs) {
std::vector<std::string> validOutputs;
for (const auto& output : outputs) {
if (output.has_value()) {
validOutputs.push_back(*output);
} else {
validOutputs.emplace_back("Constant");
}
}
return validOutputs;
};
VLOG(1) << fmt::format("[{}]", fmt::join(userInputs_, ", "));
VLOG(1) << fmt::format("[{}]", fmt::join(keys_of(inputsToParameters_), ", "));
VLOG(1) << fmt::format("[{}]", fmt::join(keys_of(inputsToBuffers_), ", "));
VLOG(1) << fmt::format(
"[{}]", fmt::join(keys_of(inputsToTensorConstants_), ", "));
VLOG(1) << fmt::format("[{}]", fmt::join(get_valid(userOutputs_), ", "));
VLOG(1) << fmt::format("[{}]", fmt::join(keys_of(buffersToMutate_), ", "));
VLOG(1) << fmt::format(
"[{}]", fmt::join(keys_of(gradientsToParameters_), ", "));
VLOG(1) << fmt::format(
"[{}]", fmt::join(keys_of(gradientsToUserInputs_), ", "));
}
std::set<std::string> GraphSignature::inputNames() const {
std::set<std::string> ret;
for (const auto& name : userInputs()) {
ret.insert(name);
}
for (const auto& [inputName, _] : inputsToParameters()) {
ret.insert(inputName);
}
for (const auto& [inputName, _] : inputsToBuffers()) {
ret.insert(inputName);
}
for (const auto& [inputName, _] : inputsToTensorConstants()) {
ret.insert(inputName);
}
for (const auto& [inputName, _] : inputsToCustomObjs()) {
ret.insert(inputName);
}
return ret;
}
std::set<std::optional<std::string>> GraphSignature::outputNames() const {
std::set<std::optional<std::string>> ret;
for (const auto& name : userOutputs()) {
ret.insert(name);
}
for (const auto& [outputName, _] : buffersToMutate()) {
ret.insert(outputName);
}
for (const auto& [outputName, _] : userInputsToMutate()) {
ret.insert(outputName);
}
if (hasBackward()) {
if (!gradientsToParameters().empty()) {
for (const auto& [outputName, _] : gradientsToParameters()) {
ret.insert(outputName);
}
}
if (!gradientsToUserInputs().empty()) {
for (const auto& [outputName, _] : gradientsToUserInputs()) {
ret.insert(outputName);
}
}
if (!lossOutput().empty()) {
ret.insert(lossOutput());
}
}
return ret;
}
void GraphSignature::lint(
const std::set<std::string>& graphInputs,
const std::set<std::string>& graphOutputs) const {
checkInputNames(inputNames(), graphInputs);
checkOutputNames(outputNames(), graphOutputs);
}
void GraphSignature::replaceAllUses(
std::string_view old,
std::string_view replacement) {
if (old == replacement) {
return;
}
for (auto& name : userOutputs_) {
if (name == old) {
name = replacement;
}
}
replaceInMap(buffersToMutate_, old, replacement);
if (hasBackward()) {
replaceInMap(gradientsToParameters_, old, replacement);
replaceInMap(gradientsToUserInputs_, old, replacement);
if (old == lossOutput_) {
lossOutput_ = replacement;
}
}
}
std::ostream& operator<<(std::ostream& out, const GraphSignature& sig) {
out << "inputsToParameters: {\n";
for (const auto& [inputName, paramName] : sig.inputsToParameters()) {
out << "\t" << inputName << " : " << paramName << std::endl;
}
out << "}\n";
out << "inputsToBuffers: {\n";
for (const auto& [inputName, bufferName] : sig.inputsToBuffers()) {
out << "\t" << inputName << " : " << bufferName << std::endl;
}
out << "}\n";
out << "inputsToTensorConstants: {\n";
for (const auto& [inputName, tensorConstantName] :
sig.inputsToTensorConstants()) {
out << "\t" << inputName << " : " << tensorConstantName << std::endl;
}
out << "}\n";
return out;
}
} // namespace torch::nativert

View File

@ -0,0 +1,141 @@
#pragma once
#include <set>
#include <string>
#include <c10/util/Logging.h>
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
namespace torch::nativert {
class Graph;
// In memory representation of graph signature.
class GraphSignature {
public:
GraphSignature() = default;
explicit GraphSignature(const torch::_export::GraphSignature& storage);
const auto& lossOutput() const {
return lossOutput_;
}
const auto& gradientsToParameters() const {
return gradientsToParameters_;
}
const auto& gradientsToUserInputs() const {
return gradientsToUserInputs_;
}
const auto& inputsToParameters() const {
return inputsToParameters_;
}
const auto& inputsToBuffers() const {
return inputsToBuffers_;
}
const auto& inputsToTensorConstants() const {
return inputsToTensorConstants_;
}
const auto& inputsToCustomObjs() const {
return inputsToCustomObjs_;
}
const auto& userInputsVec() const {
return userInputs_;
}
const auto& parameters() const {
return parameters_;
}
const auto& buffers() const {
return buffers_;
}
const auto& nonPersistentBuffers() const {
return nonPersistentBuffers_;
}
const auto& tensorConstants() const {
return tensorConstants_;
}
const auto& customObjs() const {
return customObjs_;
}
const auto& userInputs() const {
return userInputs_;
}
const auto& constantInputs() const {
return constantInputs_;
}
const auto& userOutputs() const {
return userOutputs_;
}
const auto& buffersToMutate() const {
return buffersToMutate_;
}
const auto& userInputsToMutate() const {
return userInputsToMutate_;
}
bool hasBackward() const {
return !(
lossOutput_.empty() && gradientsToParameters_.empty() &&
gradientsToUserInputs_.empty() && buffersToMutate_.empty());
}
// Mapping of FQNs to weights with stable iteration order.
const auto& inputsToWeights() const {
return inputsToWeights_;
}
void lint(
const std::set<std::string>& graphInputs,
const std::set<std::string>& graphOutputs) const;
void replaceAllUses(std::string_view old, std::string_view replacement);
torch::_export::GraphSignature serialize() const;
private:
std::set<std::string> inputNames() const;
std::set<std::optional<std::string>> outputNames() const;
std::unordered_map<std::string, std::string> gradientsToParameters_;
std::unordered_map<std::string, std::string> gradientsToUserInputs_;
std::unordered_map<std::string, std::string> inputsToParameters_;
std::unordered_map<std::string, std::string> inputsToBuffers_;
std::unordered_map<std::string, std::string> buffersToMutate_;
std::unordered_map<std::string, std::string> inputsToTensorConstants_;
std::unordered_map<std::string, std::string> inputsToCustomObjs_;
std::unordered_map<std::string, std::string> userInputsToMutate_;
// map union of inputsToParameters_, inputsToBuffers_ and
// inputsToTensorConstants_
std::vector<std::pair<std::string, std::string>> inputsToWeights_;
std::vector<std::string> parameters_;
std::vector<std::string> buffers_;
std::vector<std::string> tensorConstants_;
std::vector<std::string> customObjs_;
std::vector<std::string> nonPersistentBuffers_;
std::vector<std::string> userInputs_;
std::vector<std::string> constantInputs_;
std::vector<std::optional<std::string>> userOutputs_;
std::string lossOutput_;
};
std::ostream& operator<<(std::ostream& out, const GraphSignature& sig);
} // namespace torch::nativert

View File

@ -0,0 +1,525 @@
#include "torch/csrc/nativert/graph/Serialization.h"
#include <fmt/format.h>
#include <fmt/ostream.h>
#include <fmt/ranges.h>
namespace torch::nativert {
namespace {
std::unique_ptr<Graph> jsonToSubgraph(
const torch::_export::Graph& thriftGraph,
const torch::_export::GraphSignature* signature,
bool loadNodeMetadata);
Value* symbolicToValue(
const torch::_export::Argument& arg,
Graph& graph,
Node* insertBefore) {
switch (arg.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR:
return graph.getValue(arg.get_as_tensor().get_name());
case torch::_export::Argument::Tag::AS_TENSORS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_tensors()) {
listValue.push_back(graph.getValue(listEl.get_name()));
}
auto listPack = graph.createListPack(std::move(listValue), Type::Tensor);
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_optional_tensors()) {
switch (listEl.tag()) {
case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: {
listValue.push_back(
graph.getValue(listEl.get_as_tensor().get_name()));
break;
}
case torch::_export::OptionalTensorArgument::Tag::AS_NONE: {
listValue.push_back(
graph.addValue(std::nullopt, Type::None, nullptr));
break;
}
default:
throw std::runtime_error(fmt::format(
"Unknown OptionalTensorArgument type: {}",
torch::_export::printEnum(listEl.tag())));
}
}
auto listPack = graph.createOptionalListPack(std::move(listValue));
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
return graph.getValue(arg.get_as_sym_int().get_as_name());
}
case torch::_export::Argument::Tag::AS_SYM_INTS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_sym_ints()) {
switch (listEl.tag()) {
case torch::_export::SymIntArgument::Tag::AS_NAME: {
listValue.push_back(graph.getValue(listEl.get_as_name()));
break;
}
case torch::_export::SymIntArgument::Tag::AS_INT: {
// These are concrete int values in the SymIntList, e.g [s0, 8]
// We convert them into a constant Value in graph. These value
// doesn't have producer node
int value = listEl.get_as_int();
Value* symintValue = graph.createConstantSymIntValue(value);
listValue.push_back(symintValue);
break;
}
default:
throw std::runtime_error(fmt::format(
"Unknown SymIntArgument type: {}",
torch::_export::printEnum(listEl.tag())));
}
}
auto listPack = graph.createListPack(std::move(listValue), Type::SymInt);
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
return graph.getValue(arg.get_as_custom_obj().get_name());
}
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
return graph.getValue(arg.get_as_sym_bool().get_as_name());
}
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
return graph.getValue(arg.get_as_sym_float().get_as_name());
}
default:
throw std::runtime_error(fmt::format(
"This function should only be called for symbolics, got \'{}\' instead",
torch::_export::printEnum(arg.tag())));
}
}
std::pair<
std::vector<torch::_export::InputSpec>,
std::vector<torch::_export::Argument>>
enforceInputOrder(
const std::vector<torch::_export::InputSpec>& inputSpecs,
const std::vector<torch::_export::Argument>& graphInputs) {
// Enforce the order of inputSpecs and graphInputs to be the following:
// 1. token
// 2. parameter
// 3. persistent buffer, non-persistent buffer
// 4. tensor_constant
// 5. custom_obj
// 6. user_input/constant_input
std::vector<torch::_export::InputSpec> reorderedInputSpecs;
std::vector<torch::_export::Argument> reorderedGraphInputs;
std::vector<torch::_export::InputSpec::Tag> desiredOrder = {
torch::_export::InputSpec::Tag::TOKEN,
torch::_export::InputSpec::Tag::PARAMETER,
torch::_export::InputSpec::Tag::BUFFER,
torch::_export::InputSpec::Tag::TENSOR_CONSTANT,
torch::_export::InputSpec::Tag::CUSTOM_OBJ};
auto reorder = [&](auto condition) {
for (size_t i = 0; i < inputSpecs.size(); ++i) {
if (condition(inputSpecs[i])) {
reorderedInputSpecs.push_back(inputSpecs[i]);
reorderedGraphInputs.push_back(graphInputs[i]);
}
}
};
for (const auto& tag : desiredOrder) {
if (tag == torch::_export::InputSpec::Tag::BUFFER) {
// Add persistent buffers first, then non-persistent
reorder([&](const auto& spec) {
return spec.tag() == tag && spec.get_buffer().get_persistent();
});
reorder([&](const auto& spec) {
return spec.tag() == tag && !spec.get_buffer().get_persistent();
});
} else {
reorder([&](const auto& spec) { return spec.tag() == tag; });
}
}
// Append USER_INPUT and CONSTANT_INPUT without reordering
for (size_t i = 0; i < inputSpecs.size(); ++i) {
auto tag = inputSpecs[i].tag();
if (tag == torch::_export::InputSpec::Tag::USER_INPUT ||
tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
reorderedInputSpecs.push_back(inputSpecs[i]);
reorderedGraphInputs.push_back(graphInputs[i]);
}
}
return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)};
}
std::unique_ptr<Graph> jsonToSubgraph(
const torch::_export::Graph& jsonGraph,
const torch::_export::GraphSignature* signature,
bool loadNodeMetadata) {
auto graphInputs = jsonGraph.get_inputs();
auto graph = Graph::createGraph();
if (signature) {
// enforcing the order signature inputspecs and graph inputs
auto inputSpecs = signature->get_input_specs();
auto [reorderedInputSpecs, reorderedGraphInputs] =
enforceInputOrder(inputSpecs, graphInputs);
graphInputs = std::move(reorderedGraphInputs);
auto reorderedSignature = *signature;
reorderedSignature.set_input_specs(reorderedInputSpecs);
graph->setSignature(GraphSignature{reorderedSignature});
}
for (const auto& input : graphInputs) {
if (isSymbolic(input)) {
switch (input.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto& asTensor = input.get_as_tensor();
const auto& name = asTensor.get_name();
graph->addInput(name, Type::Tensor);
break;
}
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
const auto& asCustomObj = input.get_as_custom_obj();
const std::string& name = asCustomObj.get_name();
const std::string& classFqn = asCustomObj.get_class_fqn();
graph->addInput(name, Type(Type::CustomObj, classFqn));
break;
}
default:
throw std::runtime_error(fmt::format(
"Unsupported graph input type {}",
static_cast<int>(input.tag())));
}
} else {
switch (input.tag()) {
case torch::_export::Argument::Tag::AS_INT:
case torch::_export::Argument::Tag::AS_FLOAT:
case torch::_export::Argument::Tag::AS_STRING:
case torch::_export::Argument::Tag::AS_BOOL:
case torch::_export::Argument::Tag::AS_NONE: {
// Constant graph inputs are specialized in the graph, here we simply
// add a nullptr of Value to the graph input node.
graph->addInput();
break;
}
default:
throw std::runtime_error(fmt::format(
"Unsupported graph input type {}",
static_cast<int>(input.tag())));
}
}
}
for (const auto& jsonNode : jsonGraph.get_nodes()) {
auto node = graph->insertNode(
jsonNode.get_target(),
{},
loadNodeMetadata ? jsonNode.get_metadata()
: std::unordered_map<std::string, std::string>());
std::vector<NamedArgument> args;
std::vector<Attribute> attributes;
for (const auto& input : jsonNode.get_inputs()) {
// We handle constants and symbolic inputs differently.
const auto& arg = input.get_arg();
if (isSymbolic(arg)) {
// Symbolic values are made part of the inputs to the node
node->addInput(NamedArgument{
input.get_name(), symbolicToValue(input.get_arg(), *graph, node)});
} else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) {
node->addInput(NamedArgument{
input.get_name(), graph->addValue(std::nullopt, Type::None, node)});
} else {
node->addAttribute(Attribute{
input.get_name(),
constantToValue(input.get_arg(), loadNodeMetadata)});
// Constant values are added as "attributes" to the node.
}
}
std::vector<Value*> outputs;
std::vector<Value*> listUnpacksToCreate;
for (const auto& output : jsonNode.get_outputs()) {
switch (output.tag()) {
case torch::_export::Argument::Tag::AS_NONE: {
node->addOutput(Type::None);
break;
}
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto name = output.get_as_tensor().get_name();
node->addOutput(name, Type::Tensor);
break;
}
case torch::_export::Argument::Tag::AS_TENSORS: {
auto outputValue =
node->addOutput(graph->getUniqueValueName(), Type::TensorList);
Node* listUnpack =
graph->insertNode("prim.ListUnpack", {{"input", outputValue}});
for (const auto& arg : output.get_as_tensors()) {
listUnpack->addOutput(arg.get_name(), Type::Tensor);
}
break;
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
const auto name = output.get_as_sym_int().get_as_name();
node->addOutput(name, Type::SymInt);
break;
}
case torch::_export::Argument::Tag::AS_SYM_INTS: {
throw std::runtime_error(
"SymInts NYI. We currently doesn't have op that produces SymInts as output");
}
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
const auto name = output.get_as_sym_bool().get_as_name();
node->addOutput(name, Type::SymBool);
break;
}
case torch::_export::Argument::Tag::AS_SYM_BOOLS: {
throw std::runtime_error(
"SymBools NYI. We currently doesn't have op that produces SymBools as output");
}
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
const auto name = output.get_as_sym_float().get_as_name();
node->addOutput(name, Type::SymFloat);
break;
}
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
throw std::runtime_error(
"SymFloats NYI. We currently doesn't have op that produces SymFloats as output");
}
default:
throw std::runtime_error(fmt::format(
"disallowed output type {}",
torch::_export::printEnum(output.tag())));
}
}
}
for (const auto& output : jsonGraph.get_outputs()) {
// handle symbolic outputs and constant outputs differently
if (isSymbolic(output)) {
switch (output.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto& asTensor = output.get_as_tensor();
const auto& name = asTensor.get_name();
Value* outputValue = graph->getValue(name);
graph->addOutput(outputValue);
break;
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
const auto& asSymInt = output.get_as_sym_int();
TORCH_CHECK(
asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME);
const auto& name = asSymInt.get_as_name();
Value* outputValue = graph->getValue(name);
graph->addOutput(outputValue);
break;
}
default:
throw std::runtime_error(fmt::format(
"Unsupported graph output type: {}",
static_cast<size_t>(output.tag())));
}
} else {
Constant constValue = constantToValue(output, loadNodeMetadata);
graph->addConstantOutput(constValue);
}
}
auto jsonTensorValue = jsonGraph.get_tensor_values();
if (!signature) {
// For subgraphs we just need to derive a graph signature that only
// contains user inputs and outputs, because we don't need to handle any
// special semantics for them, e.g. mutation or gradients.
torch::_export::GraphSignature sig;
std::vector<torch::_export::InputSpec> inputSpecs;
for (const auto& input : graph->inputs()) {
torch::_export::Argument arg;
if (input->type().kind() == Type::Tensor) {
torch::_export::TensorArgument targ;
targ.set_name(std::string{input->name()});
arg.set_as_tensor(std::move(targ));
} else {
throw std::runtime_error(fmt::format(
"Unsupported subgraph input type {}",
fmt::streamed(input->type())));
}
torch::_export::UserInputSpec userInputSpec;
userInputSpec.set_arg(std::move(arg));
torch::_export::InputSpec inputSpec;
inputSpec.set_user_input(std::move(userInputSpec));
inputSpecs.push_back(std::move(inputSpec));
}
sig.set_input_specs(std::move(inputSpecs));
std::vector<torch::_export::OutputSpec> outputSpecs;
for (const auto& output : graph->outputs()) {
torch::_export::Argument arg;
if (output->type().kind() == Type::Tensor) {
torch::_export::TensorArgument targ;
targ.set_name(std::string{output->name()});
arg.set_as_tensor(std::move(targ));
} else {
throw std::runtime_error(fmt::format(
"Unsupported subgraph input type {}",
fmt::streamed(output->type())));
}
torch::_export::UserOutputSpec userOutputSpec;
userOutputSpec.set_arg(std::move(arg));
torch::_export::OutputSpec outputSpec;
outputSpec.set_user_output(std::move(userOutputSpec));
outputSpecs.push_back(std::move(outputSpec));
}
sig.set_output_specs(std::move(outputSpecs));
graph->setSignature(GraphSignature{sig});
}
// weightsTensorMeta are indexed by weight's name, not graph input's name
std::unordered_map<std::string, torch::_export::TensorMeta> weightsTensorMeta;
for (const auto& [inputName, weightName] :
graph->signature().inputsToWeights()) {
auto value = graph->getValue(inputName);
if (value->type().kind() == Type::CustomObj) {
// skip setting meta for non-tensor inputs
continue;
}
auto it = jsonTensorValue.find(inputName);
CHECK(it != jsonTensorValue.end())
<< "Missing tensor metadata for " << inputName
<< "in thriftGraph.tensorValue";
weightsTensorMeta[weightName] = it->second;
}
graph->setWeightsMeta(weightsTensorMeta);
graph->setTensorValuesMeta(jsonTensorValue);
graph->finalize();
graph->lint();
return graph;
}
} // namespace
bool isSymbolic(const torch::_export::Argument& arg) {
switch (arg.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR:
case torch::_export::Argument::Tag::AS_TENSORS:
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
case torch::_export::Argument::Tag::AS_SYM_INT:
case torch::_export::Argument::Tag::AS_SYM_INTS:
case torch::_export::Argument::Tag::AS_SYM_BOOL:
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
return true;
default:
return false;
}
}
Constant constantToValue(
const torch::_export::Argument& jsonArg,
bool loadNodeMetadata) {
switch (jsonArg.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR:
throw std::runtime_error("Tensor is symbolic, shouldn't reach here");
case torch::_export::Argument::Tag::AS_TENSORS:
throw std::runtime_error("Tensor[] is symbolic, shouldn't reach here");
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
throw std::runtime_error("Tensor?[] is symbolic, shouldn't reach here");
case torch::_export::Argument::Tag::AS_NONE:
return None();
case torch::_export::Argument::Tag::AS_INT:
return jsonArg.get_as_int();
case torch::_export::Argument::Tag::AS_INTS: {
std::vector<int64_t> ret;
for (const auto& arg : jsonArg.get_as_ints()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_FLOAT:
return jsonArg.get_as_float().get();
case torch::_export::Argument::Tag::AS_FLOATS: {
std::vector<double> ret;
for (const auto& arg : jsonArg.get_as_floats()) {
ret.push_back(arg.get());
}
return ret;
}
case torch::_export::Argument::Tag::AS_STRING:
return jsonArg.get_as_string();
case torch::_export::Argument::Tag::AS_STRINGS: {
std::vector<std::string> ret;
for (const auto& arg : jsonArg.get_as_strings()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_SYM_INT:
case torch::_export::Argument::Tag::AS_SYM_INTS:
case torch::_export::Argument::Tag::AS_SYM_BOOL:
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
throw std::runtime_error(
"Symint/Symbool Values are symbolic, shouldn't reach here");
case torch::_export::Argument::Tag::AS_SCALAR_TYPE:
return convertJsonScalarType(jsonArg.get_as_scalar_type());
case torch::_export::Argument::Tag::AS_MEMORY_FORMAT:
return convertJsonMemoryFormat(jsonArg.get_as_memory_format());
case torch::_export::Argument::Tag::AS_LAYOUT:
return convertJsonLayout(jsonArg.get_as_layout());
case torch::_export::Argument::Tag::AS_DEVICE:
return convertJsonDevice(jsonArg.get_as_device());
case torch::_export::Argument::Tag::AS_BOOL:
return jsonArg.get_as_bool();
case torch::_export::Argument::Tag::AS_BOOLS: {
std::vector<bool> ret;
for (const auto& arg : jsonArg.get_as_bools()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_GRAPH: {
return jsonToSubgraph(
*jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata);
}
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
throw std::runtime_error("custom obj is dynamic, shouldn't reach here");
case torch::_export::Argument::Tag::AS_OPERATOR:
return jsonArg.get_as_operator();
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
throw std::runtime_error("sym float is not yet implemented");
}
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
throw std::runtime_error("sym floats is not yet implemented");
}
default:
throw std::runtime_error("Got unknown thrift argument");
}
}
std::unique_ptr<Graph> jsonToGraph(
const torch::_export::GraphModule& jsonGraphModule,
bool loadNodeMetadata) {
auto graph = jsonToSubgraph(
jsonGraphModule.get_graph(),
&jsonGraphModule.get_signature(),
loadNodeMetadata);
return graph;
}
} // namespace torch::nativert

View File

@ -0,0 +1,34 @@
#pragma once
#include "torch/csrc/nativert/graph/Graph.h"
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
namespace torch::nativert {
/**
* This file contains serialization utilities for Graph.
*
* There are two serialized representations we care about:
* - Json: stable but hard to work with, not really human readable
* - Debug format: human-readable, not stable.
*
* All formats should be logically equivalent, so we should be able to go from
* in-memory graph <> json <> debugformat interchangeably
*/
// Json -> Graph
std::unique_ptr<Graph> jsonToGraph(
const torch::_export::GraphModule& jsonGraph,
bool loadNodeMetadata = true);
// Graph -> Json
std::pair<torch::_export::Graph, torch::_export::GraphSignature> graphToJson(
const Graph& graph);
bool isSymbolic(const torch::_export::Argument& arg);
Constant constantToValue(
const torch::_export::Argument& jsonArg,
bool loadNodeMetadata);
} // namespace torch::nativert

View File

@ -0,0 +1,136 @@
#include "torch/csrc/nativert/graph/TensorMeta.h"
#include <c10/util/Logging.h>
namespace torch::nativert {
c10::ScalarType convertJsonScalarType(
const torch::_export::ScalarType& scalarType) {
switch (scalarType) {
case torch::_export::ScalarType::UNKNOWN:
CHECK(false) << "scalar type is not properly set";
case torch::_export::ScalarType::BYTE:
return c10::ScalarType::Byte;
case torch::_export::ScalarType::CHAR:
return c10::ScalarType::Char;
case torch::_export::ScalarType::SHORT:
return c10::ScalarType::Short;
case torch::_export::ScalarType::INT:
return c10::ScalarType::Int;
case torch::_export::ScalarType::LONG:
return c10::ScalarType::Long;
case torch::_export::ScalarType::HALF:
return c10::ScalarType::Half;
case torch::_export::ScalarType::FLOAT:
return c10::ScalarType::Float;
case torch::_export::ScalarType::DOUBLE:
return c10::ScalarType::Double;
case torch::_export::ScalarType::COMPLEXHALF:
return c10::ScalarType::ComplexHalf;
case torch::_export::ScalarType::COMPLEXFLOAT:
return c10::ScalarType::ComplexFloat;
case torch::_export::ScalarType::COMPLEXDOUBLE:
return c10::ScalarType::ComplexDouble;
case torch::_export::ScalarType::BOOL:
return c10::ScalarType::Bool;
case torch::_export::ScalarType::BFLOAT16:
return c10::ScalarType::BFloat16;
default:
TORCH_CHECK(false, "unknown scalar type", static_cast<int>(scalarType));
}
}
c10::MemoryFormat convertJsonMemoryFormat(
const torch::_export::MemoryFormat& memoryFormat) {
switch (memoryFormat) {
case torch::_export::MemoryFormat::Unknown:
TORCH_CHECK(false, "got unknown scalar type");
case torch::_export::MemoryFormat::ContiguousFormat:
return c10::MemoryFormat::Contiguous;
case torch::_export::MemoryFormat::ChannelsLast:
return c10::MemoryFormat::ChannelsLast;
case torch::_export::MemoryFormat::ChannelsLast3d:
return c10::MemoryFormat::ChannelsLast3d;
case torch::_export::MemoryFormat::PreserveFormat:
return c10::MemoryFormat::Preserve;
default:
TORCH_CHECK(
false, "unknown memory format", static_cast<int>(memoryFormat));
}
}
c10::Layout convertJsonLayout(const torch::_export::Layout& layout) {
switch (layout) {
case torch::_export::Layout::Unknown:
TORCH_CHECK(false, "got unknown layout");
case torch::_export::Layout::SparseCoo:
// TODO is this the right translation
return c10::Layout::Sparse;
case torch::_export::Layout::SparseCsr:
return c10::Layout::SparseCsr;
case torch::_export::Layout::SparseCsc:
return c10::Layout::SparseCsc;
case torch::_export::Layout::SparseBsr:
return c10::Layout::SparseBsr;
case torch::_export::Layout::SparseBsc:
return c10::Layout::SparseBsc;
case torch::_export::Layout::_mkldnn:
return c10::Layout::Mkldnn;
case torch::_export::Layout::Strided:
return c10::Layout::Strided;
default:
TORCH_CHECK(false, "unknown layout", static_cast<int>(layout));
}
}
c10::Device convertJsonDevice(const torch::_export::Device& device) {
c10::Device d(device.get_type());
if (auto index = device.get_index()) {
d.set_index(*index);
}
return d;
}
TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta)
: device_(convertJsonDevice(tensorMeta.get_device())) {
dtype_ = convertJsonScalarType(tensorMeta.get_dtype());
layout_ = convertJsonLayout(tensorMeta.get_layout());
requiresGrad_ = tensorMeta.get_requires_grad();
if (tensorMeta.get_storage_offset().tag() ==
torch::_export::SymInt::Tag::AS_INT) {
storage_offset_ = tensorMeta.get_storage_offset().get_as_int();
} else {
CHECK(false) << "SymInt not supported yet";
}
for (const auto& size : tensorMeta.get_sizes()) {
if (size.tag() == torch::_export::SymInt::Tag::AS_INT) {
int64_t val = size.get_as_int();
sizes_.emplace_back(val);
numel_ *= val;
} else if (size.tag() == torch::_export::SymInt::Tag::AS_EXPR) {
// TODO: it's still unclear how SymInt shape should be used in runtime
// One potential use cases is for verifing inputs shape matches constrain
// This would require unpacking the serialized constrain, which is NYI
//
// For the time being, we just set the symbolic dim to -1
hasSymbolicShape_ = true;
sizes_.emplace_back(-1);
numel_ = -1;
}
}
for (const auto& stride : tensorMeta.get_strides()) {
if (stride.tag() == torch::_export::SymInt::Tag::AS_INT) {
strides_.emplace_back(stride.get_as_int());
} else if (stride.tag() == torch::_export::SymInt::Tag::AS_EXPR) {
// TODO: it's still unclear how SymInt shape should be used in runtime
// Setting symbolic shape to -1 for now
hasSymbolicShape_ = true;
strides_.emplace_back(-1);
}
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,92 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Logging.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/ArrayRef.h>
#include "c10/core/Layout.h"
#include "torch/csrc/utils/generated_serialization_types.h" // @manual=//caffe2:torch-cpp-cpu
namespace torch::nativert {
c10::ScalarType convertJsonScalarType(
const torch::_export::ScalarType& scalarType);
c10::MemoryFormat convertJsonMemoryFormat(
const torch::_export::MemoryFormat& memoryFormat);
c10::Layout convertJsonLayout(const torch::_export::Layout& layout);
c10::Device convertJsonDevice(const torch::_export::Device& device);
class TensorMeta {
public:
explicit TensorMeta(const torch::_export::TensorMeta& tensorMeta);
c10::IntArrayRef sizes() const {
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
return sizes_;
}
c10::IntArrayRef strides() const {
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
return strides_;
}
c10::Layout layout() const {
return layout_;
}
c10::ScalarType dtype() const {
return dtype_;
}
bool requires_grad() const {
return requiresGrad_;
}
int64_t storage_offset() const {
return storage_offset_;
}
int64_t dim() const {
return sizes_.size();
}
int64_t numel() const {
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
return numel_;
}
c10::Device device() const {
return device_;
}
c10::TensorOptions asTensorOptions() const {
return c10::TensorOptions().dtype(dtype_).layout(layout_).requires_grad(
requiresGrad_);
}
// NYI
// c10::SymIntArrayRef sym_sizes() const {}
// c10::SymIntArrayRef sym_strides() const {}
// c10::SymInt sym_storage_offset() const {}
// c10::SymInt sym_numel() const {}
private:
bool hasSymbolicShape_ = false;
std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;
int64_t storage_offset_ = 0;
int64_t numel_ = 1;
c10::ScalarType dtype_;
c10::Layout layout_;
bool requiresGrad_;
c10::Device device_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,50 @@
#include "torch/csrc/nativert/kernels/AOTICallDelegateKernel.h"
#include <c10/util/Logging.h>
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
namespace torch::nativert {
AOTICallDelegateKernel::AOTICallDelegateKernel(
const Node* node,
AOTIDelegateExecutor& delegateExecutor)
: OpKernel(node), delegateExecutor_(delegateExecutor) {
// torch.higher_order_ops.aoti_call_delegate(lowered_module, original_gm,
// weight_args, input_args). However, the second
// input is currently serialized as None, so we only have 3 inputs and 1
// output.
// TODO(T213681594): Fix this, None input should still be included in
// numInputs().
TORCH_CHECK_EQ(node->numInputs(), 3);
TORCH_CHECK_EQ(node->numOutputs(), 1);
// Weights are in node->inputs()[1], but it's not used in the forward call
// Instead, weight are bound to AOTI via loadWeights()
const Value* input = node->inputs()[2].value;
const Value* output = node->outputs()[0];
CHECK(input->type() == Type::TensorList)
<< "torch.higher_order_ops.aoti_call_delegate input should be a TensorList, but got "
<< input->type();
CHECK(output->type() == Type::TensorList)
<< "torch.higher_order_ops.aoti_call_delegate output should be a TensorList, but got "
<< output->type();
inputValueId_ = input->id();
outputValueId_ = output->id();
}
void AOTICallDelegateKernel::computeInternal(
ExecutionFrame& executionFrame) const {
std::vector<at::Tensor> inputs =
executionFrame.getTensorVector(inputValueId_);
auto outputs = delegateExecutor_.run(inputs);
executionFrame.setIValue(outputValueId_, std::move(outputs));
}
} // namespace torch::nativert

View File

@ -0,0 +1,27 @@
#pragma once
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
namespace torch::nativert {
class AOTIDelegateExecutor;
// Kernel for torch.higher_order_ops.aoti_call_delegate
class AOTICallDelegateKernel : public OpKernel {
public:
explicit AOTICallDelegateKernel(
const Node* node,
AOTIDelegateExecutor& delegateExecutor);
void computeInternal(ExecutionFrame& executionFrame) const override final;
private:
AOTIDelegateExecutor& delegateExecutor_;
ValueId inputValueId_;
ValueId outputValueId_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,41 @@
#include "torch/csrc/nativert/kernels/AOTIKernel.h"
#include <c10/util/Logging.h>
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
namespace torch::nativert {
AOTIKernel::AOTIKernel(const Node* node, AOTIDelegateExecutor& delegateExecutor)
: OpKernel(node), delegateExecutor_(delegateExecutor) {
// The schema is "call_aotinductor(str path, Tensor[] weights, Tensor[]
// inputs) -> Tensor[] outputs", expects 2 inputs and 1 output
TORCH_CHECK_EQ(node->numInputs(), 2);
TORCH_CHECK_EQ(node->numOutputs(), 1);
// Weights are in node->inputs()[0], but it's not used in the forward call
// Instead, weight are bound to AOTI via loadWeights()
const Value* input = node->inputs()[1].value;
const Value* output = node->outputs()[0];
CHECK(input->type() == Type::TensorList)
<< "delegate.call_aotinductor input should be a TensorList";
CHECK(output->type() == Type::TensorList)
<< "delegate.call_aotinductor output should be a TensorList";
inputValueId_ = input->id();
outputValueId_ = output->id();
}
void AOTIKernel::computeInternal(ExecutionFrame& executionFrame) const {
std::vector<at::Tensor> inputs =
executionFrame.getTensorVector(inputValueId_);
auto outputs = delegateExecutor_.run(inputs);
executionFrame.setIValue(outputValueId_, std::move(outputs));
}
} // namespace torch::nativert

View File

@ -0,0 +1,26 @@
#pragma once
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
namespace torch::nativert {
class AOTIDelegateExecutor;
// Kernel for torch.ops.delegate.call_aotinductor
// TODO: Deprecate this when we move to aoti_call_delegate HOP
class AOTIKernel : public OpKernel {
public:
explicit AOTIKernel(const Node* node, AOTIDelegateExecutor& delegateExecutor);
void computeInternal(ExecutionFrame& executionFrame) const override final;
private:
AOTIDelegateExecutor& delegateExecutor_;
ValueId inputValueId_;
ValueId outputValueId_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,64 @@
#include "torch/csrc/nativert/kernels/AutoFunctionalizeKernel.h"
#include <fmt/format.h>
#include "torch/csrc/nativert/common/Enumerate.h"
namespace torch::nativert {
UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
: OpKernel(node),
op_(getOperatorForTarget(
std::get<std::string>(node->attributes()[0].value))),
schema_(op_.schema()),
arguments_(prefillStackWithStaticArgs(node, schema_)) {
for (const auto& [idx, schemaArg] : enumerate(schema_.arguments())) {
if (schemaArg.alias_info() != nullptr &&
schemaArg.alias_info()->isWrite()) {
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
}
}
numOutputs_ = schema_.returns().size();
}
void UnsafeAutoFunctionalizeKernel::computeInternal(
ExecutionFrame& executionFrame) const {
// Make a copy of the stack
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
fillDynamicInputs(executionFrame, arguments_, stack);
// Call the op with the prepared stack.
try {
op_.callBoxed(stack);
} catch (const std::exception& ex) {
// TODO: this eats the original exception type. ATen returns different
// exception types that correspond to different Python errors (e.g.
// IndexError, ValueError). If retaining this information is important
// to us, we'll have to change this up a little.
auto stackTrace = node_->getMetadata("stack_trace");
throw std::runtime_error(fmt::format(
"Original Python stacktrace:\n{}\n{}",
stackTrace ? *stackTrace : "<no stack trace>",
ex.what()));
}
const auto& outputValues = node_->outputs();
for (int i = 0; i < numOutputs_; ++i) {
executionFrame.setIValue(outputValues[i]->id(), std::move(stack.at(i)));
}
// Copy over mutating inputs to outputs
int mutatingArgStartIndex = (numOutputs_ == 0) ? 1 : numOutputs_;
for (int i = mutatingArgStartIndex; i < outputValues.size(); ++i) {
executionFrame.setIValue(
outputValues[i]->id(),
executionFrame.getIValue(
mutatingInputArgs_.at(i - mutatingArgStartIndex)->id(),
true /* allowNone */));
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,27 @@
#pragma once
#include <torch/script.h>
#include "c10/core/Device.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
namespace torch::nativert {
class UnsafeAutoFunctionalizeKernel : public OpKernel {
public:
UnsafeAutoFunctionalizeKernel() = delete; // deleted default constructor
UnsafeAutoFunctionalizeKernel(const Node* node);
void computeInternal(ExecutionFrame& executionFrame) const override final;
private:
c10::OperatorHandle op_;
c10::FunctionSchema schema_;
Arguments arguments_;
std::vector<Value*> mutatingInputArgs_;
int numOutputs_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,242 @@
#include "torch/csrc/nativert/kernels/C10Kernel.h"
#include <c10/util/Logging.h>
#include <fmt/ostream.h>
#include "torch/csrc/nativert/common/Enumerate.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#ifdef __SIGRID_USE_GPU__
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#endif
namespace torch::nativert {
C10Kernel::C10Kernel(const Node* node, c10::Device device)
: OpKernel(node, device),
op_(getOperatorForTarget(node->target(), node)),
schema_(op_.schema()),
arguments_(prefillStackWithStaticArgs(node, schema_)) {}
void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const {
// Make a copy of the stack
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
fillDynamicInputs(executionFrame, arguments_, stack);
// Call the op with the prepared stack.
try {
op_.callBoxed(stack);
} catch (const std::exception& ex) {
auto stackTrace = node_->getMetadata("stack_trace");
throw std::runtime_error(fmt::format(
"Exception while executing node: {}\n"
"with args:\n{}\n"
"{}\n"
"Original Python stacktrace:\n{}",
fmt::streamed(*node_),
readableArgs(schema_, stack),
ex.what(),
stackTrace ? *stackTrace : "<no stack trace>"));
}
// Write out results
// TODO: we store intermediates in a single table (symint and tensor alike).
// This can theoretically lead to name collisions, although based on how
// these are named I don't think it will ever happen in practice. We need to
// enforce it though.
const auto& outputValues = node_->outputs();
TORCH_CHECK_EQ(outputValues.size(), stack.size())
<< "Output size mismatch for " << node_->toString();
for (auto&& [i, actualOutput] : enumerate(stack)) {
executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput));
}
}
namespace {
std::unordered_map<std::string, c10::IValue> getSymInputs(
const ExecutionFrame& executionFrame,
const Node& node) {
std::unordered_map<std::string, c10::IValue> inputs;
for (const auto& input : node.inputs()) {
const auto& val = executionFrame.getIValue(input.value->id());
if (val.isInt() || val.isDouble() || val.isBool()) {
inputs[input.name] = val;
} else {
throw std::runtime_error("unsupported type for symbolic input");
}
}
for (const auto& attribute : node.attributes()) {
if (std::holds_alternative<int64_t>(attribute.value)) {
inputs[attribute.name] = std::get<int64_t>(attribute.value);
} else if (std::holds_alternative<double>(attribute.value)) {
inputs[attribute.name] = std::get<double>(attribute.value);
} else if (std::holds_alternative<bool>(attribute.value)) {
inputs[attribute.name] = std::get<bool>(attribute.value);
} else {
throw std::runtime_error("unsupported type for symbolic input");
}
}
return inputs;
}
template <typename T>
void computeScalarBinaryOp(
ExecutionFrame& executionFrame,
const Node& node,
std::enable_if_t<true, T> a,
std::enable_if_t<true, T> b) {
std::string_view target = node.target();
T out;
if (target == "_operator.add") {
out = a + b;
} else if (target == "_operator.sub") {
out = a - b;
} else if (target == "_operator.mul") {
out = a * b;
} else if (target == "_operator.pow") {
out = std::pow(a, b);
} else {
throw std::runtime_error(
fmt::format("unsupported operator for symbolic values: {}", target));
}
executionFrame.setIValue(node.outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, b={}, out={}",
fmt::streamed(node),
a,
b,
out);
}
} // namespace
void ScalarBinaryOpKernel::computeInternal(
ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
const auto& a = inputs.at("a");
const auto& b = inputs.at("b");
auto coerceToDouble = [](const c10::IValue& x) -> double {
if (x.isInt()) {
return static_cast<double>(x.toInt());
} else if (x.isDouble()) {
return x.toDouble();
} else {
throw std::runtime_error("unsupported type for symbolic input");
}
};
if (a.isInt() && b.isInt()) {
computeScalarBinaryOp<int64_t>(
executionFrame, *node_, a.toInt(), b.toInt());
} else {
computeScalarBinaryOp<double>(
executionFrame, *node_, coerceToDouble(a), coerceToDouble(b));
}
}
void SymIntOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
int64_t a = inputs.at("a").toInt();
std::string_view target = node_->target();
if (target == "torch.sym_float") {
double out = static_cast<double>(a);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, out={}",
fmt::streamed(*node_),
a,
out);
return;
}
int64_t b = inputs.at("b").toInt();
int64_t out;
if (target == "_operator.floordiv") {
out = a / b;
} else if (target == "_operator.mod") {
out = a % b;
} else if (target == "torch.sym_max") {
out = std::max(a, b);
} else if (target == "torch.sym_min") {
out = std::min(a, b);
} else {
throw std::runtime_error(
fmt::format("unsupported operator for SymInt: {}", node_->target()));
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, b={}, out={}",
fmt::streamed(*node_),
a,
b,
out);
}
void SymBoolOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
bool out;
const std::string_view target = node_->target();
if (target == "torch.sym_not") {
bool a = inputs.at("a").toBool();
out = !a;
} else if (target == "_operator.ge") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a >= b;
} else if (target == "_operator.le") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a <= b;
} else if (target == "_operator.eq") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a == b;
} else if (target == "_operator.gt") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a > b;
} else if (target == "_operator.lt") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a < b;
} else if (target == "_operator.and_") {
bool a = inputs.at("a").toBool();
bool b = inputs.at("b").toBool();
out = a && b;
} else {
throw std::runtime_error(
fmt::format("unsupported operator for SymBool: {}", node_->target()));
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
}
void SymFloatOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
const std::string_view target = node_->target();
if (target == "math.trunc") {
double x = inputs.at("x").toDouble();
int64_t out = trunc(x);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else if (target == "torch._sym_sqrt") {
double a = inputs.at("a").toDouble();
double out = std::sqrt(a);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else {
throw std::runtime_error(
fmt::format("unsupported operator for SymFloat: {}", node_->target()));
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,73 @@
#pragma once
#include <torch/script.h>
#include "c10/core/Device.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
#include "torch/csrc/nativert/graph/Graph.h"
namespace torch::nativert {
// Implementation of Kernel for ATen operators
//
// This class exists to amortize per-kernel overhead by computing things during
// initialization instead of on every execution. Right now we are only
// amortizing schema resolution, and static arguments parsing,
// but in the future this could be extended to avoid operator dispatch and
// do better "Register" allocation (e.g. convert input/outputs to directly
// array accesses onto a set of registers, in concert with memory planning)
class C10Kernel : public OpKernel {
public:
C10Kernel() = delete; // deleted default constructor
C10Kernel(const Node* node, c10::Device device);
virtual ~C10Kernel() = default;
[[nodiscard]] const c10::IValue& input(
uint32_t i,
ExecutionFrame& executionFrame) const override {
if (Value* dynamicArg = arguments_.findDynamic(i)) {
return executionFrame.getIValue(dynamicArg->id());
} else {
return arguments_.getStatic(i);
}
}
void computeInternal(ExecutionFrame& executionFrame) const override;
private:
c10::OperatorHandle op_;
c10::FunctionSchema schema_;
Arguments arguments_;
};
class SymIntOpKernel : public OpKernel {
public:
explicit SymIntOpKernel(const Node* node) : OpKernel(node) {}
void computeInternal(ExecutionFrame& executionFrame) const override final;
};
class SymBoolOpKernel : public OpKernel {
public:
explicit SymBoolOpKernel(const Node* node) : OpKernel(node) {}
void computeInternal(ExecutionFrame& executionFrame) const override final;
};
class SymFloatOpKernel : public OpKernel {
public:
explicit SymFloatOpKernel(const Node* node) : OpKernel(node) {}
void computeInternal(ExecutionFrame& executionFrame) const override final;
};
// ScalarOpKernel does binary arithmetic operations on scalar values.
// Integers and floats are supported as input types. The output will be
// promoted to float if and only if there's at least one float input.
class ScalarBinaryOpKernel : public OpKernel {
public:
explicit ScalarBinaryOpKernel(const Node* node) : OpKernel(node) {}
void computeInternal(ExecutionFrame& executionFrame) const override final;
};
} // namespace torch::nativert

View File

@ -0,0 +1,48 @@
#include "torch/csrc/nativert/kernels/CallTorchBindKernel.h"
#include "torch/csrc/nativert/common/Enumerate.h"
namespace torch::nativert {
CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) {
const Value* customObjValue = node_->inputs()[0].value;
CHECK(customObjValue->type() == Type::CustomObj);
customClassName_ = customObjValue->type().classFqn();
customClassType_ = torch::jit::getCustomClass(customClassName_);
// sample schema
// torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1);
CHECK(node->attributes().size() == 1)
<< "Expects higher_order.call_torchbind to only have a single attribute, methodName";
const auto& attr = node->attributes()[0];
CHECK(std::holds_alternative<std::string>(attr.value))
<< "method should be a string";
methodName_ = std::get<std::string>(attr.value);
method_ = customClassType_->findMethod(methodName_);
CHECK(method_ != nullptr) << "method not found: " << methodName_;
}
void CallTorchBindKernel::computeInternal(
ExecutionFrame& executionFrame) const {
// prepare inputs
std::vector<c10::IValue> stack;
for (const auto& input : node_->inputs()) {
const auto& id = input.value->id();
stack.emplace_back(executionFrame.getIValue(id));
}
// call the method
method_->run(stack);
// set outputs
const auto& outputs = node_->outputs();
TORCH_CHECK_EQ(outputs.size(), stack.size());
for (auto&& [i, outputValue] : enumerate(stack)) {
executionFrame.setIValue(outputs[i]->id(), std::move(outputValue));
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,25 @@
#pragma once
#include <torch/script.h>
#include "c10/core/Device.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
namespace torch::nativert {
class CallTorchBindKernel : public OpKernel {
public:
CallTorchBindKernel() = delete; // deleted default constructor
CallTorchBindKernel(const Node* node);
void computeInternal(ExecutionFrame& executionFrame) const override final;
private:
std::string methodName_;
torch::jit::Function* method_;
std::string customClassName_;
at::ClassTypePtr customClassType_;
};
} // namespace torch::nativert

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,116 @@
#include "torch/csrc/nativert/kernels/HigherOrderKernel.h"
#include <fmt/format.h>
#include "torch/csrc/nativert/common/String.h"
#include "torch/csrc/nativert/executor/ExecutionFrame.h"
namespace torch::nativert {
HigherOrderKernel::HigherOrderKernel(
const Node* node,
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors)
: OpKernel(node), graphExecutors_(std::move(graphExecutors)) {
static constexpr std::string_view prefix = "torch.ops.higher_order.";
CHECK(starts_with(node->target(), prefix));
auto opName = node->target().substr(prefix.size());
if (opName == "cond") {
opType_ = OpType::COND;
// Checking torch.cond schema is as expected:
// torch.cond(Tensor predicate, Graph graph1, Graph graph2, Tensor[] args)
// -> Tensor[]
TORCH_CHECK_EQ(node_->attributes().size(), 2);
TORCH_CHECK_EQ(node_->inputs().size(), 2);
} else if (opName == "while_loop") {
opType_ = OpType::WHILE_LOOP;
// Checking torch.while_loop schema is as expected:
// torch.while_loop(Graph cond, Graph body, Tensor[] args, Tensor[]
// additonal) -> Tensor[]
TORCH_CHECK_EQ(node_->attributes().size(), 2);
TORCH_CHECK_EQ(node_->inputs().size(), 2);
} else if (opName == "run_const_graph") {
opType_ = OpType::RUN_CONST_GRAPH;
// Checking torch.run_const_graph schema is as expected:
// torch.run_const_graph(Graph graph, Tensor[] args) -> Tensor[]
TORCH_CHECK_GE(node_->attributes().size(), 1);
TORCH_CHECK_EQ(node_->inputs().size(), 1);
} else {
throw std::runtime_error(
fmt::format("Unknown higher order op: {}", opName));
}
}
void HigherOrderKernel::computeInternal(ExecutionFrame& executionFrame) const {
switch (opType_) {
case OpType::COND: {
auto inputs = executionFrame.getIValue(node_->inputs()[1].value->id())
.toList()
.vec();
std::vector<c10::IValue> outputs;
auto cond = executionFrame.getIValue(node_->inputs()[0].value->id());
size_t branchIdx = 0;
if (cond.isTensor()) {
branchIdx = cond.toTensor().item().toBool() ? 0 : 1;
} else if (cond.isBool()) {
branchIdx = cond.toBool() ? 0 : 1;
} else {
throw std::runtime_error("Unsupported type for cond predicate");
}
ExecutionFrame branchFrame(*std::get<std::unique_ptr<Graph>>(
node_->attributes()[branchIdx].value));
graphExecutors_[branchIdx]->execute(branchFrame, std::move(inputs));
auto ret = branchFrame.getUserOutputs();
for (size_t i = 0; i < ret.size(); i++) {
executionFrame.setIValue(node_->outputs()[i]->id(), std::move(ret[i]));
}
break;
}
case OpType::WHILE_LOOP: {
auto carriedVals =
executionFrame.getIValue(node_->inputs()[0].value->id())
.toList()
.vec();
auto additonalVals =
executionFrame.getIValue(node_->inputs()[1].value->id())
.toList()
.vec();
size_t numCarriedVals = carriedVals.size();
ExecutionFrame condFrame(
*std::get<std::unique_ptr<Graph>>(node_->attributes()[0].value));
ExecutionFrame bodyFrame(
*std::get<std::unique_ptr<Graph>>(node_->attributes()[1].value));
while (true) {
auto inputs = carriedVals;
inputs.insert(inputs.end(), additonalVals.begin(), additonalVals.end());
graphExecutors_[0]->execute(condFrame, inputs);
auto cond = condFrame.getUserOutputs();
if (cond.at(0).isTensor() && !cond[0].toTensor().item().toBool()) {
break;
}
if (cond.at(0).isBool() && !cond[0].toBool()) {
break;
}
graphExecutors_[1]->execute(bodyFrame, std::move(inputs));
auto out = bodyFrame.getUserOutputs();
TORCH_CHECK(out.size() == numCarriedVals);
carriedVals = std::move(out);
}
for (size_t i = 0; i < carriedVals.size(); i++) {
executionFrame.setIValue(
node_->outputs()[i]->id(), std::move(carriedVals[i]));
}
break;
}
case OpType::RUN_CONST_GRAPH: {
// run_const_graph op is a special case of higher order op which has
// been executed during weights loading, therefore at runtime we can
// just make this a no-op.
break;
}
default:
TORCH_CHECK(false, "Unknown higher order op");
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,28 @@
#pragma once
#include "c10/core/Device.h"
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
#include "torch/csrc/nativert/graph/Graph.h"
namespace torch::nativert {
class HigherOrderKernel : public OpKernel {
enum class OpType {
UNKNOWN,
COND,
WHILE_LOOP,
RUN_CONST_GRAPH,
};
public:
HigherOrderKernel(
const Node* node,
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors);
void computeInternal(ExecutionFrame& executionFrame) const override final;
private:
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors_;
OpType opType_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,288 @@
#include "torch/csrc/nativert/kernels/KernelFactory.h"
#include <string_view>
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include "torch/csrc/nativert/executor/AOTIDelegateExecutor.h"
#include "torch/csrc/nativert/kernels/AOTICallDelegateKernel.h"
#include "torch/csrc/nativert/kernels/AOTIKernel.h"
#ifdef __SIGRID_USE_FBA__
#include "sigmoid/backend/MTIADelegateExecutor.h"
#include "sigmoid/backend/MTIAKernel.h"
#endif
#include "torch/csrc/nativert/common/String.h"
#include "torch/csrc/nativert/executor/DelegateExecutor.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
#include "torch/csrc/nativert/executor/ParallelGraphExecutor.h"
#include "torch/csrc/nativert/executor/SerialGraphExecutor.h"
#include "torch/csrc/nativert/graph/Graph.h"
#include "torch/csrc/nativert/graph/GraphPasses.h"
#include "torch/csrc/nativert/kernels/AutoFunctionalizeKernel.h"
#include "torch/csrc/nativert/kernels/C10Kernel.h"
#include "torch/csrc/nativert/kernels/CallTorchBindKernel.h"
#include "torch/csrc/nativert/kernels/HigherOrderKernel.h"
#include "torch/csrc/nativert/kernels/KernelRegistry.h"
namespace torch::nativert {
namespace {
c10::Device inferTargetDevice(
const Node& node,
const std::unordered_map<std::string, TensorMeta>& tensorValuesMeta,
const Placement& placement) {
if (node.target() == "prim.Input" || node.target() == "prim.Output") {
return c10::Device(c10::DeviceType::CPU);
}
std::vector<c10::Device> devices;
for (auto& output : node.outputs()) {
if (output->type() == Type::Tensor) {
auto it = tensorValuesMeta.find(std::string{output->name()});
if (it != tensorValuesMeta.end()) {
devices.emplace_back(it->second.device());
}
} else if (output->type() == Type::TensorList) {
for (const auto& el : output->getListElements()) {
auto it = tensorValuesMeta.find(std::string{el->name()});
if (it != tensorValuesMeta.end()) {
devices.emplace_back(it->second.device());
}
}
}
}
if (devices.empty()) {
return c10::Device(c10::DeviceType::CPU);
} else {
for (size_t i = 1; i < devices.size(); ++i) {
if (!isSameDevice(devices[0], devices[i])) {
LOG(WARNING) << "Node " << node
<< " has outputs on multiple devices: " << devices[0]
<< " and " << devices[i];
}
}
return placement.getMappedDevice(devices[0]);
}
}
} // namespace
inline constexpr std::string_view kSymIntOps[] = {
"_operator.floordiv",
"_operator.mod",
"torch.sym_int",
"torch.sym_float",
"torch.sym_ite",
"torch.sym_max",
"torch.sym_min",
};
inline constexpr std::string_view kSymBoolOps[] = {
"_operator.eq",
"_operator.ne",
"_operator.le",
"_operator.ge",
"_operator.lt",
"_operator.gt",
"_operator.and_",
"torch.sym_not",
};
inline constexpr std::string_view kSymFloatOps[] = {
"torch._sym_sqrt",
"math.trunc",
};
inline constexpr std::string_view kScalarBinaryOps[] = {
"_operator.mul",
"_operator.add",
"_operator.sub",
"_operator.pow",
};
namespace {
const std::string maybeRevisedStaticDispatchTarget(const Node& node) {
auto overloadName = selectScalarOverloadName(node);
if (!ends_with(node.target(), overloadName)) {
const std::string& newTarget =
std::string(node.target())
.replace(node.target().rfind('.'), std::string::npos, overloadName);
LOG(INFO) << fmt::format(
"Converting Tensor to {} for node: {} -> {}",
overloadName,
node.target(),
newTarget);
return newTarget;
}
return std::string(node.target());
}
} // namespace
ExecutionKernels KernelFactory::initializeNodeKernels(
const Graph& graph,
std::shared_ptr<Weights> weights,
const ExecutorConfig& executorConfig,
const Placement& placement,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader) {
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
std::vector<ConstFoldingExecution> constFoldingExecutions;
std::unordered_set<std::string> opsWithoutStaticDispatch;
VLOG(1) << "PrimKernelRegistry: " << join(", ", PrimKernelRegistry()->Keys());
VLOG(1) << "StaticallyDispatchedCPUKernelRegistry: "
<< join(", ", StaticallyDispatchedCPUKernelRegistry()->Keys());
for (const auto& node : graph.nodes()) {
std::string originalTarget = std::string(node.target());
const std::string target =
(executorConfig.enableStaticCPUKernels &&
StaticallyDispatchedCPUKernelRegistry()->Has(originalTarget))
? maybeRevisedStaticDispatchTarget(node)
: std::move(originalTarget);
c10::Device targetDevice =
inferTargetDevice(node, graph.tensorValuesMeta(), placement);
if (PrimKernelRegistry()->Has(target)) {
nodeKernels.push_back(PrimKernelRegistry()->Create(target, &node));
} else if (
executorConfig.enableStaticCPUKernels &&
StaticallyDispatchedCPUKernelRegistry()->Has(target) &&
targetDevice.is_cpu()) {
nodeKernels.push_back(StaticallyDispatchedCPUKernelRegistry()->Create(
target, &node, targetDevice));
} else if (starts_with(
node.target(), "torch.ops.higher_order.call_torchbind")) {
nodeKernels.push_back(std::make_unique<CallTorchBindKernel>(&node));
} else if (
starts_with(
node.target(),
"torch.ops.higher_order.auto_functionalized") ||
starts_with( // TODO Remove this condition once the old
// pt2 archives are expired.
node.target(),
"torch._higher_order_ops.auto_functionalize.auto_functionalized")) {
nodeKernels.push_back(
std::make_unique<UnsafeAutoFunctionalizeKernel>(&node));
} else if (
std::find(
std::begin(kSymIntOps), std::end(kSymIntOps), node.target()) !=
std::end(kSymIntOps)) {
nodeKernels.push_back(std::make_unique<SymIntOpKernel>(&node));
} else if (
std::find(
std::begin(kSymBoolOps), std::end(kSymBoolOps), node.target()) !=
std::end(kSymBoolOps)) {
nodeKernels.push_back(std::make_unique<SymBoolOpKernel>(&node));
} else if (
std::find(
std::begin(kSymFloatOps), std::end(kSymFloatOps), node.target()) !=
std::end(kSymFloatOps)) {
nodeKernels.push_back(std::make_unique<SymFloatOpKernel>(&node));
} else if (
std::find(
std::begin(kScalarBinaryOps),
std::end(kScalarBinaryOps),
node.target()) != std::end(kScalarBinaryOps)) {
nodeKernels.push_back(std::make_unique<ScalarBinaryOpKernel>(&node));
} else if (starts_with(
node.target(), "torch.ops.delegate.call_aotinductor")) {
const auto pathAttr = node.tryGetAttribute("path");
CHECK(pathAttr != nullptr);
const Constant& pathValue = pathAttr->value;
CHECK(std::holds_alternative<std::string>(pathValue));
std::string path = std::get<std::string>(pathValue);
auto delegateExecutor = std::make_unique<AOTIDelegateExecutor>(
path, weights, targetDevice, executorConfig, pytorchStreamReader);
nodeKernels.push_back(
std::make_unique<AOTIKernel>(&node, *delegateExecutor));
delegateExecutors.push_back(std::move(delegateExecutor));
} else if (starts_with(
node.target(),
"torch.ops.higher_order.aoti_call_delegate")) {
// the first attribute is serialized as the path to the aotinductor
const auto pathAttr = node.attributes().begin();
const Constant& pathValue = pathAttr->value;
CHECK(std::holds_alternative<std::string>(pathValue));
std::string path = std::get<std::string>(pathValue);
auto delegateExecutor = std::make_unique<AOTIDelegateExecutor>(
path, weights, targetDevice, executorConfig, pytorchStreamReader);
nodeKernels.push_back(
std::make_unique<AOTICallDelegateKernel>(&node, *delegateExecutor));
delegateExecutors.push_back(std::move(delegateExecutor));
} else if (starts_with(node.target(), "torch.ops.delegate.call_mtia")) {
#ifdef __SIGRID_USE_FBA__
auto delegateExecutor = std::make_unique<MTIADelegateExecutor>(
&node, weights, executorConfig, pytorchStreamReader);
nodeKernels.push_back(
std::make_unique<MTIAKernel>(&node, *delegateExecutor));
delegateExecutors.push_back(std::move(delegateExecutor));
#endif
} else if (starts_with(node.target(), "torch.ops.higher_order")) {
std::vector<std::unique_ptr<GraphExecutorBase>> graphExecutors;
for (const auto& attr : node.attributes()) {
if (std::holds_alternative<std::unique_ptr<Graph>>(attr.value)) {
const auto& subgraph = std::get<std::unique_ptr<Graph>>(attr.value);
auto executionKernels = initializeNodeKernels(
*subgraph, weights, executorConfig, placement);
CHECK(executionKernels.delegateExecutors.empty())
<< "HigherOrderKernel does not support delegates";
CHECK(executionKernels.constFoldingExecutions.size() == 0)
<< "HigherOrderKernel does not support const folding";
if (executorConfig.maxParallelOps > 1) {
graphExecutors.emplace_back(
std::unique_ptr<GraphExecutorBase>(new ParallelGraphExecutor(
*subgraph,
std::move(executionKernels.nodeKernels),
executorConfig)));
} else {
graphExecutors.emplace_back(
std::unique_ptr<GraphExecutorBase>(new SerialGraphExecutor(
*subgraph,
std::move(executionKernels.nodeKernels),
executorConfig)));
}
}
}
if (node.target() == "torch.ops.higher_order.run_const_graph") {
constFoldingExecutions.push_back(
ConstFoldingExecution{std::move(graphExecutors[0])});
}
nodeKernels.push_back(std::make_unique<HigherOrderKernel>(
&node, std::move(graphExecutors)));
} else if (starts_with(node.target(), "torch.ops")) {
nodeKernels.push_back(std::make_unique<C10Kernel>(&node, targetDevice));
opsWithoutStaticDispatch.insert(std::string(node.target()));
} else {
TORCH_CHECK(false, "Unsupported operator: ", target);
}
}
if (executorConfig.enableStaticCPUKernels) {
std::stringstream ss;
for (const auto& op : opsWithoutStaticDispatch) {
ss << op << ", ";
}
LOG(WARNING) << "Following ops are missing static dispatched kernels: "
<< ss.str();
}
return {
std::move(nodeKernels),
std::move(delegateExecutors),
std::move(constFoldingExecutions)};
}
} // namespace torch::nativert

View File

@ -0,0 +1,37 @@
#pragma once
#include <memory>
#include <torch/script.h>
#include "torch/csrc/nativert/executor/ExecutorConfig.h"
#include "torch/csrc/nativert/executor/GraphExecutorBase.h"
#include "torch/csrc/nativert/executor/OpKernel.h"
namespace torch::nativert {
class DelegateExecutor;
struct ConstFoldingExecution {
std::unique_ptr<GraphExecutorBase> executor;
};
struct ExecutionKernels {
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
std::vector<ConstFoldingExecution> constFoldingExecutions;
};
class KernelFactory {
public:
explicit KernelFactory() {}
ExecutionKernels initializeNodeKernels(
const Graph& graph,
std::shared_ptr<Weights> weights,
const ExecutorConfig& executorConfig,
const Placement& placement,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
pytorchStreamReader = nullptr);
};
} // namespace torch::nativert

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,117 @@
#pragma once
#include <torch/script.h>
#include "torch/csrc/nativert/executor/OpKernel.h"
#include "torch/csrc/nativert/graph/Graph.h"
#include "torch/csrc/nativert/kernels/C10Kernel.h"
namespace torch::nativert {
#define KernelInput(id) input(id, executionFrame)
#define KernelOutput(id) output(id, executionFrame)
TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*);
#define REGISTER_PRIM_KERNEL(name, id, ...) \
class OpKernel_##id : public OpKernel { \
public: \
OpKernel_##id(const Node* node) : OpKernel(node) { \
kind_ = OpKernel::Kind::kPrimKernel; \
} \
void computeInternal( \
ExecutionFrame& executionFrame) const override final { \
__VA_ARGS__; \
} \
}; \
C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id);
TORCH_DECLARE_REGISTRY(
StaticallyDispatchedCPUKernelRegistry,
OpKernel,
const Node*,
c10::Device);
#define REGISTER_CPU_KERNEL(name, id, ...) \
class OpKernel_##id : public C10Kernel { \
public: \
OpKernel_##id(const Node* node, c10::Device device) \
: C10Kernel(node, device) { \
kind_ = OpKernel::Kind::kStaticDispatchKernel; \
} \
void computeInternal( \
ExecutionFrame& executionFrame) const override final { \
__VA_ARGS__; \
} \
}; \
C10_REGISTER_TYPED_CLASS( \
StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id);
inline bool checkResizedDataPtr(at::Tensor& t) {
auto const prev_data_ptr = t.data_ptr();
t.resize_({0});
return prev_data_ptr == t.data_ptr();
}
inline void fastResizeToZero(at::Tensor& t) {
t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
}
inline at::Tensor create_empty_from(const at::Tensor& t) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
t.layout(),
t.device(),
std::nullopt,
std::nullopt);
}
inline at::Tensor create_empty_from(
const at::Tensor& t,
c10::ScalarType dtype) {
return at::detail::empty_cpu(
{0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt);
}
inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
t.layout(),
device,
std::nullopt,
std::nullopt);
}
inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
layout,
t.device(),
std::nullopt,
std::nullopt);
}
inline at::Tensor create_empty_from(
const at::Tensor& t,
c10::MemoryFormat memory_format) {
return at::detail::empty_cpu(
{0},
c10::typeMetaToScalarType(t.dtype()),
t.layout(),
t.device(),
std::nullopt,
memory_format);
}
inline at::Tensor create_empty_from(
const at::Tensor& t,
c10::ScalarType dtype,
c10::MemoryFormat memory_format) {
return at::detail::empty_cpu(
{0}, dtype, t.layout(), t.device(), std::nullopt, memory_format);
}
} // namespace torch::nativert

View File

@ -0,0 +1,14 @@
#include "torch/csrc/nativert/kernels/KernelRegistry.h"
namespace torch::nativert {
REGISTER_CPU_KERNEL("torch.ops.aten.slice.Tensor", aten_slice_Tensor, {
const auto& self = KernelInput(0).toTensor();
const auto& dim = KernelInput(1).toInt();
const auto& start = KernelInput(2).toOptional<int64_t>();
const auto& end = KernelInput(3).toOptional<int64_t>();
const auto& step = KernelInput(4).toInt();
KernelOutput(0) = at::native::slice(self, dim, start, end, step);
});
} // namespace torch::nativert

View File

@ -0,0 +1,68 @@
#pragma once
#include <array>
#include <string_view>
namespace torch::nativert {
namespace archive_spec {
#define FORALL_COSTANTS(_) \
_(ARCHIVE_ROOT_NAME, "package") \
/* Archive format */ \
_(ARCHIVE_FORMAT_PATH, "archive_format") \
_(ARCHIVE_FORMAT_VALUE, "pt2") \
/* Archive version */ \
_(ARCHIVE_VERSION_PATH, "archive_version") \
_(ARCHIVE_VERSION_VALUE, \
"0") /* Sep.4.2024: This is the initial version of the PT2 Archive Spec */ \
/* \
* ######## Note on updating ARCHIVE_VERSION_VALUE ######## \
* When there is a BC breaking change to the PT2 Archive Spec, \
* e.g. deleting a folder, or changing the naming convention of the \
* following fields it would require bumping the ARCHIVE_VERSION_VALUE \
* Archive reader would need corresponding changes to support loading both \
* the current and older versions of the PT2 Archive. \
*/ \
/* Model definitions */ \
_(MODELS_DIR, "models/") \
_(MODELS_FILENAME_FORMAT, "models/{}.json") /* {model_name} */ \
/* AOTInductor artifacts */ \
_(AOTINDUCTOR_DIR, "data/aotinductor/") \
/* MTIA artifacts */ \
_(MTIA_DIR, "data/mtia") \
/* weights, including parameters and buffers */ \
_(WEIGHTS_DIR, "data/weights/") \
_(WEIGHT_FILENAME_PREFIX, "weight_") \
/* constants, including tensor_constants, non-persistent buffers and script \
* objects */ \
_(CONSTANTS_DIR, "data/constants/") \
_(TENSOR_CONSTANT_FILENAME_PREFIX, "tensor_") \
_(CUSTOM_OBJ_FILENAME_PREFIX, "custom_obj_") \
/* example inputs */ \
_(SAMPLE_INPUTS_DIR, "data/sample_inputs/") \
_(SAMPLE_INPUTS_FILENAME_FORMAT, \
"data/sample_inputs/{}.pt") /* {model_name} */ \
/* extra folder */ \
_(EXTRA_DIR, "extra/") \
_(MODULE_INFO_PATH, "extra/module_info.json") \
/* xl_model_weights, this folder is used for storing per-feature-weights for \
* remote net data in this folder is consume by Predictor, and is not \
* intended to be used by Sigmoid */ \
_(XL_MODEL_WEIGHTS_DIR, "xl_model_weights/") \
_(XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH, "xl_model_weights/model_param_config")
#define DEFINE_GLOBAL(NAME, VALUE) \
inline constexpr std::string_view NAME = VALUE;
#define DEFINE_ENTRY(NAME, VALUE) std::pair(#NAME, VALUE),
FORALL_COSTANTS(DEFINE_GLOBAL)
inline constexpr std::array kAllConstants{FORALL_COSTANTS(DEFINE_ENTRY)};
#undef DEFINE_ENTRY
#undef DEFINE_GLOBAL
#undef FORALL_COSTANTS
} // namespace archive_spec
} // namespace torch::nativert

View File

@ -0,0 +1,18 @@
#include <pybind11/pybind11.h>
#include "torch/csrc/nativert/package/pt2_archive_constants.h"
namespace torch::nativert {
void initPt2ArchiveConstantsPybind(pybind11::module& m) {
for (const auto& entry : torch::nativert::archive_spec::kAllConstants) {
m.attr(entry.first) = entry.second;
}
}
} // namespace torch::nativert
// TODO Remove this once we fully migrate to OSS build.
#ifdef FBCODE_CAFFE2
PYBIND11_MODULE(pt2_archive_constants_pybind, m) {
torch::nativert::initPt2ArchiveConstantsPybind(m);
}
#endif

View File

@ -0,0 +1,6 @@
#pragma once
#include <pybind11/pybind11.h>
namespace torch::nativert {
void initPt2ArchiveConstantsPybind(pybind11::module& m);
} // namespace torch::nativert

View File

@ -0,0 +1,225 @@
import ctypes
import io
import logging
import os
from typing import Any, Dict, Optional, Tuple
###############################################################################
#
# This file contains the code to package a model for Sigmoid in open source.
# Please do not introduce fbcode dependencies here. (e.g. aiplatform, fbgemm, thrift)
#
###############################################################################
import torch
from torch.export.experimental.package.pt2_archive import PT2ArchiveWriter
from torch._C.nativert.pt2_archive_constants import ( # @manual=//sigmoid/core/package:pt2_archive_constants_pybind
CONSTANTS_DIR,
CUSTOM_OBJ_FILENAME_PREFIX,
MODELS_FILENAME_FORMAT,
SAMPLE_INPUTS_FILENAME_FORMAT,
TENSOR_CONSTANT_FILENAME_PREFIX,
WEIGHT_FILENAME_PREFIX,
WEIGHTS_DIR,
)
from torch._export.serde.schema import Model, Program
from torch._export.serde.serialize import (
_enable_graph_inputs_of_type_nn_module,
_to_json_bytes,
ExportedProgramSerializer,
)
from torch.export import ExportedProgram
from torch.utils import _pytree as pytree
logger = logging.getLogger(__name__)
def get_raw_tensor_bytes(value: torch.Tensor) -> bytes:
# NOTE: don't chain .cpu() with .data_ptr(). If an HtoD copy needs to be
# performed, the CPU copy needs to be kept alive when its underlying
# memory is accessed.
if value.data_ptr():
cpu_tensor = value.cpu().contiguous()
# we store the raw bytes of tensor. Tensor metadata is stored separately
value_bytes = bytes(
ctypes.cast(
cpu_tensor.data_ptr(),
ctypes.POINTER(ctypes.c_ubyte * value.element_size() * value.numel()),
).contents
)
else:
# for empty tensor
value_bytes = bytes()
return value_bytes
def _package_state_dict(
exported_program: ExportedProgram,
zip_file: PT2ArchiveWriter,
) -> Dict[str, str]:
idx = zip_file.count_prefix(os.path.join(WEIGHTS_DIR, WEIGHT_FILENAME_PREFIX))
qual_name_to_id = {} # Map from tensor name to its name in xl_weights folder
for name, tensor in exported_program.state_dict.items():
if tensor.is_meta:
logger.error(
f"Skipping state_dict packing of {name} since it's a meta tensor"
)
continue
param_name = f"{WEIGHT_FILENAME_PREFIX}{idx}"
idx += 1
qual_name_to_id[name] = param_name
archive_path = os.path.join(WEIGHTS_DIR, param_name)
tensor_bytes = get_raw_tensor_bytes(tensor)
zip_file.write_bytes(archive_path, tensor_bytes)
return qual_name_to_id
def _package_constants(
exported_program: ExportedProgram,
zip_file: PT2ArchiveWriter,
) -> Dict[str, Any]:
tensor_idx = zip_file.count_prefix(
os.path.join(CONSTANTS_DIR, TENSOR_CONSTANT_FILENAME_PREFIX)
)
custom_obj_idx = zip_file.count_prefix(
os.path.join(CONSTANTS_DIR, CUSTOM_OBJ_FILENAME_PREFIX)
)
qual_name_to_id = {} # Map from constant name to its name in constants folder
for name, constant in exported_program.constants.items():
if isinstance(constant, torch.Tensor):
# Save the constant tensors the same way we save weights
tensor_name = f"{TENSOR_CONSTANT_FILENAME_PREFIX}{tensor_idx}"
tensor_idx += 1
qual_name_to_id[name] = tensor_name
archive_path = os.path.join(CONSTANTS_DIR, tensor_name)
tensor_bytes = get_raw_tensor_bytes(constant)
zip_file.write_bytes(archive_path, tensor_bytes)
elif isinstance(constant, torch._C.ScriptObject):
# CustomClassHolder objects implement their own pickle saving
# functions.
logger.info(f"saving script object {name}")
custom_obj_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}"
custom_obj_idx += 1
qual_name_to_id[name] = custom_obj_name
archive_path = os.path.join(CONSTANTS_DIR, custom_obj_name)
custom_obj_bytes = torch._C._pickle_save(constant)
zip_file.write_bytes(archive_path, custom_obj_bytes)
else:
raise RuntimeError(f"Serializing constant type {type(constant)} nyi")
return qual_name_to_id
# `sample_inputs` will be pytree_flatten as a python list, and saved via `torch.save()`
# in the zip archive as "data/sample_inputs/<model_name>.pt".
# In C++, this can be loaded via `torch::pickle_load`.
# See sigmoid::ModelRunner::loadSampleInputs() for more details.
def _package_sample_inputs(
sample_args: Tuple[pytree.PyTree, ...],
sample_kwargs: Dict[str, pytree.PyTree],
zip_file: PT2ArchiveWriter,
model_name: str,
) -> str:
sample_inputs_path = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name)
buffer = io.BytesIO()
# Convert torch.nn.Parameter to torch.Tensor
# This is needed because torch::pickle_load() doesn't support torch.nn.Parameter
def get_tensor(x: torch.Tensor) -> torch.Tensor:
if isinstance(x, torch.nn.Parameter):
return x.data
else:
return x
# args must be a tuple, not a list
sample_args = tuple(pytree.tree_map(get_tensor, sample_args))
# kwargs must be a dict
sample_kwargs = pytree.tree_map(get_tensor, sample_kwargs)
torch.save((sample_args, sample_kwargs), buffer)
zip_file.write_bytes(sample_inputs_path, buffer.getvalue())
return sample_inputs_path
def package_model(
exported_program: ExportedProgram,
model_name: str,
zip_file: PT2ArchiveWriter,
delegates: Optional[Dict[str, ExportedProgram]] = None,
) -> None:
"""
Saving in the format that's compatible with sigmoid ModelRunner.
"""
def _make_program(
ep: ExportedProgram,
) -> Program:
with _enable_graph_inputs_of_type_nn_module(ep.example_inputs):
return Program(
methods={
"forward": ExportedProgramSerializer()
.serialize(ep)
.exported_program
},
)
if delegates is None:
delegates = {}
# Packaging for Weight
tensor_path_map = _package_state_dict(exported_program, zip_file)
# Packaging for Constants (tensor constants, custom class obj)
constant_path_map = _package_constants(exported_program, zip_file)
example_args, example_kwargs = exported_program.example_inputs
# Packaging for input samples
assert (
example_args is not None or example_kwargs is not None
), "PT2 Archive requires sample inputs to be present"
sample_inputs_path = _package_sample_inputs( # noqa
example_args, example_kwargs, zip_file, model_name
)
model_json = Model(
name=model_name,
tensorPaths=tensor_path_map,
program=_make_program(exported_program),
delegates={key: _make_program(ep) for key, ep in delegates.items()},
deviceAllocationMap={},
constantPaths=constant_path_map,
)
model_bytes: bytes = _to_json_bytes(model_json)
# Packaging for model
zip_file.write_bytes(MODELS_FILENAME_FORMAT.format(model_name), model_bytes)
# Include readable graph for debugging
zip_file.write_string(
f"models/debug/{model_name}_readable.txt", str(exported_program)
)
zip_file.write_string(
f"models/debug/{model_name}_device_annotation.txt",
exported_program.graph_module.print_readable(
print_output=False, include_device=True
),
)

View File

@ -0,0 +1,139 @@
# pyre-unsafe
import glob
import io
import logging
import os
import zipfile
from typing import BinaryIO, Union
import torch
from torch._C.nativert.pt2_archive_constants import ( # @manual=//sigmoid/core/package:pt2_archive_constants_pybind
ARCHIVE_FORMAT_PATH,
ARCHIVE_FORMAT_VALUE,
ARCHIVE_VERSION_PATH,
ARCHIVE_VERSION_VALUE,
)
logger: logging.Logger = logging.getLogger(__name__)
def is_sigmoid_package(serialized_model: Union[bytes, str]) -> bool:
try:
zip_reader = zipfile.ZipFile(
io.BytesIO(serialized_model)
if isinstance(serialized_model, bytes)
else serialized_model
)
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
if archive_format_path in zip_reader.namelist():
return zip_reader.read(archive_format_path) == b"pt2"
except Exception as ex:
logger.info(f"Model is not a sigmoid package: {ex}")
return False
class PT2ArchiveWriter:
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
# pyre-ignore
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer)
# NOTICE: version here is different from the archive_version
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
# archive_version is the version of the PT2 archive spec, which write to /archive_version
self.archive_file.set_min_version(6)
def __enter__(self):
return self
def __exit__(self, *args):
if not self.has_record(ARCHIVE_FORMAT_PATH):
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
if not self.has_record(ARCHIVE_VERSION_PATH):
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
self.close()
def has_record(self, name: str) -> bool:
return name in self.archive_file.get_all_written_records()
def count_prefix(self, prefix: str) -> int:
return sum(
1
for record in self.archive_file.get_all_written_records()
if record.startswith(prefix)
)
def write_bytes(self, name: str, data: bytes) -> None:
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
self.archive_file.write_record(name, data, len(data))
def write_string(self, name: str, data: str) -> None:
assert isinstance(data, str), f"Expected string but got {type(data)}"
data_bytes = data.encode()
self.write_bytes(name, data_bytes)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
with open(file_path, "rb") as f:
file_bytes = f.read()
self.write_bytes(name, file_bytes)
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
"""
Copy a folder into the archive.
archive_dir: The destination folder inside the archive.
folder_dir: The source folder on disk.
"""
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
file_paths = filter(
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
)
for file_path in file_paths:
filename = os.path.relpath(file_path, folder_dir)
archive_path = os.path.join(archive_dir, filename)
self.write_file(archive_path, file_path)
def close(self) -> None:
self.archive_file.write_end_of_file()
class PT2ArchiveReader:
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
# pyre-ignore
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer)
assert (
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
), "Invalid archive format"
def __enter__(self):
return self
def __exit__(self, *args):
# torch._C.PyTorchFileReader doesn't have a close method
pass
def read_bytes(self, name: str) -> bytes:
return self.archive_file.get_record(name)
def read_string(self, name: str) -> str:
data = self.read_bytes(name)
return data.decode()
def archive_version(self) -> int:
try:
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
except Exception:
# if archive_version is not found, it means the archive is older than version 0.
# In this case, we assume the archive is version 0.
archive_version = 0
return int(archive_version)

View File

@ -0,0 +1,139 @@
# pyre-unsafe
import glob
import io
import logging
import os
import zipfile
from typing import BinaryIO, Union
import torch
from torch._C.nativert.package.pt2_archive_constants_pybind import ( # @manual=//sigmoid/core/package:pt2_archive_constants_pybind
ARCHIVE_FORMAT_PATH,
ARCHIVE_FORMAT_VALUE,
ARCHIVE_VERSION_PATH,
ARCHIVE_VERSION_VALUE,
)
logger: logging.Logger = logging.getLogger(__name__)
def is_sigmoid_package(serialized_model: Union[bytes, str]) -> bool:
try:
zip_reader = zipfile.ZipFile(
io.BytesIO(serialized_model)
if isinstance(serialized_model, bytes)
else serialized_model
)
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
if archive_format_path in zip_reader.namelist():
return zip_reader.read(archive_format_path) == b"pt2"
except Exception as ex:
logger.info(f"Model is not a sigmoid package: {ex}")
return False
class PT2ArchiveWriter:
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
# pyre-ignore
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer)
# NOTICE: version here is different from the archive_version
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
# archive_version is the version of the PT2 archive spec, which write to /archive_version
self.archive_file.set_min_version(6)
def __enter__(self):
return self
def __exit__(self, *args):
if not self.has_record(ARCHIVE_FORMAT_PATH):
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
if not self.has_record(ARCHIVE_VERSION_PATH):
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
self.close()
def has_record(self, name: str) -> bool:
return name in self.archive_file.get_all_written_records()
def count_prefix(self, prefix: str) -> int:
return sum(
1
for record in self.archive_file.get_all_written_records()
if record.startswith(prefix)
)
def write_bytes(self, name: str, data: bytes) -> None:
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
self.archive_file.write_record(name, data, len(data))
def write_string(self, name: str, data: str) -> None:
assert isinstance(data, str), f"Expected string but got {type(data)}"
data_bytes = data.encode()
self.write_bytes(name, data_bytes)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
with open(file_path, "rb") as f:
file_bytes = f.read()
self.write_bytes(name, file_bytes)
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
"""
Copy a folder into the archive.
archive_dir: The destination folder inside the archive.
folder_dir: The source folder on disk.
"""
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
file_paths = filter(
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
)
for file_path in file_paths:
filename = os.path.relpath(file_path, folder_dir)
archive_path = os.path.join(archive_dir, filename)
self.write_file(archive_path, file_path)
def close(self) -> None:
self.archive_file.write_end_of_file()
class PT2ArchiveReader:
def __init__(self, archive_path_or_buffer: Union[str, BinaryIO]):
# pyre-ignore
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer)
assert (
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
), "Invalid archive format"
def __enter__(self):
return self
def __exit__(self, *args):
# torch._C.PyTorchFileReader doesn't have a close method
pass
def read_bytes(self, name: str) -> bytes:
return self.archive_file.get_record(name)
def read_string(self, name: str) -> str:
data = self.read_bytes(name)
return data.decode()
def archive_version(self) -> int:
try:
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
except Exception:
# if archive_version is not found, it means the archive is older than version 0.
# In this case, we assume the archive is version 0.
archive_version = 0
return int(archive_version)