[nativert] add memory overlap debug assertion (#157290)

Summary: better safe than sorry. will throw if memory overlap detected when using planned tensors and debug mode is enabled -- this will make our planning unit tests more robust.

Test Plan:
ci

Rollback Plan:

Differential Revision: D77327841

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157290
Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
This commit is contained in:
dolpm
2025-07-14 19:12:41 +00:00
committed by PyTorch MergeBot
parent f87d117939
commit 725c327284
19 changed files with 604 additions and 76 deletions

View File

@ -24,6 +24,15 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp
${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp
${TORCH_ROOT}/torch/nativert/executor/Executor.cpp
${TORCH_ROOT}/torch/nativert/kernels/KernelFactory.cpp
${TORCH_ROOT}/torch/nativert/executor/ConstantFolder.cpp
${TORCH_ROOT}/torch/nativert/executor/GraphExecutorBase.cpp
${TORCH_ROOT}/torch/nativert/executor/SerialGraphExecutor.cpp
${TORCH_ROOT}/torch/nativert/executor/ParallelGraphExecutor.cpp
${TORCH_ROOT}/torch/nativert/kernels/AutoFunctionalizeKernel.cpp
${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp
${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp
)
add_executable(test_nativert

View File

@ -0,0 +1,182 @@
#include <gtest/gtest.h>
#include <fmt/format.h>
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/executor/Executor.h>
#include <torch/nativert/kernels/KernelFactory.h>
using namespace ::testing;
using namespace torch::nativert;
using AliasTestCase = std::tuple<
std::string /* value */,
AllocationLifetime,
bool /* is_alias */,
bool /* is_storage_associated_with_output */,
c10::FastSet<std::string> /* source(s) */>;
class AliasAnalyzerTests : public testing::Test {
void SetUp() override {}
void TearDown() override {
test_cases.clear();
model.clear();
}
public:
void setTestCases(std::vector<AliasTestCase> cases) {
test_cases = std::move(cases);
}
void setModel(std::string m) {
model = std::move(m);
}
void run() {
EXPECT_FALSE(test_cases.empty());
EXPECT_FALSE(model.empty());
ExecutorConfig cfg;
cfg.enableStaticCPUKernels = true;
auto graph = stringToGraph(model);
auto kernels = KernelFactory().initializeNodeKernels(
*graph, nullptr, cfg, {}, nullptr);
auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels);
AliasAnalyzer analyzer(*graph, kernelSchemas);
for (
auto& [value, lifetime, is_alias, is_storage_associated_with_output, srcs] :
test_cases) {
LOG(INFO) << fmt::format(
"running test: value={}, lifetime=({}, {}), is_alias={}, is_storage_associated_with_output={}, src={}",
value,
lifetime.start,
lifetime.end,
is_alias,
is_storage_associated_with_output,
srcs.empty() ? "{}"
: std::accumulate(
srcs.begin(),
srcs.end(),
std::string{},
[](std::string cur, const std::string& src) {
cur.append(",");
cur.append(src);
return cur;
}));
auto* v = graph->getValue(value);
EXPECT_EQ(analyzer.lifetime(v), lifetime);
EXPECT_EQ(analyzer.is_alias(v), is_alias);
EXPECT_EQ(
analyzer.is_storage_associated_with_output(v),
is_storage_associated_with_output);
const auto* resolved_srcs = analyzer.get_sources_of_alias(v);
if (resolved_srcs /* ensure set equality between *resolved_srcs and srcs */) {
EXPECT_FALSE(srcs.empty());
EXPECT_EQ(resolved_srcs->size(), srcs.size());
for (const auto& resolved_src : *resolved_srcs) {
EXPECT_TRUE(srcs.erase(std::string(resolved_src->name())) == 1);
}
EXPECT_TRUE(srcs.empty());
} else {
EXPECT_TRUE(srcs.empty());
}
}
}
private:
std::string model;
std::vector<AliasTestCase> test_cases;
};
TEST_F(AliasAnalyzerTests, TestNoAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%res = torch.ops.aten.clone.default(self=%out_t, memory_format=None)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 2), false, false, {}},
{"res", AllocationLifetime(2, 3), false, true, {}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestSimpleAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%res = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 3), false, true, {}},
{"res", AllocationLifetime(2, 3), true, false, {"out_t"}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestDeepAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%a1 = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 4), false, true, {}},
{"a1", AllocationLifetime(2, 4), true, false, {"out_t"}},
{"res", AllocationLifetime(3, 4), true, false, {"out_t"}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestPackedListUnpack) {
setModel(R"(
graph(%a, %b, %c, %d):
%input_list[] = prim.ListPack(l0=%a, l1=%b, l2=%c, l3=%d)
%x0, %x1, %x2, %x3 = prim.ListUnpack(input=%input_list)
return (%x1, %x3))");
setTestCases({
{"a", AllocationLifetime(0, 2), false, false, {}},
{"x0", AllocationLifetime(2, 2), true, false, {"a"}},
{"b", AllocationLifetime(0, 3), false, true, {}},
{"x1", AllocationLifetime(2, 3), true, false, {"b"}},
{"c", AllocationLifetime(0, 2), false, false, {}},
{"x2", AllocationLifetime(2, 2), true, false, {"c"}},
{"d", AllocationLifetime(0, 3), false, true, {}},
{"x3", AllocationLifetime(2, 3), true, false, {"d"}},
});
run();
}
TEST_F(AliasAnalyzerTests, TestAmbiguousSourceOfAlias) {
setModel(R"(
graph(%y0, %y1):
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%out_t2 = torch.ops.aten.matmul.default(self=%y0, other=%y1)
%a1 = prim.VarStack(l0=%out_t, l1=%out_t2)
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
return (%res))");
setTestCases({
{"out_t", AllocationLifetime(1, 5), false, true, {}},
{"out_t2", AllocationLifetime(2, 5), false, true, {}},
{"a1", AllocationLifetime(3, 5), true, false, {"out_t", "out_t2"}},
{"res", AllocationLifetime(4, 5), true, false, {"out_t", "out_t2"}},
});
run();
}

View File

@ -46,13 +46,14 @@ class ExecutionFrame {
}
template <typename CB>
auto withMemoryPlanner(CB&& cb) {
auto withManagedMemory(CB&& cb) {
if (!layoutManager_) {
return std::forward<CB>(cb)();
return std::forward<CB>(cb)(nullptr);
}
LayoutManagerGuard guard(*layoutManager_);
return std::forward<CB>(cb)();
return std::forward<CB>(cb)(
const_cast<const LayoutManager*>(layoutManager_.get()));
}
std::vector<c10::IValue> tryMoveUserOutputs();

View File

@ -19,30 +19,31 @@ namespace torch::nativert {
Executor::Executor(
torch::nativert::ExecutorConfig executorConfig,
std::shared_ptr<Graph> graph,
std::shared_ptr<Weights> weights,
const Placement& placement,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
const MakeProxyExecutorFn& makeProxyExecutorFunc)
const std::shared_ptr<Weights>& weights,
Placement placement,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader,
MakeProxyExecutorFn makeProxyExecutorFunc)
: executorConfig_(std::move(executorConfig)),
graph_(std::move(graph)),
placement_(placement),
placement_(std::move(placement)),
constantFolder_(
executorConfig_.runConstFolding
? std::optional<ConstantFolder>(*graph_)
: std::nullopt),
makeProxyExecutorFunc_(makeProxyExecutorFunc),
makeProxyExecutorFunc_(std::move(makeProxyExecutorFunc)),
executionFrames_(executorConfig_.maxNumConcurrentThreads),
clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
numExecutionFrames_(0),
lastClearedTimestamp_(getCurrentTimestampSeconds()) {
if (weights) {
initialize(std::move(weights), std::move(pytorchStreamReader));
initialize(weights, pytorchStreamReader);
}
}
void Executor::initialize(
std::shared_ptr<Weights> weights,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
const std::shared_ptr<Weights>& weights,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader) {
auto start = std::chrono::high_resolution_clock::now();
@ -51,7 +52,7 @@ void Executor::initialize(
weights,
executorConfig_,
placement_,
std::move(pytorchStreamReader),
pytorchStreamReader,
makeProxyExecutorFunc_);
if (constantFolder_.has_value()) {
@ -113,13 +114,14 @@ void Executor::atomicSwapWeights(std::shared_ptr<Weights> weights) {
}
}
void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
void Executor::maybeRunConstantFolding(
const std::shared_ptr<Weights>& weights) {
for (auto& execution : constFoldingExecutions_) {
ExecutionFrame constFoldingFrame(execution.executor->graph());
std::vector<c10::IValue> inputs;
inputs.reserve(graph_->signature().inputsToWeights().size());
for (const auto& [_, name] : graph_->signature().inputsToWeights()) {
inputs.push_back(weights->at(name));
inputs.emplace_back(weights->at(name));
}
auto outputs = execution.executor->execute(constFoldingFrame, inputs);
@ -130,7 +132,7 @@ void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
}
}
void Executor::processWeights(std::shared_ptr<Weights> weights) {
void Executor::processWeights(const std::shared_ptr<Weights>& weights) {
maybeRunConstantFolding(weights);
if (constantFolder_.has_value()) {
constantFolder_->evaluate(*weights);
@ -352,10 +354,10 @@ std::vector<c10::IValue> Executor::execute(
}
ProfileMetrics Executor::benchmarkIndividualNodes(
std::vector<std::vector<c10::IValue>> inputsList,
const std::vector<std::vector<c10::IValue>>& inputsList,
const uint32_t warmupRuns,
const uint32_t mainRuns) {
CHECK(inputsList.size() > 0) << "Need at least one input to benchmark";
CHECK(!inputsList.empty()) << "Need at least one input to benchmark";
CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run";
for (const auto& inputs : inputsList) {
@ -378,8 +380,9 @@ int64_t Executor::getCurrentTimestampSeconds() const {
std::vector<DelegateExecutor*> Executor::getDelegates() {
std::vector<DelegateExecutor*> delegates;
delegates.reserve(delegateExecutors_.size());
for (const auto& delegateExecutor : delegateExecutors_) {
delegates.push_back(delegateExecutor.get());
delegates.emplace_back(delegateExecutor.get());
}
return delegates;
}

View File

@ -79,11 +79,11 @@ class Executor {
Executor(
torch::nativert::ExecutorConfig executorConfig,
std::shared_ptr<Graph> graph,
std::shared_ptr<Weights> weights,
const Placement& placement = Placement(),
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
const std::shared_ptr<Weights>& weights,
Placement placement = Placement(),
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader = nullptr,
const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr);
MakeProxyExecutorFn makeProxyExecutorFunc = nullptr);
std::shared_ptr<Weights> getWeights() {
std::shared_ptr<Weights> ret;
@ -91,7 +91,7 @@ class Executor {
return ret;
}
void processWeights(std::shared_ptr<Weights> weights);
void processWeights(const std::shared_ptr<Weights>& weights);
void atomicSwapWeights(std::shared_ptr<Weights> weights);
// This API only returns the flattened UserOutputs,
@ -106,7 +106,7 @@ class Executor {
const ITreeSpec& inputTreeSpec);
ProfileMetrics benchmarkIndividualNodes(
std::vector<std::vector<c10::IValue>> inputsList,
const std::vector<std::vector<c10::IValue>>& inputsList,
const uint32_t warmupRuns,
const uint32_t mainRuns);
@ -141,8 +141,8 @@ class Executor {
c10::Synchronized<std::shared_ptr<Weights>> weights_;
void initialize(
std::shared_ptr<Weights> weights,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
const std::shared_ptr<Weights>& weights,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader);
ExecutorFramePtr getExecutorFrameFromPool();
@ -171,7 +171,7 @@ class Executor {
ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete;
};
void maybeRunConstantFolding(std::shared_ptr<Weights> weights);
void maybeRunConstantFolding(const std::shared_ptr<Weights>& weights);
void validateInputs(const std::vector<c10::IValue>& inputs) const;
// Helper method to get current timestamp in seconds

View File

@ -32,7 +32,7 @@ void GraphExecutorBase::fillUserInputs(
ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
ExecutionFrame& executionFrame,
std::vector<std::vector<c10::IValue>> inputsList,
const std::vector<std::vector<c10::IValue>>& inputsList,
const uint32_t warmupRuns,
const uint32_t mainRuns) {
// TODO: add support for memory profiling
@ -112,7 +112,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
results.totalNodesCount = numNodes;
for (const auto& r : results.timePerNodeType) {
const std::string& target = r.first;
results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime;
results.percentPerNodeType[target] = r.second * 100.0f / results.totalTime;
}
return results;
}

View File

@ -51,7 +51,7 @@ class GraphExecutorBase {
ProfileMetrics benchmarkIndividualNodes(
ExecutionFrame& executionFrame,
std::vector<std::vector<c10::IValue>> inputs,
const std::vector<std::vector<c10::IValue>>& inputs,
const uint32_t warmup_runs,
const uint32_t main_runs);

View File

@ -22,11 +22,13 @@ ThreadPoolExecutor::~ThreadPoolExecutor() {
}
C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() {
// NOLINTNEXTLINE(misc-use-internal-linkage)
thread_local moodycamel::ProducerToken ptok(*work_);
return ptok;
}
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() {
// NOLINTNEXTLINE(misc-use-internal-linkage)
thread_local moodycamel::ConsumerToken ctok(*work_);
return ctok;
}
@ -39,7 +41,7 @@ void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) {
void ThreadPoolExecutor::start(int32_t numThreads) {
stopped_ = false;
for (int32_t i = 0; i < numThreads; ++i) {
threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this));
threads_.emplace_back(&ThreadPoolExecutor::loop, this);
}
}
@ -62,16 +64,17 @@ void ThreadPoolExecutor::loop() {
void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) {
session->addWork();
work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session));
work_->enqueue(ptok(), [unit, this, session] { unit->run(this, session); });
sem_->release();
}
void ThreadPoolExecutor::add(
SessionState* session,
std::vector<WorkUnit*>::const_iterator&& begin,
const std::vector<WorkUnit*>::const_iterator&& end) {
std::vector<WorkUnit*>::const_iterator begin,
const std::vector<WorkUnit*>::const_iterator& end) {
const auto count = end - begin;
// NOLINTNEXTLINE(bugprone-switch-missing-default-case)
switch (count) {
case 0: {
return;
@ -86,16 +89,17 @@ void ThreadPoolExecutor::add(
std::vector<Work> runnables;
runnables.reserve(count);
for (; begin != end; ++begin) {
runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session));
runnables.emplace_back(
[capture0 = *begin, this, session] { capture0->run(this, session); });
}
work_->enqueue_bulk(ptok(), runnables.begin(), count);
sem_->release(count);
sem_->release(static_cast<int32_t>(count));
}
void ThreadPoolExecutor::stop() {
stopped_ = true;
sem_->release(threads_.size());
sem_->release(static_cast<int32_t>(threads_.size()));
std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); });
threads_.clear();
@ -136,10 +140,10 @@ void ThreadPoolExecutor::run(
}
void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) {
thread_local std::vector<WorkUnit*> newWorkUnits;
thread_local c10::InferenceMode mode;
/* thread_local */ std::vector<WorkUnit*> newWorkUnits;
/* thread_local */ c10::InferenceMode mode;
WorkUnit* unit = this;
/* thread_local */ WorkUnit* unit = this;
while (true) {
unit->kernel->compute(session->frame());
@ -219,7 +223,7 @@ ParallelGraphExecutor::ParallelGraphExecutor(
}
}
executor_.start(executorConfig.maxParallelOps);
executor_.start(static_cast<int32_t>(executorConfig.maxParallelOps));
}
std::vector<c10::IValue> ParallelGraphExecutor::execute(

View File

@ -46,8 +46,8 @@ class ThreadPoolExecutor {
void add(SessionState* session, WorkUnit* unit);
void add(
SessionState* session,
std::vector<WorkUnit*>::const_iterator&& begin,
const std::vector<WorkUnit*>::const_iterator&& end);
std::vector<WorkUnit*>::const_iterator begin,
const std::vector<WorkUnit*>::const_iterator& end);
C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok();
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok();

View File

@ -14,11 +14,17 @@ std::vector<c10::IValue> SerialGraphExecutor::execute(
std::vector<c10::IValue> SerialGraphExecutor::executeWithPrefilledFrame(
ExecutionFrame& executionFrame) {
executionFrame.withMemoryPlanner([&]() {
executionFrame.withManagedMemory([&](const LayoutManager* layout_manager) {
// Execute kernels for all nodes except prim.Input and prim.Output
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
nodeKernels_[nodeIdx]->compute(executionFrame);
#ifndef NDEBUG
if (layout_manager != nullptr) {
layout_manager->assert_no_overlapping_storages(nodeIdx);
}
#endif
// don't free intermediate values when static memory planning is enabled
if (executorConfig_.tryFreeUnmanagedValuesAfterUse) {
// Free the intermediate values that are no used anymore

View File

@ -23,18 +23,32 @@ AliasAnalyzer::AliasAnalyzer(
maybe_update_aliases_from_schema(node, schemas);
}
maybe_extend_lifetimes(graph);
// squash_deep_aliases this will populate aliases_
// with a mapping from each alias to its backed
// source (i.e., the value that owns the underlying
// dataptr for said alias)
squash_deep_aliases(graph);
// set all non-aliasing outputs. outputs
// that are aliased will be set later when
// lifetimes are extended
for (const auto* output : graph.outputs()) {
if (!is_alias(output)) {
values_associated_with_outputs_.insert(output);
values_associated_with_outputs_.emplace(output);
}
}
maybe_extend_lifetimes(graph);
log_state();
}
alive_values_at_time_.resize(graph.nodes().size());
for (const auto& [v, lifetime] : lifetimes_) {
for (const auto t : c10::irange(lifetime.start, lifetime.end + 1)) {
alive_values_at_time_[t].emplace_back(v);
}
}
} // namespace torch::nativert
bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack(
const Node& node,
@ -63,7 +77,7 @@ bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack(
create_or_update_lifetime(input, i);
create_or_update_lifetime(output, i);
aliases_[output].insert(input);
aliases_[output].emplace(input);
}
return true;
@ -96,7 +110,7 @@ void AliasAnalyzer::maybe_update_aliases_from_schema(
VLOG(1) << node.target()
<< " may contain input/output alias: " << input->id() << " -> "
<< output->id();
aliases_[output].insert(input);
aliases_[output].emplace(input);
}
}
}
@ -109,6 +123,56 @@ void AliasAnalyzer::create_or_update_lifetime(const Value* value, size_t i) {
}
}
void AliasAnalyzer::squash_deep_aliases(const Graph& graph) {
for (auto& node : graph.nodes()) {
for (const auto& output : node.outputs()) {
auto aliasIt = aliases_.find(output);
if (aliasIt == aliases_.end()) {
continue;
}
c10::FastSet<const Value*> filtered_srcs;
auto& srcs = aliasIt->second;
for (const auto* src : srcs) {
// check if this source is an alias itself,
// making 'output' a deep alias (i.e.,
// an alias of an alias)
// we want aliases_[x] to return the value from which x
// inherits its dataptr.
// as such, we want to add values that do not meet this
// criteria (i.e., those that are aliases).
// in practice, there can only be 1 value that meets this
// criteria (at a time), but there are some cases where
// this is ambiguous (e.g., where the spec doesn't exist,
// dealing with variadics)
auto srcAliasIt = aliases_.find(src);
if (srcAliasIt == aliases_.end()) {
filtered_srcs.emplace(src);
continue;
}
// since we are going from the beginning of the graph
// to the end of the graph we can assume that these
// aliases, which have already been visited, have already
// been squashed.
auto& srcs_of_src = srcAliasIt->second;
for (const auto* src_of_src : srcs_of_src) {
// if the source of the source is not an alias
// (i.e., it has ownership over it's data ptr)
// then we want to add it as a source of 'output'
if (aliases_.find(src_of_src) == aliases_.end()) {
filtered_srcs.emplace(src_of_src);
}
}
}
srcs = std::move(filtered_srcs);
}
}
}
void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) {
c10::FastSet<const Value*> extended;
@ -129,10 +193,11 @@ void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) {
VLOG(1) << "extended EOL of value " << src->id() << " to " << eol;
extended.insert(src);
extended.emplace(src);
if (eol == graph.nodes().size() - 1 /* aliases output */) {
values_associated_with_outputs_.insert(src);
if (aliases_.find(src) == aliases_.end() &&
eol == graph.nodes().size() - 1 /* aliases output */) {
values_associated_with_outputs_.emplace(src);
}
}
}

View File

@ -14,26 +14,38 @@ class AliasAnalyzer {
const Graph& graph,
const c10::FastMap<std::string /* target */, FunctionSchema>& schemas);
C10_ALWAYS_INLINE const AllocationLifetime& lifetime(
const c10::FastSet<const Value*>* get_sources_of_alias(
const Value* value) const {
const auto it = aliases_.find(value);
if (it == aliases_.end()) {
return nullptr;
}
return &it->second;
}
const AllocationLifetime& lifetime(const Value* value) const {
return lifetimes_.at(value);
}
C10_ALWAYS_INLINE bool is_alias(const Value* value) const {
bool is_alias(const Value* value) const {
return aliases_.find(value) != aliases_.end();
}
C10_ALWAYS_INLINE bool is_storage_associated_with_output(
const Value* value) const {
bool is_storage_associated_with_output(const Value* value) const {
return values_associated_with_outputs_.find(value) !=
values_associated_with_outputs_.end();
}
C10_ALWAYS_INLINE const c10::FastSet<const Value*>&
values_associated_with_output_storage() const {
const c10::FastSet<const Value*>& values_associated_with_output_storage()
const {
return values_associated_with_outputs_;
}
const std::vector<const Value*>& alive_values_at_time(size_t time) const {
TORCH_CHECK_LT(time, alive_values_at_time_.size());
return alive_values_at_time_[time];
}
private:
// listunpack operations who take a list that has
// been created with a listpack operation should
@ -72,14 +84,35 @@ class AliasAnalyzer {
// even if they aren't explicitly considered outputs)
void maybe_extend_lifetimes(const Graph& graph);
// in the event that we have aliases-of-aliases
// we want to make sure that the 'sources'
// are propagated
//
// e.g.,
// %x0 = ...
// %x1 = some_aliasing_op(x0)
// %x2 = some_aliasing_op(x1)
//
// we want aliases_[x2] = x0
// instead of aliases[x2] = x1
//
// the result is aliases_ will contain a
// mapping from each alias to its backed
// source (i.e., the value that owns its
// associated dataptr)
void squash_deep_aliases(const Graph& graph);
void log_state() const;
// mapping from alias to the set of values that it aliases
// mapping from alias to its source
c10::FastMap<const Value*, c10::FastSet<const Value*>> aliases_;
c10::FastMap<const Value*, AllocationLifetime> lifetimes_;
// non-aliasing outputs or non-aliasing intermediates that are aliased by
// outputs
c10::FastSet<const Value*> values_associated_with_outputs_;
// alive_values_at_time_[i] = values that are "alive" during the
// computation of node i
std::vector<std::vector<const Value*>> alive_values_at_time_;
};
} // namespace torch::nativert

View File

@ -4,6 +4,7 @@
#include <c10/core/CPUAllocator.h>
#include <c10/util/Enumerate.h>
#include <c10/util/FbcodeMaps.h>
namespace torch::nativert {
@ -147,6 +148,9 @@ void LayoutManager::populate_tensor_values() {
planned_tensors_max_nbytes_local_.resize(value_ids.size());
for (const auto&& [i, v] : c10::enumerate(value_ids)) {
#ifndef NDEBUG
value_to_vector_idx_map_[v] = i;
#endif
planned_tensors_[i] = &parent_frame_.getIValue(v).toTensor();
}
@ -157,6 +161,165 @@ void LayoutManager::populate_tensor_values() {
}
}
#ifndef NDEBUG
void LayoutManager::assert_no_overlapping_storages(
size_t graph_node_idx) const {
if (state_ != LayoutManagerState::Running) {
return;
}
/*
for each value
(either an input or output)
ensure that the associated storage
slice lies within the allocated slice
if it is managed (or if it is an alias,
we can use the slice allocated to its source)
---
also ensure that the current index lies
within the lifetime of this value
*/
const auto& alias_analyzer = planner_.get_alias_analyzer();
// get the 'active' values during the execution of nodes[graph_node_idx]
const auto& alive_values =
alias_analyzer.alive_values_at_time(graph_node_idx);
// make sure active memory intervals are non-overlapping
// by sorting them by start, and ensuring
// cur.start > prev.end for each
//
// by default, the pairs are compared lexicographically.
// ref: https://cplusplus.com/reference/utility/pair/operators/
//
// in our case, this means that leftmost (on the number line) intervals will
// come first, and if the start point of two intervals is the same, they will
// be sorted by their relative widths (in increasing order)
//
// e.g., the ordering for the following usage intervals
//
// |######1######|
// |######2######|
// |######3#####|
//
// would be [1,3,2]
std::multiset<std::pair<size_t, size_t>> intervals;
planner_.with_plan([&](const LayoutPlan& plan) {
// prevent recomputation from occuring
c10::FastSet<ValueId> checked_values;
// check that some arbitrary storage (defined by the allocation start and
// the size in bytes) lies within the slice allocated for value_id during
// planning.
//
// if the checks pass, add the interval [alloc_start, alloc_start +
// alloc_nbytes) to the set of intervals
auto check_allocation_bounds =
[&](ValueId value_id, size_t alloc_start, size_t alloc_end) -> void {
if (!checked_values.emplace(value_id).second /* already checked */) {
return;
}
auto& alloc = plan.allocations[value_to_vector_idx_map_.at(value_id)];
TORCH_CHECK_GE(alloc_start, alloc.offset);
TORCH_CHECK_LT(alloc_end, alloc.offset + alloc.size);
intervals.emplace(alloc_start, alloc_end);
};
// get the inclusive storage interval for some value (i.e.,
// [buffer_storage_start_offset, buffer_storage_start_offset +
// storage_nbytes]) that represents the sub-slice of the runtime-managed
// buffer allocated to this tensor
auto try_get_interval =
[&](ValueId value_id) -> std::optional<std::pair<size_t, size_t>> {
const auto& iv = parent_frame_.getIValue(value_id);
if (!iv.isTensor()) {
return std::nullopt;
}
const auto& storage_impl = iv.toTensor().storage().unsafeGetStorageImpl();
const auto storage_nbytes = storage_impl->nbytes();
if (const auto start = layout_buffer_.get_offset_from_ptr(
storage_impl->data_ptr().get());
start.has_value()) {
return std::make_pair(*start, *start + storage_nbytes - 1);
}
return std::nullopt;
};
for (auto v : alive_values) {
// sanity check lifetimes to ensure this
// value ~should~ be alive at this point
const auto& lt = alias_analyzer.lifetime(v);
TORCH_CHECK_GE(graph_node_idx, lt.start);
TORCH_CHECK_LE(graph_node_idx, lt.end);
const auto interval = try_get_interval(v->id());
if (C10_UNLIKELY(!interval.has_value())) {
continue;
}
auto& [v_start, v_end] = *interval;
// it's possible that v is an alias, in which case
// we want to try to get the source (i.e., the value)
// that actually owns the storage
//
// NOTE: it's possible the source is ambiguous, hence
// why get_sources_of_alias returns a set (although it's usually a
// singleton set)
if (const auto* srcs_of_v = alias_analyzer.get_sources_of_alias(v);
srcs_of_v != nullptr /* v is an alias */) {
// 1. v's interval is a sub-interval of ~a~ source's interval and we
// want to add the source's interval to the set of intervals
// 2. v possibly got re-alloc'd / is not actually aliasing anything
// and we want to add v's interval to the set of intervals
bool found_viable_source = false;
for (const auto* src_of_v : *srcs_of_v) {
const auto src_interval = try_get_interval(src_of_v->id());
if (C10_UNLIKELY(!src_interval.has_value())) {
continue;
}
auto& [src_of_v_start, src_of_v_end] = *src_interval;
if (v_start >= src_of_v_start && v_end <= src_of_v_end) {
check_allocation_bounds(
src_of_v->id(), src_of_v_start, src_of_v_end);
found_viable_source = true;
break;
}
}
if (!found_viable_source) {
check_allocation_bounds(v->id(), v_start, v_end);
}
} else /* if v isn't an alias */ {
check_allocation_bounds(v->id(), v_start, v_end);
}
}
});
// if we only have less than two active intervals,
// it isn't possible to have overlap...
if (intervals.size() < 2) {
return;
}
// ensure that no 'active' buffer intervals are overlapping
auto it = intervals.begin();
size_t prev_end = it->second;
while (++it != intervals.end()) {
TORCH_CHECK_LT(prev_end, it->first /* cur_start */);
prev_end = it->second;
}
}
#endif
void LayoutManager::try_update_historical_max_nbytes() {
for (const auto i : c10::irange(planned_tensors_.size())) {
auto nbytes = get_aligned_nbytes(planned_tensors_[i]->nbytes());

View File

@ -24,6 +24,20 @@ struct ContiguousLayoutBuffer {
ContiguousLayoutBuffer& operator=(const ContiguousLayoutBuffer& other) =
delete;
std::optional<size_t> get_offset_from_ptr(void* offset_ptr) const {
void* raw_ptr = data_ptr_.get();
if (!raw_ptr || !offset_ptr) {
return std::nullopt;
}
auto offset = reinterpret_cast<uint8_t*>(offset_ptr) -
reinterpret_cast<uint8_t*>(raw_ptr);
return offset < 0 || static_cast<size_t>(offset) >= size_
? std::nullopt
: std::optional(offset);
}
void* get_ptr_with_offset(size_t offset) {
void* raw_ptr = data_ptr_.get();
TORCH_CHECK_NOTNULL(raw_ptr);
@ -148,10 +162,32 @@ class LayoutManager {
torch::nativert::LayoutManagerSettings settings = {});
~LayoutManager() = default;
// this is a debugging function. it will slow thing down SIGNIFICANTLY
// so please ensure this isn't called unless you really need it
//
// it checks a few things in between node executions...
//
// 1. ensures all 'alive' values are within the bounds of thier lifetimes
// - this is the definition of a sanity check since the live-sets are built
// from the lifetimes lol. if this fails, something is very very wrong
// 2. ensures that all planned values are within the bounds of their
// allocated storage buffer slices
// - if the value is an alias, ensure the alias is within the bounds
// of the source value
// 3. ensures that all planned value data-ptrs are non-overlapping
#ifndef NDEBUG
void assert_no_overlapping_storages(
size_t
graph_node_idx /* the graph node that is currently being computed */)
const;
#endif
private:
friend class LayoutManagerGuard;
void allocate();
void deallocate_and_plan();
private:
#ifdef LayoutPlannerTests_TEST_FRIENDS
LayoutPlannerTests_TEST_FRIENDS;
#endif
@ -178,6 +214,9 @@ class LayoutManager {
std::vector<const at::Tensor*> planned_tensors_;
std::vector<size_t> planned_tensors_max_nbytes_local_;
#ifndef NDEBUG
c10::FastMap<ValueId, size_t> value_to_vector_idx_map_;
#endif
ContiguousLayoutBuffer layout_buffer_;
ContiguousStorageImplBuffer storage_impl_buffer_;

View File

@ -16,9 +16,18 @@ LayoutPlanner::LayoutPlanner(
const c10::FastMap<std::string /* target */, FunctionSchema>& kernelSchemas,
const std::vector<bool>& persistentValues,
const torch::nativert::LayoutPlannerSettings& settings)
: managed_values_(graph.values().size()), settings_(settings) {
auto value_to_allocation_spec = c10::FastMap<const Value*, AllocationSpec>{};
: managed_values_(graph.values().size()),
#ifndef NDEBUG
alias_analyzer_(graph, kernelSchemas),
#endif
settings_(settings) {
#ifndef NDEBUG
auto& alias_analyzer = alias_analyzer_;
#else
auto alias_analyzer = AliasAnalyzer(graph, kernelSchemas);
#endif
auto value_to_allocation_spec = c10::FastMap<const Value*, AllocationSpec>{};
std::set<const Value*> input_values_set_;
for (const auto* nv : graph.userInputs()) {

View File

@ -8,6 +8,7 @@
#include <c10/util/FbcodeMaps.h>
#include <c10/util/LeftRight.h>
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
#include <torch/nativert/executor/memory/FunctionSchema.h>
#include <torch/nativert/executor/memory/LayoutPlannerAlgorithm.h>
#include <torch/nativert/executor/memory/LayoutPlannerSettings.h>
@ -61,7 +62,17 @@ class LayoutPlanner {
const std::vector<ValueId>& get_planned_values() const;
const std::vector<ValueId>& get_unplanned_values() const;
C10_ALWAYS_INLINE bool is_managed(ValueId id) {
#ifndef NDEBUG
const AliasAnalyzer& get_alias_analyzer() const {
return alias_analyzer_;
}
#endif
size_t num_values() const {
return managed_values_.size();
}
bool is_managed(ValueId id) {
TORCH_CHECK_LT(static_cast<size_t>(id), managed_values_.size());
return managed_values_[id];
}
@ -120,6 +131,9 @@ class LayoutPlanner {
LayoutPlannerAlgorithm* algorithm_;
c10::LeftRight<LayoutPlan> plan_;
#ifndef NDEBUG
AliasAnalyzer alias_analyzer_;
#endif
torch::nativert::LayoutPlannerSettings settings_;
};

View File

@ -11,15 +11,14 @@ UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
op_(getOperatorForTarget(
std::get<std::string>(node->attributes()[0].value))),
schema_(op_.schema()),
arguments_(prefillStackWithStaticArgs(node, schema_)) {
arguments_(prefillStackWithStaticArgs(node, schema_)),
numOutputs_(static_cast<int>(schema_.returns().size())) {
for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) {
if (schemaArg.alias_info() != nullptr &&
schemaArg.alias_info()->isWrite()) {
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
}
}
numOutputs_ = schema_.returns().size();
}
void UnsafeAutoFunctionalizeKernel::computeInternal(

View File

@ -62,7 +62,7 @@ c10::Device inferTargetDevice(
} // namespace
inline constexpr std::string_view kSymIntOps[] = {
inline constexpr std::array<std::string_view, 7> kSymIntOps = {
"_operator.floordiv",
"_operator.mod",
"torch.sym_int",
@ -72,7 +72,7 @@ inline constexpr std::string_view kSymIntOps[] = {
"torch.sym_min",
};
inline constexpr std::string_view kSymBoolOps[] = {
inline constexpr std::array<std::string_view, 8> kSymBoolOps = {
"_operator.eq",
"_operator.ne",
"_operator.le",
@ -83,14 +83,14 @@ inline constexpr std::string_view kSymBoolOps[] = {
"torch.sym_not",
};
inline constexpr std::string_view kSymFloatOps[] = {
inline constexpr std::array<std::string_view, 4> kSymFloatOps = {
"torch._sym_sqrt",
"math.trunc",
"_operator.neg",
"_operator.truediv",
};
inline constexpr std::string_view kScalarBinaryOps[] = {
inline constexpr std::array<std::string_view, 4> kScalarBinaryOps = {
"_operator.mul",
"_operator.add",
"_operator.sub",
@ -124,10 +124,11 @@ void KernelFactory::registerHandler(
ExecutionKernels KernelFactory::initializeNodeKernels(
const Graph& graph,
std::shared_ptr<Weights> weights,
const std::shared_ptr<Weights>& weights,
const torch::nativert::ExecutorConfig& executorConfig,
const Placement& placement,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader,
const MakeProxyExecutorFn& makeProxyExecutorFunc) {
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
@ -216,7 +217,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
*subgraph, weights, executorConfig, placement);
CHECK(executionKernels.delegateExecutors.empty())
<< "HigherOrderKernel does not support delegates";
CHECK(executionKernels.constFoldingExecutions.size() == 0)
CHECK(executionKernels.constFoldingExecutions.empty())
<< "HigherOrderKernel does not support const folding";
if (executorConfig.maxParallelOps > 1) {
graphExecutors.emplace_back(

View File

@ -74,10 +74,10 @@ class KernelFactory {
ExecutionKernels initializeNodeKernels(
const Graph& graph,
std::shared_ptr<Weights> weights,
const std::shared_ptr<Weights>& weights,
const torch::nativert::ExecutorConfig& executorConfig,
const Placement& placement,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader = nullptr,
const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr);