[nativert] oss pass graph pass registration (#160859)

Summary: att

Test Plan:
CI

Rollback Plan:

Differential Revision: D80368343

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160859
Approved by: https://github.com/georgiaphillips
This commit is contained in:
dolpm
2025-08-18 22:23:38 +00:00
committed by PyTorch MergeBot
parent 82c7a1eb4b
commit b439675ae2
9 changed files with 354 additions and 0 deletions

View File

@ -632,6 +632,8 @@ libtorch_nativert_sources = [
"torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp",
"torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp",
"torch/nativert/graph/passes/SubgraphRewriter.cpp",
"torch/nativert/graph/passes/pass_manager/GraphPasses.cpp",
"torch/nativert/graph/passes/pass_manager/PassManager.cpp",
]
torch_mobile_tracer_sources = [

View File

@ -37,6 +37,8 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp
${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/SubgraphRewriter.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
)
add_executable(test_nativert

View File

@ -0,0 +1,33 @@
#include <gtest/gtest.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/graph/passes/pass_manager/PassManager.h>
#include <torch/csrc/jit/testing/file_check.h>
using namespace ::testing;
using namespace torch::nativert;
TEST(PassManagerTest, TestEmptyPass) {
GraphPassManager manager({"EmptyPass"});
EXPECT_FALSE(manager.run(Graph::createGraph().get()));
}
TEST(PassPipelineTest, TestConcat) {
GraphPassPipeline p1({"test"});
EXPECT_EQ(p1.size(), 1);
EXPECT_EQ(p1.at(0), "test");
p1.concat({"test1", "test2"});
EXPECT_EQ(p1.at(0), "test");
EXPECT_EQ(p1.at(1), "test1");
EXPECT_EQ(p1.at(2), "test2");
}
TEST(PassPipelineTest, TestPushFront) {
GraphPassPipeline p1({"test"});
EXPECT_EQ(p1.size(), 1);
EXPECT_EQ(p1.at(0), "test");
p1.push_front("test1");
EXPECT_EQ(p1.at(0), "test1");
EXPECT_EQ(p1.at(1), "test");
}

View File

@ -0,0 +1,84 @@
#pragma once
#include <functional>
#include <map>
#include <c10/util/Logging.h>
#include <torch/nativert/graph/Graph.h>
namespace torch::nativert {
using PassSignature = std::function<bool(Graph*)>;
using GraphPassIdentifier = std::string;
class GraphPass {
public:
GraphPass(GraphPassIdentifier&& name, PassSignature&& pass)
: name_(std::move(name)), pass_(std::move(pass)) {}
const GraphPassIdentifier& name() const {
return name_;
}
const PassSignature& get() const {
return pass_;
}
private:
GraphPassIdentifier name_;
PassSignature pass_;
};
class GraphPassRegistry {
public:
static GraphPassRegistry& get() {
static GraphPassRegistry instance;
return instance;
}
static void add_pass(GraphPassIdentifier&& name, PassSignature&& pass) {
GraphPassRegistry::get().add_pass(
GraphPass(std::move(name), std::move(pass)));
}
void add_pass(GraphPass&& pass) {
if (auto it = registry_.find(pass.name()); it != registry_.end()) {
LOG(WARNING) << "Pass " << pass.name() << " already registered";
return;
}
GraphPassIdentifier name = pass.name();
LOG(INFO) << "Pass " << name << " registered";
registry_.insert({std::move(name), std::move(pass)});
}
void remove_pass(const GraphPassIdentifier& name) {
if (!registry_.erase(name)) {
LOG(WARNING) << "Pass " << name << " not registered but tried to remove";
return;
}
LOG(INFO) << "Pass " << name << " unregistered";
}
const GraphPass& get_pass(const GraphPassIdentifier& name) {
auto it = registry_.find(name);
if (it == registry_.end()) {
throw std::runtime_error("Pass " + name + " not registered to get");
}
return it->second;
}
private:
GraphPassRegistry() {
LOG(INFO) << "Creating GraphPassRegistry";
}
std::map<std::string, GraphPass> registry_;
public:
GraphPassRegistry(GraphPassRegistry const&) = delete;
void operator=(GraphPassRegistry const&) = delete;
};
} // namespace torch::nativert

View File

@ -0,0 +1,92 @@
#include <torch/nativert/graph/passes/pass_manager/GraphPasses.h>
#include <torch/nativert/graph/passes/SubgraphRewriter.h>
#include <torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h>
namespace torch::nativert {
void register_base_passes() {
GraphPassRegistry::add_pass("EmptyPass", [](Graph*) { return false; });
GraphPassRegistry::add_pass(
"LinearDynamicFp16UnpackedWeight", [](Graph* graph) {
std::string p = R"(
graph(%i, %w, %b):
%out_0 = torch.ops.aten.linear.default(input=%i, weight=%w, bias=%b)
return (%out_0))";
std::string p_1 = R"(
graph(%i, %w, %b):
%out_0 = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default(X=%i, weight=%w, bias=%b)
return (%out_0))";
std::string p_new = R"(
graph(%i, %w, %b):
%pw = torch.ops.quantized.linear_prepack_fp16.default(W=%w, B=%b)
%out_0 = torch.ops.quantized.linear_dynamic_fp16.default(X=%i, W_prepack=%pw)
return (%out_0))";
SubgraphRewriter rewriter("LinearDynamicFp16UnpackedWeight");
rewriter.registerRewritePattern(p, p_new);
rewriter.registerRewritePattern(p_1, p_new);
return rewriter.run(graph);
});
GraphPassRegistry::add_pass(
"LinearReluDynamicFp16UnpackedWeight", [](Graph* graph) {
std::string p = R"(
graph(%i, %w, %b):
%out_0 = torch.ops.aten.linear.default(input=%i, weight=%w, bias=%b)
%out_1 = torch.ops.aten.relu.default(self=%out_0)
return (%out_1))";
std::string p_1 = R"(
graph(%i, %w, %b):
%out_0 = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default(X=%i, weight=%w, bias=%b)
%out_1 = torch.ops.aten.relu.default(self=%out_0)
return (%out_1))";
std::string p_new = R"(
graph(%i, %w, %b):
%pw = torch.ops.quantized.linear_prepack_fp16.default(W=%w, B=%b)
%out_0 = torch.ops.quantized.linear_relu_dynamic_fp16.default(X=%i, W_prepack=%pw)
return (%out_0))";
SubgraphRewriter rewriter("LinearReluDynamicFp16UnpackedWeight");
rewriter.registerRewritePattern(p, p_new);
rewriter.registerRewritePattern(p_1, p_new);
return rewriter.run(graph);
});
GraphPassRegistry::add_pass("CleanUpDeadNodes", [](Graph* graph) {
return graph->cleanupDeadNodes();
});
GraphPassRegistry::add_pass("RemoveDetach", [](Graph* graph) {
std::vector<Node*> nodesToDestroy;
for (auto& node : graph->nodes()) {
if (node.target() == "torch.ops.aten.detach.default") {
nodesToDestroy.push_back(&node);
graph->replaceAllUses(node.outputs()[0], node.inputs()[0].value);
}
}
VLOG(1) << "[GraphPasses] Removed " << nodesToDestroy.size()
<< " aten.detach nodes";
const bool mutated = !nodesToDestroy.empty();
for (Node* node : nodesToDestroy) {
node->destroy();
}
graph->renumberValues();
graph->finalize();
graph->lint();
return mutated;
});
}
} // namespace torch::nativert

View File

@ -0,0 +1,7 @@
#pragma once
namespace torch::nativert {
void register_base_passes();
} // namespace torch::nativert

View File

@ -0,0 +1,52 @@
#include <torch/nativert/graph/passes/pass_manager/PassManager.h>
#include <c10/util/CallOnce.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/graph/passes/pass_manager/GraphPasses.h>
namespace torch::nativert {
GraphPassManager::GraphPassManager(
GraphPassPipeline pipeline,
PassManagerOptions opts)
: pipeline_(std::move(pipeline)), opts_(opts) {
static c10::once_flag flag;
c10::call_once(flag, [&]() { register_base_passes(); });
}
bool GraphPassManager::run(Graph* graph) {
bool changed = false;
for (const auto& pass_name : pipeline_) {
changed |= run_pass(graph, pass_name);
}
return changed;
}
bool GraphPassManager::run_pass(Graph* graph, const GraphPassIdentifier& name) {
const auto& pass = GraphPassRegistry::get().get_pass(name);
bool changed = pass_pre_run_hook(graph, pass);
changed |= (pass.get())(graph);
changed |= pass_post_run_hook(graph, pass);
return changed;
}
bool GraphPassManager::pass_pre_run_hook(Graph* graph, const GraphPass& pass) {
if (opts_.logGraphBetweenPasses()) {
LOG(INFO) << "Before pass: " << pass.name() << "\n"
<< graph->toString() << "-------------------------";
}
return false;
}
bool GraphPassManager::pass_post_run_hook(Graph* graph, const GraphPass& pass) {
if (opts_.logGraphBetweenPasses()) {
LOG(INFO) << "After pass: " << pass.name() << "\n"
<< graph->toString() << "-------------------------";
}
return false;
}
} // namespace torch::nativert

View File

@ -0,0 +1,58 @@
#pragma once
#include <memory>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/graph/passes/pass_manager/PassPipeline.h>
namespace torch::nativert {
using torch::nativert::Graph;
using torch::nativert::GraphPass;
class PassManagerOptions {
public:
/* GETTERS */
bool logGraphBetweenPasses() const {
return log_graph_between_passes_;
}
/* SETTERS */
PassManagerOptions& setLogGraphBetweenPasses(bool log_graph_between_passes) {
log_graph_between_passes_ = log_graph_between_passes;
return *this;
}
private:
bool log_graph_between_passes_{false};
};
class GraphPassManager {
public:
explicit GraphPassManager(
GraphPassPipeline pipeline,
PassManagerOptions opts = {});
~GraphPassManager() = default;
bool run(Graph* graph);
const GraphPassPipeline& pipeline() const {
return pipeline_;
}
const PassManagerOptions& opts() const {
return opts_;
}
private:
std::unique_ptr<GraphPass> create_pass(GraphPassIdentifier id);
bool run_pass(Graph* graph, const GraphPassIdentifier& config);
bool pass_pre_run_hook(Graph* graph, const GraphPass& pass);
bool pass_post_run_hook(Graph* graph, const GraphPass& pass);
const GraphPassPipeline pipeline_;
const PassManagerOptions opts_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,24 @@
#pragma once
#include <torch/nativert/graph/passes/pass_manager/GraphPassRegistry.h>
namespace torch::nativert {
using GraphPassIdentifier = std::string;
class GraphPassPipeline : public std::vector<GraphPassIdentifier> {
public:
using std::vector<GraphPassIdentifier>::vector;
void push_front(GraphPassIdentifier pass) {
std::vector<GraphPassIdentifier>::insert(begin(), std::move(pass));
}
// concats the passed pipeline to the end of the current
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
void concat(GraphPassPipeline&& other) {
std::move(other.begin(), other.end(), std::back_inserter(*this));
}
};
} // namespace torch::nativert