mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] add memory overlap debug assertion (#157290)
Summary: better safe than sorry. will throw if memory overlap detected when using planned tensors and debug mode is enabled -- this will make our planning unit tests more robust. Test Plan: ci Rollback Plan: Differential Revision: D77327841 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157290 Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
This commit is contained in:
@ -24,6 +24,15 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/Executor.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/KernelFactory.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/ConstantFolder.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/GraphExecutorBase.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/SerialGraphExecutor.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/ParallelGraphExecutor.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/AutoFunctionalizeKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp
|
||||
)
|
||||
|
||||
add_executable(test_nativert
|
||||
|
182
test/cpp/nativert/test_alias_analyzer.cpp
Normal file
182
test/cpp/nativert/test_alias_analyzer.cpp
Normal file
@ -0,0 +1,182 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
#include <torch/nativert/executor/Executor.h>
|
||||
#include <torch/nativert/kernels/KernelFactory.h>
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace torch::nativert;
|
||||
|
||||
using AliasTestCase = std::tuple<
|
||||
std::string /* value */,
|
||||
AllocationLifetime,
|
||||
bool /* is_alias */,
|
||||
bool /* is_storage_associated_with_output */,
|
||||
c10::FastSet<std::string> /* source(s) */>;
|
||||
|
||||
class AliasAnalyzerTests : public testing::Test {
|
||||
void SetUp() override {}
|
||||
|
||||
void TearDown() override {
|
||||
test_cases.clear();
|
||||
model.clear();
|
||||
}
|
||||
|
||||
public:
|
||||
void setTestCases(std::vector<AliasTestCase> cases) {
|
||||
test_cases = std::move(cases);
|
||||
}
|
||||
|
||||
void setModel(std::string m) {
|
||||
model = std::move(m);
|
||||
}
|
||||
|
||||
void run() {
|
||||
EXPECT_FALSE(test_cases.empty());
|
||||
EXPECT_FALSE(model.empty());
|
||||
|
||||
ExecutorConfig cfg;
|
||||
cfg.enableStaticCPUKernels = true;
|
||||
|
||||
auto graph = stringToGraph(model);
|
||||
auto kernels = KernelFactory().initializeNodeKernels(
|
||||
*graph, nullptr, cfg, {}, nullptr);
|
||||
auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels);
|
||||
|
||||
AliasAnalyzer analyzer(*graph, kernelSchemas);
|
||||
|
||||
for (
|
||||
auto& [value, lifetime, is_alias, is_storage_associated_with_output, srcs] :
|
||||
test_cases) {
|
||||
LOG(INFO) << fmt::format(
|
||||
"running test: value={}, lifetime=({}, {}), is_alias={}, is_storage_associated_with_output={}, src={}",
|
||||
value,
|
||||
lifetime.start,
|
||||
lifetime.end,
|
||||
is_alias,
|
||||
is_storage_associated_with_output,
|
||||
srcs.empty() ? "{}"
|
||||
: std::accumulate(
|
||||
srcs.begin(),
|
||||
srcs.end(),
|
||||
std::string{},
|
||||
[](std::string cur, const std::string& src) {
|
||||
cur.append(",");
|
||||
cur.append(src);
|
||||
return cur;
|
||||
}));
|
||||
auto* v = graph->getValue(value);
|
||||
EXPECT_EQ(analyzer.lifetime(v), lifetime);
|
||||
EXPECT_EQ(analyzer.is_alias(v), is_alias);
|
||||
EXPECT_EQ(
|
||||
analyzer.is_storage_associated_with_output(v),
|
||||
is_storage_associated_with_output);
|
||||
const auto* resolved_srcs = analyzer.get_sources_of_alias(v);
|
||||
if (resolved_srcs /* ensure set equality between *resolved_srcs and srcs */) {
|
||||
EXPECT_FALSE(srcs.empty());
|
||||
EXPECT_EQ(resolved_srcs->size(), srcs.size());
|
||||
for (const auto& resolved_src : *resolved_srcs) {
|
||||
EXPECT_TRUE(srcs.erase(std::string(resolved_src->name())) == 1);
|
||||
}
|
||||
EXPECT_TRUE(srcs.empty());
|
||||
} else {
|
||||
EXPECT_TRUE(srcs.empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string model;
|
||||
std::vector<AliasTestCase> test_cases;
|
||||
};
|
||||
|
||||
TEST_F(AliasAnalyzerTests, TestNoAlias) {
|
||||
setModel(R"(
|
||||
graph(%y0, %y1):
|
||||
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
|
||||
%res = torch.ops.aten.clone.default(self=%out_t, memory_format=None)
|
||||
return (%res))");
|
||||
|
||||
setTestCases({
|
||||
{"out_t", AllocationLifetime(1, 2), false, false, {}},
|
||||
{"res", AllocationLifetime(2, 3), false, true, {}},
|
||||
});
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(AliasAnalyzerTests, TestSimpleAlias) {
|
||||
setModel(R"(
|
||||
graph(%y0, %y1):
|
||||
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
|
||||
%res = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
|
||||
return (%res))");
|
||||
|
||||
setTestCases({
|
||||
{"out_t", AllocationLifetime(1, 3), false, true, {}},
|
||||
{"res", AllocationLifetime(2, 3), true, false, {"out_t"}},
|
||||
});
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(AliasAnalyzerTests, TestDeepAlias) {
|
||||
setModel(R"(
|
||||
graph(%y0, %y1):
|
||||
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
|
||||
%a1 = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1)
|
||||
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
|
||||
return (%res))");
|
||||
|
||||
setTestCases({
|
||||
{"out_t", AllocationLifetime(1, 4), false, true, {}},
|
||||
{"a1", AllocationLifetime(2, 4), true, false, {"out_t"}},
|
||||
{"res", AllocationLifetime(3, 4), true, false, {"out_t"}},
|
||||
});
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(AliasAnalyzerTests, TestPackedListUnpack) {
|
||||
setModel(R"(
|
||||
graph(%a, %b, %c, %d):
|
||||
%input_list[] = prim.ListPack(l0=%a, l1=%b, l2=%c, l3=%d)
|
||||
%x0, %x1, %x2, %x3 = prim.ListUnpack(input=%input_list)
|
||||
return (%x1, %x3))");
|
||||
|
||||
setTestCases({
|
||||
{"a", AllocationLifetime(0, 2), false, false, {}},
|
||||
{"x0", AllocationLifetime(2, 2), true, false, {"a"}},
|
||||
{"b", AllocationLifetime(0, 3), false, true, {}},
|
||||
{"x1", AllocationLifetime(2, 3), true, false, {"b"}},
|
||||
{"c", AllocationLifetime(0, 2), false, false, {}},
|
||||
{"x2", AllocationLifetime(2, 2), true, false, {"c"}},
|
||||
{"d", AllocationLifetime(0, 3), false, true, {}},
|
||||
{"x3", AllocationLifetime(2, 3), true, false, {"d"}},
|
||||
});
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(AliasAnalyzerTests, TestAmbiguousSourceOfAlias) {
|
||||
setModel(R"(
|
||||
graph(%y0, %y1):
|
||||
%out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1)
|
||||
%out_t2 = torch.ops.aten.matmul.default(self=%y0, other=%y1)
|
||||
%a1 = prim.VarStack(l0=%out_t, l1=%out_t2)
|
||||
%res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1)
|
||||
return (%res))");
|
||||
|
||||
setTestCases({
|
||||
{"out_t", AllocationLifetime(1, 5), false, true, {}},
|
||||
{"out_t2", AllocationLifetime(2, 5), false, true, {}},
|
||||
{"a1", AllocationLifetime(3, 5), true, false, {"out_t", "out_t2"}},
|
||||
{"res", AllocationLifetime(4, 5), true, false, {"out_t", "out_t2"}},
|
||||
});
|
||||
|
||||
run();
|
||||
}
|
@ -46,13 +46,14 @@ class ExecutionFrame {
|
||||
}
|
||||
|
||||
template <typename CB>
|
||||
auto withMemoryPlanner(CB&& cb) {
|
||||
auto withManagedMemory(CB&& cb) {
|
||||
if (!layoutManager_) {
|
||||
return std::forward<CB>(cb)();
|
||||
return std::forward<CB>(cb)(nullptr);
|
||||
}
|
||||
|
||||
LayoutManagerGuard guard(*layoutManager_);
|
||||
return std::forward<CB>(cb)();
|
||||
return std::forward<CB>(cb)(
|
||||
const_cast<const LayoutManager*>(layoutManager_.get()));
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> tryMoveUserOutputs();
|
||||
|
@ -19,30 +19,31 @@ namespace torch::nativert {
|
||||
Executor::Executor(
|
||||
torch::nativert::ExecutorConfig executorConfig,
|
||||
std::shared_ptr<Graph> graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const Placement& placement,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const MakeProxyExecutorFn& makeProxyExecutorFunc)
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
Placement placement,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader,
|
||||
MakeProxyExecutorFn makeProxyExecutorFunc)
|
||||
: executorConfig_(std::move(executorConfig)),
|
||||
graph_(std::move(graph)),
|
||||
placement_(placement),
|
||||
placement_(std::move(placement)),
|
||||
constantFolder_(
|
||||
executorConfig_.runConstFolding
|
||||
? std::optional<ConstantFolder>(*graph_)
|
||||
: std::nullopt),
|
||||
makeProxyExecutorFunc_(makeProxyExecutorFunc),
|
||||
makeProxyExecutorFunc_(std::move(makeProxyExecutorFunc)),
|
||||
executionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||
clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads),
|
||||
numExecutionFrames_(0),
|
||||
lastClearedTimestamp_(getCurrentTimestampSeconds()) {
|
||||
if (weights) {
|
||||
initialize(std::move(weights), std::move(pytorchStreamReader));
|
||||
initialize(weights, pytorchStreamReader);
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::initialize(
|
||||
std::shared_ptr<Weights> weights,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader) {
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
@ -51,7 +52,7 @@ void Executor::initialize(
|
||||
weights,
|
||||
executorConfig_,
|
||||
placement_,
|
||||
std::move(pytorchStreamReader),
|
||||
pytorchStreamReader,
|
||||
makeProxyExecutorFunc_);
|
||||
|
||||
if (constantFolder_.has_value()) {
|
||||
@ -113,13 +114,14 @@ void Executor::atomicSwapWeights(std::shared_ptr<Weights> weights) {
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
|
||||
void Executor::maybeRunConstantFolding(
|
||||
const std::shared_ptr<Weights>& weights) {
|
||||
for (auto& execution : constFoldingExecutions_) {
|
||||
ExecutionFrame constFoldingFrame(execution.executor->graph());
|
||||
std::vector<c10::IValue> inputs;
|
||||
inputs.reserve(graph_->signature().inputsToWeights().size());
|
||||
for (const auto& [_, name] : graph_->signature().inputsToWeights()) {
|
||||
inputs.push_back(weights->at(name));
|
||||
inputs.emplace_back(weights->at(name));
|
||||
}
|
||||
|
||||
auto outputs = execution.executor->execute(constFoldingFrame, inputs);
|
||||
@ -130,7 +132,7 @@ void Executor::maybeRunConstantFolding(std::shared_ptr<Weights> weights) {
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::processWeights(std::shared_ptr<Weights> weights) {
|
||||
void Executor::processWeights(const std::shared_ptr<Weights>& weights) {
|
||||
maybeRunConstantFolding(weights);
|
||||
if (constantFolder_.has_value()) {
|
||||
constantFolder_->evaluate(*weights);
|
||||
@ -352,10 +354,10 @@ std::vector<c10::IValue> Executor::execute(
|
||||
}
|
||||
|
||||
ProfileMetrics Executor::benchmarkIndividualNodes(
|
||||
std::vector<std::vector<c10::IValue>> inputsList,
|
||||
const std::vector<std::vector<c10::IValue>>& inputsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns) {
|
||||
CHECK(inputsList.size() > 0) << "Need at least one input to benchmark";
|
||||
CHECK(!inputsList.empty()) << "Need at least one input to benchmark";
|
||||
CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run";
|
||||
|
||||
for (const auto& inputs : inputsList) {
|
||||
@ -378,8 +380,9 @@ int64_t Executor::getCurrentTimestampSeconds() const {
|
||||
|
||||
std::vector<DelegateExecutor*> Executor::getDelegates() {
|
||||
std::vector<DelegateExecutor*> delegates;
|
||||
delegates.reserve(delegateExecutors_.size());
|
||||
for (const auto& delegateExecutor : delegateExecutors_) {
|
||||
delegates.push_back(delegateExecutor.get());
|
||||
delegates.emplace_back(delegateExecutor.get());
|
||||
}
|
||||
return delegates;
|
||||
}
|
||||
|
@ -79,11 +79,11 @@ class Executor {
|
||||
Executor(
|
||||
torch::nativert::ExecutorConfig executorConfig,
|
||||
std::shared_ptr<Graph> graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const Placement& placement = Placement(),
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
Placement placement = Placement(),
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader = nullptr,
|
||||
const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr);
|
||||
MakeProxyExecutorFn makeProxyExecutorFunc = nullptr);
|
||||
|
||||
std::shared_ptr<Weights> getWeights() {
|
||||
std::shared_ptr<Weights> ret;
|
||||
@ -91,7 +91,7 @@ class Executor {
|
||||
return ret;
|
||||
}
|
||||
|
||||
void processWeights(std::shared_ptr<Weights> weights);
|
||||
void processWeights(const std::shared_ptr<Weights>& weights);
|
||||
void atomicSwapWeights(std::shared_ptr<Weights> weights);
|
||||
|
||||
// This API only returns the flattened UserOutputs,
|
||||
@ -106,7 +106,7 @@ class Executor {
|
||||
const ITreeSpec& inputTreeSpec);
|
||||
|
||||
ProfileMetrics benchmarkIndividualNodes(
|
||||
std::vector<std::vector<c10::IValue>> inputsList,
|
||||
const std::vector<std::vector<c10::IValue>>& inputsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns);
|
||||
|
||||
@ -141,8 +141,8 @@ class Executor {
|
||||
c10::Synchronized<std::shared_ptr<Weights>> weights_;
|
||||
|
||||
void initialize(
|
||||
std::shared_ptr<Weights> weights,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader);
|
||||
|
||||
ExecutorFramePtr getExecutorFrameFromPool();
|
||||
@ -171,7 +171,7 @@ class Executor {
|
||||
ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete;
|
||||
};
|
||||
|
||||
void maybeRunConstantFolding(std::shared_ptr<Weights> weights);
|
||||
void maybeRunConstantFolding(const std::shared_ptr<Weights>& weights);
|
||||
void validateInputs(const std::vector<c10::IValue>& inputs) const;
|
||||
|
||||
// Helper method to get current timestamp in seconds
|
||||
|
@ -32,7 +32,7 @@ void GraphExecutorBase::fillUserInputs(
|
||||
|
||||
ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<std::vector<c10::IValue>> inputsList,
|
||||
const std::vector<std::vector<c10::IValue>>& inputsList,
|
||||
const uint32_t warmupRuns,
|
||||
const uint32_t mainRuns) {
|
||||
// TODO: add support for memory profiling
|
||||
@ -112,7 +112,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
|
||||
results.totalNodesCount = numNodes;
|
||||
for (const auto& r : results.timePerNodeType) {
|
||||
const std::string& target = r.first;
|
||||
results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime;
|
||||
results.percentPerNodeType[target] = r.second * 100.0f / results.totalTime;
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ class GraphExecutorBase {
|
||||
|
||||
ProfileMetrics benchmarkIndividualNodes(
|
||||
ExecutionFrame& executionFrame,
|
||||
std::vector<std::vector<c10::IValue>> inputs,
|
||||
const std::vector<std::vector<c10::IValue>>& inputs,
|
||||
const uint32_t warmup_runs,
|
||||
const uint32_t main_runs);
|
||||
|
||||
|
@ -22,11 +22,13 @@ ThreadPoolExecutor::~ThreadPoolExecutor() {
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() {
|
||||
// NOLINTNEXTLINE(misc-use-internal-linkage)
|
||||
thread_local moodycamel::ProducerToken ptok(*work_);
|
||||
return ptok;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() {
|
||||
// NOLINTNEXTLINE(misc-use-internal-linkage)
|
||||
thread_local moodycamel::ConsumerToken ctok(*work_);
|
||||
return ctok;
|
||||
}
|
||||
@ -39,7 +41,7 @@ void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) {
|
||||
void ThreadPoolExecutor::start(int32_t numThreads) {
|
||||
stopped_ = false;
|
||||
for (int32_t i = 0; i < numThreads; ++i) {
|
||||
threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this));
|
||||
threads_.emplace_back(&ThreadPoolExecutor::loop, this);
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,16 +64,17 @@ void ThreadPoolExecutor::loop() {
|
||||
|
||||
void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) {
|
||||
session->addWork();
|
||||
work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session));
|
||||
work_->enqueue(ptok(), [unit, this, session] { unit->run(this, session); });
|
||||
sem_->release();
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::add(
|
||||
SessionState* session,
|
||||
std::vector<WorkUnit*>::const_iterator&& begin,
|
||||
const std::vector<WorkUnit*>::const_iterator&& end) {
|
||||
std::vector<WorkUnit*>::const_iterator begin,
|
||||
const std::vector<WorkUnit*>::const_iterator& end) {
|
||||
const auto count = end - begin;
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-switch-missing-default-case)
|
||||
switch (count) {
|
||||
case 0: {
|
||||
return;
|
||||
@ -86,16 +89,17 @@ void ThreadPoolExecutor::add(
|
||||
std::vector<Work> runnables;
|
||||
runnables.reserve(count);
|
||||
for (; begin != end; ++begin) {
|
||||
runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session));
|
||||
runnables.emplace_back(
|
||||
[capture0 = *begin, this, session] { capture0->run(this, session); });
|
||||
}
|
||||
|
||||
work_->enqueue_bulk(ptok(), runnables.begin(), count);
|
||||
sem_->release(count);
|
||||
sem_->release(static_cast<int32_t>(count));
|
||||
}
|
||||
|
||||
void ThreadPoolExecutor::stop() {
|
||||
stopped_ = true;
|
||||
sem_->release(threads_.size());
|
||||
sem_->release(static_cast<int32_t>(threads_.size()));
|
||||
|
||||
std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); });
|
||||
threads_.clear();
|
||||
@ -136,10 +140,10 @@ void ThreadPoolExecutor::run(
|
||||
}
|
||||
|
||||
void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) {
|
||||
thread_local std::vector<WorkUnit*> newWorkUnits;
|
||||
thread_local c10::InferenceMode mode;
|
||||
/* thread_local */ std::vector<WorkUnit*> newWorkUnits;
|
||||
/* thread_local */ c10::InferenceMode mode;
|
||||
|
||||
WorkUnit* unit = this;
|
||||
/* thread_local */ WorkUnit* unit = this;
|
||||
|
||||
while (true) {
|
||||
unit->kernel->compute(session->frame());
|
||||
@ -219,7 +223,7 @@ ParallelGraphExecutor::ParallelGraphExecutor(
|
||||
}
|
||||
}
|
||||
|
||||
executor_.start(executorConfig.maxParallelOps);
|
||||
executor_.start(static_cast<int32_t>(executorConfig.maxParallelOps));
|
||||
}
|
||||
|
||||
std::vector<c10::IValue> ParallelGraphExecutor::execute(
|
||||
|
@ -46,8 +46,8 @@ class ThreadPoolExecutor {
|
||||
void add(SessionState* session, WorkUnit* unit);
|
||||
void add(
|
||||
SessionState* session,
|
||||
std::vector<WorkUnit*>::const_iterator&& begin,
|
||||
const std::vector<WorkUnit*>::const_iterator&& end);
|
||||
std::vector<WorkUnit*>::const_iterator begin,
|
||||
const std::vector<WorkUnit*>::const_iterator& end);
|
||||
|
||||
C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok();
|
||||
C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok();
|
||||
|
@ -14,11 +14,17 @@ std::vector<c10::IValue> SerialGraphExecutor::execute(
|
||||
|
||||
std::vector<c10::IValue> SerialGraphExecutor::executeWithPrefilledFrame(
|
||||
ExecutionFrame& executionFrame) {
|
||||
executionFrame.withMemoryPlanner([&]() {
|
||||
executionFrame.withManagedMemory([&](const LayoutManager* layout_manager) {
|
||||
// Execute kernels for all nodes except prim.Input and prim.Output
|
||||
for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) {
|
||||
nodeKernels_[nodeIdx]->compute(executionFrame);
|
||||
|
||||
#ifndef NDEBUG
|
||||
if (layout_manager != nullptr) {
|
||||
layout_manager->assert_no_overlapping_storages(nodeIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
// don't free intermediate values when static memory planning is enabled
|
||||
if (executorConfig_.tryFreeUnmanagedValuesAfterUse) {
|
||||
// Free the intermediate values that are no used anymore
|
||||
|
@ -23,18 +23,32 @@ AliasAnalyzer::AliasAnalyzer(
|
||||
maybe_update_aliases_from_schema(node, schemas);
|
||||
}
|
||||
|
||||
maybe_extend_lifetimes(graph);
|
||||
|
||||
// squash_deep_aliases this will populate aliases_
|
||||
// with a mapping from each alias to its backed
|
||||
// source (i.e., the value that owns the underlying
|
||||
// dataptr for said alias)
|
||||
squash_deep_aliases(graph);
|
||||
|
||||
// set all non-aliasing outputs. outputs
|
||||
// that are aliased will be set later when
|
||||
// lifetimes are extended
|
||||
for (const auto* output : graph.outputs()) {
|
||||
if (!is_alias(output)) {
|
||||
values_associated_with_outputs_.insert(output);
|
||||
values_associated_with_outputs_.emplace(output);
|
||||
}
|
||||
}
|
||||
|
||||
maybe_extend_lifetimes(graph);
|
||||
log_state();
|
||||
}
|
||||
|
||||
alive_values_at_time_.resize(graph.nodes().size());
|
||||
for (const auto& [v, lifetime] : lifetimes_) {
|
||||
for (const auto t : c10::irange(lifetime.start, lifetime.end + 1)) {
|
||||
alive_values_at_time_[t].emplace_back(v);
|
||||
}
|
||||
}
|
||||
} // namespace torch::nativert
|
||||
|
||||
bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack(
|
||||
const Node& node,
|
||||
@ -63,7 +77,7 @@ bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack(
|
||||
create_or_update_lifetime(input, i);
|
||||
create_or_update_lifetime(output, i);
|
||||
|
||||
aliases_[output].insert(input);
|
||||
aliases_[output].emplace(input);
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -96,7 +110,7 @@ void AliasAnalyzer::maybe_update_aliases_from_schema(
|
||||
VLOG(1) << node.target()
|
||||
<< " may contain input/output alias: " << input->id() << " -> "
|
||||
<< output->id();
|
||||
aliases_[output].insert(input);
|
||||
aliases_[output].emplace(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -109,6 +123,56 @@ void AliasAnalyzer::create_or_update_lifetime(const Value* value, size_t i) {
|
||||
}
|
||||
}
|
||||
|
||||
void AliasAnalyzer::squash_deep_aliases(const Graph& graph) {
|
||||
for (auto& node : graph.nodes()) {
|
||||
for (const auto& output : node.outputs()) {
|
||||
auto aliasIt = aliases_.find(output);
|
||||
if (aliasIt == aliases_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
c10::FastSet<const Value*> filtered_srcs;
|
||||
|
||||
auto& srcs = aliasIt->second;
|
||||
for (const auto* src : srcs) {
|
||||
// check if this source is an alias itself,
|
||||
// making 'output' a deep alias (i.e.,
|
||||
// an alias of an alias)
|
||||
|
||||
// we want aliases_[x] to return the value from which x
|
||||
// inherits its dataptr.
|
||||
// as such, we want to add values that do not meet this
|
||||
// criteria (i.e., those that are aliases).
|
||||
// in practice, there can only be 1 value that meets this
|
||||
// criteria (at a time), but there are some cases where
|
||||
// this is ambiguous (e.g., where the spec doesn't exist,
|
||||
// dealing with variadics)
|
||||
auto srcAliasIt = aliases_.find(src);
|
||||
if (srcAliasIt == aliases_.end()) {
|
||||
filtered_srcs.emplace(src);
|
||||
continue;
|
||||
}
|
||||
|
||||
// since we are going from the beginning of the graph
|
||||
// to the end of the graph we can assume that these
|
||||
// aliases, which have already been visited, have already
|
||||
// been squashed.
|
||||
auto& srcs_of_src = srcAliasIt->second;
|
||||
for (const auto* src_of_src : srcs_of_src) {
|
||||
// if the source of the source is not an alias
|
||||
// (i.e., it has ownership over it's data ptr)
|
||||
// then we want to add it as a source of 'output'
|
||||
if (aliases_.find(src_of_src) == aliases_.end()) {
|
||||
filtered_srcs.emplace(src_of_src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
srcs = std::move(filtered_srcs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) {
|
||||
c10::FastSet<const Value*> extended;
|
||||
|
||||
@ -129,10 +193,11 @@ void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) {
|
||||
|
||||
VLOG(1) << "extended EOL of value " << src->id() << " to " << eol;
|
||||
|
||||
extended.insert(src);
|
||||
extended.emplace(src);
|
||||
|
||||
if (eol == graph.nodes().size() - 1 /* aliases output */) {
|
||||
values_associated_with_outputs_.insert(src);
|
||||
if (aliases_.find(src) == aliases_.end() &&
|
||||
eol == graph.nodes().size() - 1 /* aliases output */) {
|
||||
values_associated_with_outputs_.emplace(src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,26 +14,38 @@ class AliasAnalyzer {
|
||||
const Graph& graph,
|
||||
const c10::FastMap<std::string /* target */, FunctionSchema>& schemas);
|
||||
|
||||
C10_ALWAYS_INLINE const AllocationLifetime& lifetime(
|
||||
const c10::FastSet<const Value*>* get_sources_of_alias(
|
||||
const Value* value) const {
|
||||
const auto it = aliases_.find(value);
|
||||
if (it == aliases_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &it->second;
|
||||
}
|
||||
|
||||
const AllocationLifetime& lifetime(const Value* value) const {
|
||||
return lifetimes_.at(value);
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE bool is_alias(const Value* value) const {
|
||||
bool is_alias(const Value* value) const {
|
||||
return aliases_.find(value) != aliases_.end();
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE bool is_storage_associated_with_output(
|
||||
const Value* value) const {
|
||||
bool is_storage_associated_with_output(const Value* value) const {
|
||||
return values_associated_with_outputs_.find(value) !=
|
||||
values_associated_with_outputs_.end();
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE const c10::FastSet<const Value*>&
|
||||
values_associated_with_output_storage() const {
|
||||
const c10::FastSet<const Value*>& values_associated_with_output_storage()
|
||||
const {
|
||||
return values_associated_with_outputs_;
|
||||
}
|
||||
|
||||
const std::vector<const Value*>& alive_values_at_time(size_t time) const {
|
||||
TORCH_CHECK_LT(time, alive_values_at_time_.size());
|
||||
return alive_values_at_time_[time];
|
||||
}
|
||||
|
||||
private:
|
||||
// listunpack operations who take a list that has
|
||||
// been created with a listpack operation should
|
||||
@ -72,14 +84,35 @@ class AliasAnalyzer {
|
||||
// even if they aren't explicitly considered outputs)
|
||||
void maybe_extend_lifetimes(const Graph& graph);
|
||||
|
||||
// in the event that we have aliases-of-aliases
|
||||
// we want to make sure that the 'sources'
|
||||
// are propagated
|
||||
//
|
||||
// e.g.,
|
||||
// %x0 = ...
|
||||
// %x1 = some_aliasing_op(x0)
|
||||
// %x2 = some_aliasing_op(x1)
|
||||
//
|
||||
// we want aliases_[x2] = x0
|
||||
// instead of aliases[x2] = x1
|
||||
//
|
||||
// the result is aliases_ will contain a
|
||||
// mapping from each alias to its backed
|
||||
// source (i.e., the value that owns its
|
||||
// associated dataptr)
|
||||
void squash_deep_aliases(const Graph& graph);
|
||||
|
||||
void log_state() const;
|
||||
|
||||
// mapping from alias to the set of values that it aliases
|
||||
// mapping from alias to its source
|
||||
c10::FastMap<const Value*, c10::FastSet<const Value*>> aliases_;
|
||||
c10::FastMap<const Value*, AllocationLifetime> lifetimes_;
|
||||
// non-aliasing outputs or non-aliasing intermediates that are aliased by
|
||||
// outputs
|
||||
c10::FastSet<const Value*> values_associated_with_outputs_;
|
||||
// alive_values_at_time_[i] = values that are "alive" during the
|
||||
// computation of node i
|
||||
std::vector<std::vector<const Value*>> alive_values_at_time_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/util/Enumerate.h>
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
@ -147,6 +148,9 @@ void LayoutManager::populate_tensor_values() {
|
||||
planned_tensors_max_nbytes_local_.resize(value_ids.size());
|
||||
|
||||
for (const auto&& [i, v] : c10::enumerate(value_ids)) {
|
||||
#ifndef NDEBUG
|
||||
value_to_vector_idx_map_[v] = i;
|
||||
#endif
|
||||
planned_tensors_[i] = &parent_frame_.getIValue(v).toTensor();
|
||||
}
|
||||
|
||||
@ -157,6 +161,165 @@ void LayoutManager::populate_tensor_values() {
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
void LayoutManager::assert_no_overlapping_storages(
|
||||
size_t graph_node_idx) const {
|
||||
if (state_ != LayoutManagerState::Running) {
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
for each value
|
||||
(either an input or output)
|
||||
ensure that the associated storage
|
||||
slice lies within the allocated slice
|
||||
if it is managed (or if it is an alias,
|
||||
we can use the slice allocated to its source)
|
||||
---
|
||||
also ensure that the current index lies
|
||||
within the lifetime of this value
|
||||
*/
|
||||
|
||||
const auto& alias_analyzer = planner_.get_alias_analyzer();
|
||||
// get the 'active' values during the execution of nodes[graph_node_idx]
|
||||
const auto& alive_values =
|
||||
alias_analyzer.alive_values_at_time(graph_node_idx);
|
||||
|
||||
// make sure active memory intervals are non-overlapping
|
||||
// by sorting them by start, and ensuring
|
||||
// cur.start > prev.end for each
|
||||
//
|
||||
// by default, the pairs are compared lexicographically.
|
||||
// ref: https://cplusplus.com/reference/utility/pair/operators/
|
||||
//
|
||||
// in our case, this means that leftmost (on the number line) intervals will
|
||||
// come first, and if the start point of two intervals is the same, they will
|
||||
// be sorted by their relative widths (in increasing order)
|
||||
//
|
||||
// e.g., the ordering for the following usage intervals
|
||||
//
|
||||
// |######1######|
|
||||
// |######2######|
|
||||
// |######3#####|
|
||||
//
|
||||
// would be [1,3,2]
|
||||
|
||||
std::multiset<std::pair<size_t, size_t>> intervals;
|
||||
|
||||
planner_.with_plan([&](const LayoutPlan& plan) {
|
||||
// prevent recomputation from occuring
|
||||
c10::FastSet<ValueId> checked_values;
|
||||
|
||||
// check that some arbitrary storage (defined by the allocation start and
|
||||
// the size in bytes) lies within the slice allocated for value_id during
|
||||
// planning.
|
||||
//
|
||||
// if the checks pass, add the interval [alloc_start, alloc_start +
|
||||
// alloc_nbytes) to the set of intervals
|
||||
auto check_allocation_bounds =
|
||||
[&](ValueId value_id, size_t alloc_start, size_t alloc_end) -> void {
|
||||
if (!checked_values.emplace(value_id).second /* already checked */) {
|
||||
return;
|
||||
}
|
||||
auto& alloc = plan.allocations[value_to_vector_idx_map_.at(value_id)];
|
||||
TORCH_CHECK_GE(alloc_start, alloc.offset);
|
||||
TORCH_CHECK_LT(alloc_end, alloc.offset + alloc.size);
|
||||
intervals.emplace(alloc_start, alloc_end);
|
||||
};
|
||||
|
||||
// get the inclusive storage interval for some value (i.e.,
|
||||
// [buffer_storage_start_offset, buffer_storage_start_offset +
|
||||
// storage_nbytes]) that represents the sub-slice of the runtime-managed
|
||||
// buffer allocated to this tensor
|
||||
auto try_get_interval =
|
||||
[&](ValueId value_id) -> std::optional<std::pair<size_t, size_t>> {
|
||||
const auto& iv = parent_frame_.getIValue(value_id);
|
||||
if (!iv.isTensor()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto& storage_impl = iv.toTensor().storage().unsafeGetStorageImpl();
|
||||
const auto storage_nbytes = storage_impl->nbytes();
|
||||
|
||||
if (const auto start = layout_buffer_.get_offset_from_ptr(
|
||||
storage_impl->data_ptr().get());
|
||||
start.has_value()) {
|
||||
return std::make_pair(*start, *start + storage_nbytes - 1);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
};
|
||||
|
||||
for (auto v : alive_values) {
|
||||
// sanity check lifetimes to ensure this
|
||||
// value ~should~ be alive at this point
|
||||
const auto& lt = alias_analyzer.lifetime(v);
|
||||
TORCH_CHECK_GE(graph_node_idx, lt.start);
|
||||
TORCH_CHECK_LE(graph_node_idx, lt.end);
|
||||
|
||||
const auto interval = try_get_interval(v->id());
|
||||
if (C10_UNLIKELY(!interval.has_value())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& [v_start, v_end] = *interval;
|
||||
|
||||
// it's possible that v is an alias, in which case
|
||||
// we want to try to get the source (i.e., the value)
|
||||
// that actually owns the storage
|
||||
//
|
||||
// NOTE: it's possible the source is ambiguous, hence
|
||||
// why get_sources_of_alias returns a set (although it's usually a
|
||||
// singleton set)
|
||||
if (const auto* srcs_of_v = alias_analyzer.get_sources_of_alias(v);
|
||||
srcs_of_v != nullptr /* v is an alias */) {
|
||||
// 1. v's interval is a sub-interval of ~a~ source's interval and we
|
||||
// want to add the source's interval to the set of intervals
|
||||
// 2. v possibly got re-alloc'd / is not actually aliasing anything
|
||||
// and we want to add v's interval to the set of intervals
|
||||
bool found_viable_source = false;
|
||||
|
||||
for (const auto* src_of_v : *srcs_of_v) {
|
||||
const auto src_interval = try_get_interval(src_of_v->id());
|
||||
if (C10_UNLIKELY(!src_interval.has_value())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& [src_of_v_start, src_of_v_end] = *src_interval;
|
||||
|
||||
if (v_start >= src_of_v_start && v_end <= src_of_v_end) {
|
||||
check_allocation_bounds(
|
||||
src_of_v->id(), src_of_v_start, src_of_v_end);
|
||||
found_viable_source = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_viable_source) {
|
||||
check_allocation_bounds(v->id(), v_start, v_end);
|
||||
}
|
||||
} else /* if v isn't an alias */ {
|
||||
check_allocation_bounds(v->id(), v_start, v_end);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// if we only have less than two active intervals,
|
||||
// it isn't possible to have overlap...
|
||||
if (intervals.size() < 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
// ensure that no 'active' buffer intervals are overlapping
|
||||
auto it = intervals.begin();
|
||||
size_t prev_end = it->second;
|
||||
while (++it != intervals.end()) {
|
||||
TORCH_CHECK_LT(prev_end, it->first /* cur_start */);
|
||||
prev_end = it->second;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void LayoutManager::try_update_historical_max_nbytes() {
|
||||
for (const auto i : c10::irange(planned_tensors_.size())) {
|
||||
auto nbytes = get_aligned_nbytes(planned_tensors_[i]->nbytes());
|
||||
|
@ -24,6 +24,20 @@ struct ContiguousLayoutBuffer {
|
||||
ContiguousLayoutBuffer& operator=(const ContiguousLayoutBuffer& other) =
|
||||
delete;
|
||||
|
||||
std::optional<size_t> get_offset_from_ptr(void* offset_ptr) const {
|
||||
void* raw_ptr = data_ptr_.get();
|
||||
if (!raw_ptr || !offset_ptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto offset = reinterpret_cast<uint8_t*>(offset_ptr) -
|
||||
reinterpret_cast<uint8_t*>(raw_ptr);
|
||||
|
||||
return offset < 0 || static_cast<size_t>(offset) >= size_
|
||||
? std::nullopt
|
||||
: std::optional(offset);
|
||||
}
|
||||
|
||||
void* get_ptr_with_offset(size_t offset) {
|
||||
void* raw_ptr = data_ptr_.get();
|
||||
TORCH_CHECK_NOTNULL(raw_ptr);
|
||||
@ -148,10 +162,32 @@ class LayoutManager {
|
||||
torch::nativert::LayoutManagerSettings settings = {});
|
||||
~LayoutManager() = default;
|
||||
|
||||
// this is a debugging function. it will slow thing down SIGNIFICANTLY
|
||||
// so please ensure this isn't called unless you really need it
|
||||
//
|
||||
// it checks a few things in between node executions...
|
||||
//
|
||||
// 1. ensures all 'alive' values are within the bounds of thier lifetimes
|
||||
// - this is the definition of a sanity check since the live-sets are built
|
||||
// from the lifetimes lol. if this fails, something is very very wrong
|
||||
// 2. ensures that all planned values are within the bounds of their
|
||||
// allocated storage buffer slices
|
||||
// - if the value is an alias, ensure the alias is within the bounds
|
||||
// of the source value
|
||||
// 3. ensures that all planned value data-ptrs are non-overlapping
|
||||
#ifndef NDEBUG
|
||||
void assert_no_overlapping_storages(
|
||||
size_t
|
||||
graph_node_idx /* the graph node that is currently being computed */)
|
||||
const;
|
||||
#endif
|
||||
|
||||
private:
|
||||
friend class LayoutManagerGuard;
|
||||
|
||||
void allocate();
|
||||
void deallocate_and_plan();
|
||||
|
||||
private:
|
||||
#ifdef LayoutPlannerTests_TEST_FRIENDS
|
||||
LayoutPlannerTests_TEST_FRIENDS;
|
||||
#endif
|
||||
@ -178,6 +214,9 @@ class LayoutManager {
|
||||
|
||||
std::vector<const at::Tensor*> planned_tensors_;
|
||||
std::vector<size_t> planned_tensors_max_nbytes_local_;
|
||||
#ifndef NDEBUG
|
||||
c10::FastMap<ValueId, size_t> value_to_vector_idx_map_;
|
||||
#endif
|
||||
|
||||
ContiguousLayoutBuffer layout_buffer_;
|
||||
ContiguousStorageImplBuffer storage_impl_buffer_;
|
||||
|
@ -16,9 +16,18 @@ LayoutPlanner::LayoutPlanner(
|
||||
const c10::FastMap<std::string /* target */, FunctionSchema>& kernelSchemas,
|
||||
const std::vector<bool>& persistentValues,
|
||||
const torch::nativert::LayoutPlannerSettings& settings)
|
||||
: managed_values_(graph.values().size()), settings_(settings) {
|
||||
auto value_to_allocation_spec = c10::FastMap<const Value*, AllocationSpec>{};
|
||||
: managed_values_(graph.values().size()),
|
||||
#ifndef NDEBUG
|
||||
alias_analyzer_(graph, kernelSchemas),
|
||||
#endif
|
||||
settings_(settings) {
|
||||
#ifndef NDEBUG
|
||||
auto& alias_analyzer = alias_analyzer_;
|
||||
#else
|
||||
auto alias_analyzer = AliasAnalyzer(graph, kernelSchemas);
|
||||
#endif
|
||||
|
||||
auto value_to_allocation_spec = c10::FastMap<const Value*, AllocationSpec>{};
|
||||
|
||||
std::set<const Value*> input_values_set_;
|
||||
for (const auto* nv : graph.userInputs()) {
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/LeftRight.h>
|
||||
|
||||
#include <torch/nativert/executor/memory/AliasAnalyzer.h>
|
||||
#include <torch/nativert/executor/memory/FunctionSchema.h>
|
||||
#include <torch/nativert/executor/memory/LayoutPlannerAlgorithm.h>
|
||||
#include <torch/nativert/executor/memory/LayoutPlannerSettings.h>
|
||||
@ -61,7 +62,17 @@ class LayoutPlanner {
|
||||
const std::vector<ValueId>& get_planned_values() const;
|
||||
const std::vector<ValueId>& get_unplanned_values() const;
|
||||
|
||||
C10_ALWAYS_INLINE bool is_managed(ValueId id) {
|
||||
#ifndef NDEBUG
|
||||
const AliasAnalyzer& get_alias_analyzer() const {
|
||||
return alias_analyzer_;
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t num_values() const {
|
||||
return managed_values_.size();
|
||||
}
|
||||
|
||||
bool is_managed(ValueId id) {
|
||||
TORCH_CHECK_LT(static_cast<size_t>(id), managed_values_.size());
|
||||
return managed_values_[id];
|
||||
}
|
||||
@ -120,6 +131,9 @@ class LayoutPlanner {
|
||||
LayoutPlannerAlgorithm* algorithm_;
|
||||
c10::LeftRight<LayoutPlan> plan_;
|
||||
|
||||
#ifndef NDEBUG
|
||||
AliasAnalyzer alias_analyzer_;
|
||||
#endif
|
||||
torch::nativert::LayoutPlannerSettings settings_;
|
||||
};
|
||||
|
||||
|
@ -11,15 +11,14 @@ UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
|
||||
op_(getOperatorForTarget(
|
||||
std::get<std::string>(node->attributes()[0].value))),
|
||||
schema_(op_.schema()),
|
||||
arguments_(prefillStackWithStaticArgs(node, schema_)) {
|
||||
arguments_(prefillStackWithStaticArgs(node, schema_)),
|
||||
numOutputs_(static_cast<int>(schema_.returns().size())) {
|
||||
for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) {
|
||||
if (schemaArg.alias_info() != nullptr &&
|
||||
schemaArg.alias_info()->isWrite()) {
|
||||
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
|
||||
}
|
||||
}
|
||||
|
||||
numOutputs_ = schema_.returns().size();
|
||||
}
|
||||
|
||||
void UnsafeAutoFunctionalizeKernel::computeInternal(
|
||||
|
@ -62,7 +62,7 @@ c10::Device inferTargetDevice(
|
||||
|
||||
} // namespace
|
||||
|
||||
inline constexpr std::string_view kSymIntOps[] = {
|
||||
inline constexpr std::array<std::string_view, 7> kSymIntOps = {
|
||||
"_operator.floordiv",
|
||||
"_operator.mod",
|
||||
"torch.sym_int",
|
||||
@ -72,7 +72,7 @@ inline constexpr std::string_view kSymIntOps[] = {
|
||||
"torch.sym_min",
|
||||
};
|
||||
|
||||
inline constexpr std::string_view kSymBoolOps[] = {
|
||||
inline constexpr std::array<std::string_view, 8> kSymBoolOps = {
|
||||
"_operator.eq",
|
||||
"_operator.ne",
|
||||
"_operator.le",
|
||||
@ -83,14 +83,14 @@ inline constexpr std::string_view kSymBoolOps[] = {
|
||||
"torch.sym_not",
|
||||
};
|
||||
|
||||
inline constexpr std::string_view kSymFloatOps[] = {
|
||||
inline constexpr std::array<std::string_view, 4> kSymFloatOps = {
|
||||
"torch._sym_sqrt",
|
||||
"math.trunc",
|
||||
"_operator.neg",
|
||||
"_operator.truediv",
|
||||
};
|
||||
|
||||
inline constexpr std::string_view kScalarBinaryOps[] = {
|
||||
inline constexpr std::array<std::string_view, 4> kScalarBinaryOps = {
|
||||
"_operator.mul",
|
||||
"_operator.add",
|
||||
"_operator.sub",
|
||||
@ -124,10 +124,11 @@ void KernelFactory::registerHandler(
|
||||
|
||||
ExecutionKernels KernelFactory::initializeNodeKernels(
|
||||
const Graph& graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const torch::nativert::ExecutorConfig& executorConfig,
|
||||
const Placement& placement,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> pytorchStreamReader,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader,
|
||||
const MakeProxyExecutorFn& makeProxyExecutorFunc) {
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
|
||||
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
|
||||
@ -216,7 +217,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
|
||||
*subgraph, weights, executorConfig, placement);
|
||||
CHECK(executionKernels.delegateExecutors.empty())
|
||||
<< "HigherOrderKernel does not support delegates";
|
||||
CHECK(executionKernels.constFoldingExecutions.size() == 0)
|
||||
CHECK(executionKernels.constFoldingExecutions.empty())
|
||||
<< "HigherOrderKernel does not support const folding";
|
||||
if (executorConfig.maxParallelOps > 1) {
|
||||
graphExecutors.emplace_back(
|
||||
|
@ -74,10 +74,10 @@ class KernelFactory {
|
||||
|
||||
ExecutionKernels initializeNodeKernels(
|
||||
const Graph& graph,
|
||||
std::shared_ptr<Weights> weights,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const torch::nativert::ExecutorConfig& executorConfig,
|
||||
const Placement& placement,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader = nullptr,
|
||||
const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr);
|
||||
|
||||
|
Reference in New Issue
Block a user