diff --git a/build_variables.bzl b/build_variables.bzl index 7926e36592e4..c3c99014d9f4 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -632,6 +632,8 @@ libtorch_nativert_sources = [ "torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp", "torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp", "torch/nativert/graph/passes/SubgraphRewriter.cpp", + "torch/nativert/graph/passes/pass_manager/GraphPasses.cpp", + "torch/nativert/graph/passes/pass_manager/PassManager.cpp", ] torch_mobile_tracer_sources = [ diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 8b5ca51b6301..822ed7c3bd99 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -37,6 +37,8 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp ${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp ${TORCH_ROOT}/torch/nativert/graph/passes/SubgraphRewriter.cpp + ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp + ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp ) add_executable(test_nativert diff --git a/test/cpp/nativert/test_pass_manager.cpp b/test/cpp/nativert/test_pass_manager.cpp new file mode 100644 index 000000000000..d3e5d6585978 --- /dev/null +++ b/test/cpp/nativert/test_pass_manager.cpp @@ -0,0 +1,33 @@ +#include + +#include +#include + +#include + +using namespace ::testing; +using namespace torch::nativert; + +TEST(PassManagerTest, TestEmptyPass) { + GraphPassManager manager({"EmptyPass"}); + EXPECT_FALSE(manager.run(Graph::createGraph().get())); +} + +TEST(PassPipelineTest, TestConcat) { + GraphPassPipeline p1({"test"}); + EXPECT_EQ(p1.size(), 1); + EXPECT_EQ(p1.at(0), "test"); + p1.concat({"test1", "test2"}); + EXPECT_EQ(p1.at(0), "test"); + EXPECT_EQ(p1.at(1), "test1"); + EXPECT_EQ(p1.at(2), "test2"); +} + +TEST(PassPipelineTest, TestPushFront) { + GraphPassPipeline p1({"test"}); + EXPECT_EQ(p1.size(), 1); + EXPECT_EQ(p1.at(0), "test"); + p1.push_front("test1"); + EXPECT_EQ(p1.at(0), "test1"); + EXPECT_EQ(p1.at(1), "test"); +} diff --git a/torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h b/torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h new file mode 100644 index 000000000000..28a7f77aa8a1 --- /dev/null +++ b/torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch::nativert { + +using PassSignature = std::function; +using GraphPassIdentifier = std::string; + +class GraphPass { + public: + GraphPass(GraphPassIdentifier&& name, PassSignature&& pass) + : name_(std::move(name)), pass_(std::move(pass)) {} + + const GraphPassIdentifier& name() const { + return name_; + } + + const PassSignature& get() const { + return pass_; + } + + private: + GraphPassIdentifier name_; + PassSignature pass_; +}; + +class GraphPassRegistry { + public: + static GraphPassRegistry& get() { + static GraphPassRegistry instance; + return instance; + } + + static void add_pass(GraphPassIdentifier&& name, PassSignature&& pass) { + GraphPassRegistry::get().add_pass( + GraphPass(std::move(name), std::move(pass))); + } + + void add_pass(GraphPass&& pass) { + if (auto it = registry_.find(pass.name()); it != registry_.end()) { + LOG(WARNING) << "Pass " << pass.name() << " already registered"; + return; + } + + GraphPassIdentifier name = pass.name(); + + LOG(INFO) << "Pass " << name << " registered"; + registry_.insert({std::move(name), std::move(pass)}); + } + + void remove_pass(const GraphPassIdentifier& name) { + if (!registry_.erase(name)) { + LOG(WARNING) << "Pass " << name << " not registered but tried to remove"; + return; + } + LOG(INFO) << "Pass " << name << " unregistered"; + } + + const GraphPass& get_pass(const GraphPassIdentifier& name) { + auto it = registry_.find(name); + if (it == registry_.end()) { + throw std::runtime_error("Pass " + name + " not registered to get"); + } + return it->second; + } + + private: + GraphPassRegistry() { + LOG(INFO) << "Creating GraphPassRegistry"; + } + + std::map registry_; + + public: + GraphPassRegistry(GraphPassRegistry const&) = delete; + void operator=(GraphPassRegistry const&) = delete; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp b/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp new file mode 100644 index 000000000000..7a838b2a651f --- /dev/null +++ b/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp @@ -0,0 +1,92 @@ +#include + +#include +#include + +namespace torch::nativert { + +void register_base_passes() { + GraphPassRegistry::add_pass("EmptyPass", [](Graph*) { return false; }); + + GraphPassRegistry::add_pass( + "LinearDynamicFp16UnpackedWeight", [](Graph* graph) { + std::string p = R"( + graph(%i, %w, %b): + %out_0 = torch.ops.aten.linear.default(input=%i, weight=%w, bias=%b) + return (%out_0))"; + + std::string p_1 = R"( + graph(%i, %w, %b): + %out_0 = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default(X=%i, weight=%w, bias=%b) + return (%out_0))"; + + std::string p_new = R"( + graph(%i, %w, %b): + %pw = torch.ops.quantized.linear_prepack_fp16.default(W=%w, B=%b) + %out_0 = torch.ops.quantized.linear_dynamic_fp16.default(X=%i, W_prepack=%pw) + return (%out_0))"; + + SubgraphRewriter rewriter("LinearDynamicFp16UnpackedWeight"); + rewriter.registerRewritePattern(p, p_new); + rewriter.registerRewritePattern(p_1, p_new); + return rewriter.run(graph); + }); + + GraphPassRegistry::add_pass( + "LinearReluDynamicFp16UnpackedWeight", [](Graph* graph) { + std::string p = R"( + graph(%i, %w, %b): + %out_0 = torch.ops.aten.linear.default(input=%i, weight=%w, bias=%b) + %out_1 = torch.ops.aten.relu.default(self=%out_0) + return (%out_1))"; + + std::string p_1 = R"( + graph(%i, %w, %b): + %out_0 = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default(X=%i, weight=%w, bias=%b) + %out_1 = torch.ops.aten.relu.default(self=%out_0) + return (%out_1))"; + + std::string p_new = R"( + graph(%i, %w, %b): + %pw = torch.ops.quantized.linear_prepack_fp16.default(W=%w, B=%b) + %out_0 = torch.ops.quantized.linear_relu_dynamic_fp16.default(X=%i, W_prepack=%pw) + return (%out_0))"; + + SubgraphRewriter rewriter("LinearReluDynamicFp16UnpackedWeight"); + rewriter.registerRewritePattern(p, p_new); + rewriter.registerRewritePattern(p_1, p_new); + return rewriter.run(graph); + }); + + GraphPassRegistry::add_pass("CleanUpDeadNodes", [](Graph* graph) { + return graph->cleanupDeadNodes(); + }); + + GraphPassRegistry::add_pass("RemoveDetach", [](Graph* graph) { + std::vector nodesToDestroy; + + for (auto& node : graph->nodes()) { + if (node.target() == "torch.ops.aten.detach.default") { + nodesToDestroy.push_back(&node); + graph->replaceAllUses(node.outputs()[0], node.inputs()[0].value); + } + } + + VLOG(1) << "[GraphPasses] Removed " << nodesToDestroy.size() + << " aten.detach nodes"; + + const bool mutated = !nodesToDestroy.empty(); + + for (Node* node : nodesToDestroy) { + node->destroy(); + } + + graph->renumberValues(); + graph->finalize(); + graph->lint(); + + return mutated; + }); +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/passes/pass_manager/GraphPasses.h b/torch/nativert/graph/passes/pass_manager/GraphPasses.h new file mode 100644 index 000000000000..f62564448652 --- /dev/null +++ b/torch/nativert/graph/passes/pass_manager/GraphPasses.h @@ -0,0 +1,7 @@ +#pragma once + +namespace torch::nativert { + +void register_base_passes(); + +} // namespace torch::nativert diff --git a/torch/nativert/graph/passes/pass_manager/PassManager.cpp b/torch/nativert/graph/passes/pass_manager/PassManager.cpp new file mode 100644 index 000000000000..e023f223ed6f --- /dev/null +++ b/torch/nativert/graph/passes/pass_manager/PassManager.cpp @@ -0,0 +1,52 @@ +#include + +#include + +#include +#include + +namespace torch::nativert { + +GraphPassManager::GraphPassManager( + GraphPassPipeline pipeline, + PassManagerOptions opts) + : pipeline_(std::move(pipeline)), opts_(opts) { + static c10::once_flag flag; + c10::call_once(flag, [&]() { register_base_passes(); }); +} + +bool GraphPassManager::run(Graph* graph) { + bool changed = false; + for (const auto& pass_name : pipeline_) { + changed |= run_pass(graph, pass_name); + } + return changed; +} + +bool GraphPassManager::run_pass(Graph* graph, const GraphPassIdentifier& name) { + const auto& pass = GraphPassRegistry::get().get_pass(name); + + bool changed = pass_pre_run_hook(graph, pass); + changed |= (pass.get())(graph); + changed |= pass_post_run_hook(graph, pass); + + return changed; +} + +bool GraphPassManager::pass_pre_run_hook(Graph* graph, const GraphPass& pass) { + if (opts_.logGraphBetweenPasses()) { + LOG(INFO) << "Before pass: " << pass.name() << "\n" + << graph->toString() << "-------------------------"; + } + return false; +} + +bool GraphPassManager::pass_post_run_hook(Graph* graph, const GraphPass& pass) { + if (opts_.logGraphBetweenPasses()) { + LOG(INFO) << "After pass: " << pass.name() << "\n" + << graph->toString() << "-------------------------"; + } + return false; +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/passes/pass_manager/PassManager.h b/torch/nativert/graph/passes/pass_manager/PassManager.h new file mode 100644 index 000000000000..22ce0144bcd8 --- /dev/null +++ b/torch/nativert/graph/passes/pass_manager/PassManager.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +#include +#include + +namespace torch::nativert { + +using torch::nativert::Graph; +using torch::nativert::GraphPass; + +class PassManagerOptions { + public: + /* GETTERS */ + bool logGraphBetweenPasses() const { + return log_graph_between_passes_; + } + + /* SETTERS */ + PassManagerOptions& setLogGraphBetweenPasses(bool log_graph_between_passes) { + log_graph_between_passes_ = log_graph_between_passes; + return *this; + } + + private: + bool log_graph_between_passes_{false}; +}; + +class GraphPassManager { + public: + explicit GraphPassManager( + GraphPassPipeline pipeline, + PassManagerOptions opts = {}); + ~GraphPassManager() = default; + + bool run(Graph* graph); + + const GraphPassPipeline& pipeline() const { + return pipeline_; + } + + const PassManagerOptions& opts() const { + return opts_; + } + + private: + std::unique_ptr create_pass(GraphPassIdentifier id); + + bool run_pass(Graph* graph, const GraphPassIdentifier& config); + bool pass_pre_run_hook(Graph* graph, const GraphPass& pass); + bool pass_post_run_hook(Graph* graph, const GraphPass& pass); + + const GraphPassPipeline pipeline_; + const PassManagerOptions opts_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/graph/passes/pass_manager/PassPipeline.h b/torch/nativert/graph/passes/pass_manager/PassPipeline.h new file mode 100644 index 000000000000..634e7436ec01 --- /dev/null +++ b/torch/nativert/graph/passes/pass_manager/PassPipeline.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace torch::nativert { + +using GraphPassIdentifier = std::string; + +class GraphPassPipeline : public std::vector { + public: + using std::vector::vector; + + void push_front(GraphPassIdentifier pass) { + std::vector::insert(begin(), std::move(pass)); + } + + // concats the passed pipeline to the end of the current + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + void concat(GraphPassPipeline&& other) { + std::move(other.begin(), other.end(), std::back_inserter(*this)); + } +}; + +} // namespace torch::nativert