mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[nativert] oss pass graph pass registration (#160859)
Summary: att Test Plan: CI Rollback Plan: Differential Revision: D80368343 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160859 Approved by: https://github.com/georgiaphillips
This commit is contained in:
@ -632,6 +632,8 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp",
|
||||
"torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp",
|
||||
"torch/nativert/graph/passes/SubgraphRewriter.cpp",
|
||||
"torch/nativert/graph/passes/pass_manager/GraphPasses.cpp",
|
||||
"torch/nativert/graph/passes/pass_manager/PassManager.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
|
@ -37,6 +37,8 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/SubgraphRewriter.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
|
||||
)
|
||||
|
||||
add_executable(test_nativert
|
||||
|
33
test/cpp/nativert/test_pass_manager.cpp
Normal file
33
test/cpp/nativert/test_pass_manager.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
#include <torch/nativert/graph/passes/pass_manager/PassManager.h>
|
||||
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace torch::nativert;
|
||||
|
||||
TEST(PassManagerTest, TestEmptyPass) {
|
||||
GraphPassManager manager({"EmptyPass"});
|
||||
EXPECT_FALSE(manager.run(Graph::createGraph().get()));
|
||||
}
|
||||
|
||||
TEST(PassPipelineTest, TestConcat) {
|
||||
GraphPassPipeline p1({"test"});
|
||||
EXPECT_EQ(p1.size(), 1);
|
||||
EXPECT_EQ(p1.at(0), "test");
|
||||
p1.concat({"test1", "test2"});
|
||||
EXPECT_EQ(p1.at(0), "test");
|
||||
EXPECT_EQ(p1.at(1), "test1");
|
||||
EXPECT_EQ(p1.at(2), "test2");
|
||||
}
|
||||
|
||||
TEST(PassPipelineTest, TestPushFront) {
|
||||
GraphPassPipeline p1({"test"});
|
||||
EXPECT_EQ(p1.size(), 1);
|
||||
EXPECT_EQ(p1.at(0), "test");
|
||||
p1.push_front("test1");
|
||||
EXPECT_EQ(p1.at(0), "test1");
|
||||
EXPECT_EQ(p1.at(1), "test");
|
||||
}
|
84
torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h
Normal file
84
torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h
Normal file
@ -0,0 +1,84 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
using PassSignature = std::function<bool(Graph*)>;
|
||||
using GraphPassIdentifier = std::string;
|
||||
|
||||
class GraphPass {
|
||||
public:
|
||||
GraphPass(GraphPassIdentifier&& name, PassSignature&& pass)
|
||||
: name_(std::move(name)), pass_(std::move(pass)) {}
|
||||
|
||||
const GraphPassIdentifier& name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
const PassSignature& get() const {
|
||||
return pass_;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphPassIdentifier name_;
|
||||
PassSignature pass_;
|
||||
};
|
||||
|
||||
class GraphPassRegistry {
|
||||
public:
|
||||
static GraphPassRegistry& get() {
|
||||
static GraphPassRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
static void add_pass(GraphPassIdentifier&& name, PassSignature&& pass) {
|
||||
GraphPassRegistry::get().add_pass(
|
||||
GraphPass(std::move(name), std::move(pass)));
|
||||
}
|
||||
|
||||
void add_pass(GraphPass&& pass) {
|
||||
if (auto it = registry_.find(pass.name()); it != registry_.end()) {
|
||||
LOG(WARNING) << "Pass " << pass.name() << " already registered";
|
||||
return;
|
||||
}
|
||||
|
||||
GraphPassIdentifier name = pass.name();
|
||||
|
||||
LOG(INFO) << "Pass " << name << " registered";
|
||||
registry_.insert({std::move(name), std::move(pass)});
|
||||
}
|
||||
|
||||
void remove_pass(const GraphPassIdentifier& name) {
|
||||
if (!registry_.erase(name)) {
|
||||
LOG(WARNING) << "Pass " << name << " not registered but tried to remove";
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Pass " << name << " unregistered";
|
||||
}
|
||||
|
||||
const GraphPass& get_pass(const GraphPassIdentifier& name) {
|
||||
auto it = registry_.find(name);
|
||||
if (it == registry_.end()) {
|
||||
throw std::runtime_error("Pass " + name + " not registered to get");
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphPassRegistry() {
|
||||
LOG(INFO) << "Creating GraphPassRegistry";
|
||||
}
|
||||
|
||||
std::map<std::string, GraphPass> registry_;
|
||||
|
||||
public:
|
||||
GraphPassRegistry(GraphPassRegistry const&) = delete;
|
||||
void operator=(GraphPassRegistry const&) = delete;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
92
torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
Normal file
92
torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
Normal file
@ -0,0 +1,92 @@
|
||||
#include <torch/nativert/graph/passes/pass_manager/GraphPasses.h>
|
||||
|
||||
#include <torch/nativert/graph/passes/SubgraphRewriter.h>
|
||||
#include <torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
void register_base_passes() {
|
||||
GraphPassRegistry::add_pass("EmptyPass", [](Graph*) { return false; });
|
||||
|
||||
GraphPassRegistry::add_pass(
|
||||
"LinearDynamicFp16UnpackedWeight", [](Graph* graph) {
|
||||
std::string p = R"(
|
||||
graph(%i, %w, %b):
|
||||
%out_0 = torch.ops.aten.linear.default(input=%i, weight=%w, bias=%b)
|
||||
return (%out_0))";
|
||||
|
||||
std::string p_1 = R"(
|
||||
graph(%i, %w, %b):
|
||||
%out_0 = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default(X=%i, weight=%w, bias=%b)
|
||||
return (%out_0))";
|
||||
|
||||
std::string p_new = R"(
|
||||
graph(%i, %w, %b):
|
||||
%pw = torch.ops.quantized.linear_prepack_fp16.default(W=%w, B=%b)
|
||||
%out_0 = torch.ops.quantized.linear_dynamic_fp16.default(X=%i, W_prepack=%pw)
|
||||
return (%out_0))";
|
||||
|
||||
SubgraphRewriter rewriter("LinearDynamicFp16UnpackedWeight");
|
||||
rewriter.registerRewritePattern(p, p_new);
|
||||
rewriter.registerRewritePattern(p_1, p_new);
|
||||
return rewriter.run(graph);
|
||||
});
|
||||
|
||||
GraphPassRegistry::add_pass(
|
||||
"LinearReluDynamicFp16UnpackedWeight", [](Graph* graph) {
|
||||
std::string p = R"(
|
||||
graph(%i, %w, %b):
|
||||
%out_0 = torch.ops.aten.linear.default(input=%i, weight=%w, bias=%b)
|
||||
%out_1 = torch.ops.aten.relu.default(self=%out_0)
|
||||
return (%out_1))";
|
||||
|
||||
std::string p_1 = R"(
|
||||
graph(%i, %w, %b):
|
||||
%out_0 = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default(X=%i, weight=%w, bias=%b)
|
||||
%out_1 = torch.ops.aten.relu.default(self=%out_0)
|
||||
return (%out_1))";
|
||||
|
||||
std::string p_new = R"(
|
||||
graph(%i, %w, %b):
|
||||
%pw = torch.ops.quantized.linear_prepack_fp16.default(W=%w, B=%b)
|
||||
%out_0 = torch.ops.quantized.linear_relu_dynamic_fp16.default(X=%i, W_prepack=%pw)
|
||||
return (%out_0))";
|
||||
|
||||
SubgraphRewriter rewriter("LinearReluDynamicFp16UnpackedWeight");
|
||||
rewriter.registerRewritePattern(p, p_new);
|
||||
rewriter.registerRewritePattern(p_1, p_new);
|
||||
return rewriter.run(graph);
|
||||
});
|
||||
|
||||
GraphPassRegistry::add_pass("CleanUpDeadNodes", [](Graph* graph) {
|
||||
return graph->cleanupDeadNodes();
|
||||
});
|
||||
|
||||
GraphPassRegistry::add_pass("RemoveDetach", [](Graph* graph) {
|
||||
std::vector<Node*> nodesToDestroy;
|
||||
|
||||
for (auto& node : graph->nodes()) {
|
||||
if (node.target() == "torch.ops.aten.detach.default") {
|
||||
nodesToDestroy.push_back(&node);
|
||||
graph->replaceAllUses(node.outputs()[0], node.inputs()[0].value);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "[GraphPasses] Removed " << nodesToDestroy.size()
|
||||
<< " aten.detach nodes";
|
||||
|
||||
const bool mutated = !nodesToDestroy.empty();
|
||||
|
||||
for (Node* node : nodesToDestroy) {
|
||||
node->destroy();
|
||||
}
|
||||
|
||||
graph->renumberValues();
|
||||
graph->finalize();
|
||||
graph->lint();
|
||||
|
||||
return mutated;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
7
torch/nativert/graph/passes/pass_manager/GraphPasses.h
Normal file
7
torch/nativert/graph/passes/pass_manager/GraphPasses.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
void register_base_passes();
|
||||
|
||||
} // namespace torch::nativert
|
52
torch/nativert/graph/passes/pass_manager/PassManager.cpp
Normal file
52
torch/nativert/graph/passes/pass_manager/PassManager.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <torch/nativert/graph/passes/pass_manager/PassManager.h>
|
||||
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
#include <torch/nativert/graph/passes/pass_manager/GraphPasses.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
GraphPassManager::GraphPassManager(
|
||||
GraphPassPipeline pipeline,
|
||||
PassManagerOptions opts)
|
||||
: pipeline_(std::move(pipeline)), opts_(opts) {
|
||||
static c10::once_flag flag;
|
||||
c10::call_once(flag, [&]() { register_base_passes(); });
|
||||
}
|
||||
|
||||
bool GraphPassManager::run(Graph* graph) {
|
||||
bool changed = false;
|
||||
for (const auto& pass_name : pipeline_) {
|
||||
changed |= run_pass(graph, pass_name);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool GraphPassManager::run_pass(Graph* graph, const GraphPassIdentifier& name) {
|
||||
const auto& pass = GraphPassRegistry::get().get_pass(name);
|
||||
|
||||
bool changed = pass_pre_run_hook(graph, pass);
|
||||
changed |= (pass.get())(graph);
|
||||
changed |= pass_post_run_hook(graph, pass);
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool GraphPassManager::pass_pre_run_hook(Graph* graph, const GraphPass& pass) {
|
||||
if (opts_.logGraphBetweenPasses()) {
|
||||
LOG(INFO) << "Before pass: " << pass.name() << "\n"
|
||||
<< graph->toString() << "-------------------------";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GraphPassManager::pass_post_run_hook(Graph* graph, const GraphPass& pass) {
|
||||
if (opts_.logGraphBetweenPasses()) {
|
||||
LOG(INFO) << "After pass: " << pass.name() << "\n"
|
||||
<< graph->toString() << "-------------------------";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
58
torch/nativert/graph/passes/pass_manager/PassManager.h
Normal file
58
torch/nativert/graph/passes/pass_manager/PassManager.h
Normal file
@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
#include <torch/nativert/graph/passes/pass_manager/PassPipeline.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
using torch::nativert::Graph;
|
||||
using torch::nativert::GraphPass;
|
||||
|
||||
class PassManagerOptions {
|
||||
public:
|
||||
/* GETTERS */
|
||||
bool logGraphBetweenPasses() const {
|
||||
return log_graph_between_passes_;
|
||||
}
|
||||
|
||||
/* SETTERS */
|
||||
PassManagerOptions& setLogGraphBetweenPasses(bool log_graph_between_passes) {
|
||||
log_graph_between_passes_ = log_graph_between_passes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
bool log_graph_between_passes_{false};
|
||||
};
|
||||
|
||||
class GraphPassManager {
|
||||
public:
|
||||
explicit GraphPassManager(
|
||||
GraphPassPipeline pipeline,
|
||||
PassManagerOptions opts = {});
|
||||
~GraphPassManager() = default;
|
||||
|
||||
bool run(Graph* graph);
|
||||
|
||||
const GraphPassPipeline& pipeline() const {
|
||||
return pipeline_;
|
||||
}
|
||||
|
||||
const PassManagerOptions& opts() const {
|
||||
return opts_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<GraphPass> create_pass(GraphPassIdentifier id);
|
||||
|
||||
bool run_pass(Graph* graph, const GraphPassIdentifier& config);
|
||||
bool pass_pre_run_hook(Graph* graph, const GraphPass& pass);
|
||||
bool pass_post_run_hook(Graph* graph, const GraphPass& pass);
|
||||
|
||||
const GraphPassPipeline pipeline_;
|
||||
const PassManagerOptions opts_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
24
torch/nativert/graph/passes/pass_manager/PassPipeline.h
Normal file
24
torch/nativert/graph/passes/pass_manager/PassPipeline.h
Normal file
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
using GraphPassIdentifier = std::string;
|
||||
|
||||
class GraphPassPipeline : public std::vector<GraphPassIdentifier> {
|
||||
public:
|
||||
using std::vector<GraphPassIdentifier>::vector;
|
||||
|
||||
void push_front(GraphPassIdentifier pass) {
|
||||
std::vector<GraphPassIdentifier>::insert(begin(), std::move(pass));
|
||||
}
|
||||
|
||||
// concats the passed pipeline to the end of the current
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
||||
void concat(GraphPassPipeline&& other) {
|
||||
std::move(other.begin(), other.end(), std::back_inserter(*this));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user