mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] add static subgraph fusion pass (#49185)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185 This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc. TODO: [x] factor out into separate fusion file [x] add python test case [x] add graph that isn't fully lowered test case [x] add graph that has weird list/tuple outputs test case the loop example looks quite good: ``` graph(%a.1 : Tensor, %b.1 : Tensor, %iters.1 : int): %12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4 %c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1) %c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4 block0(%i : int, %c.12 : Tensor): %c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1) -> (%12, %c.10) return (%c) with prim::StaticSubgraph_0 = graph(%0 : Tensor, %4 : Tensor): %5 : int = prim::Constant[value=2]() %6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12 %2 : int = prim::Constant[value=1]() %c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8 return (%c.2) with prim::StaticSubgraph_1 = graph(%1 : Tensor, %7 : Tensor, %8 : Tensor): %9 : int = prim::Constant[value=1]() %c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12 %5 : int = prim::Constant[value=2]() %c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8 %2 : int = prim::Constant[value=1]() %c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8 return (%c.10) ``` (Note: this ignores all push blocking failures!) Test Plan: buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest buck test mode/no-gpu caffe2/test:static_runtime Reviewed By: bertmaher Differential Revision: D25385702 fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
This commit is contained in:
committed by
Facebook GitHub Bot
parent
95a1725a4a
commit
f4226b5c90
@ -39,6 +39,7 @@ namespace c10 {
|
||||
_(prim, FunctionalGraph) \
|
||||
_(prim, DifferentiableGraph) \
|
||||
_(prim, TensorExprGroup) \
|
||||
_(prim, StaticSubgraph) \
|
||||
_(prim, If) \
|
||||
_(prim, Jump) /* debug */ \
|
||||
_(prim, JumpNZ) /* debug */ \
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/runtime/static/fusion.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include "deep_wide_pt.h"
|
||||
#include "test_scripts.h"
|
||||
@ -249,3 +250,34 @@ TEST(StaticRuntime, CleanUpMemory) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, FusionPass) {
|
||||
const int embedding_size = 32;
|
||||
const int num_features = 50;
|
||||
for (int batch_size : {1, 8, 32}) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
torch::jit::Module module = getDeepAndWideSciptModel();
|
||||
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
||||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||
auto wide = torch::randn({batch_size, num_features});
|
||||
|
||||
// run jit graph executor
|
||||
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
||||
auto output_1 = getTensor(module.forward(inputs));
|
||||
|
||||
Method method = module.get_method("forward");
|
||||
auto graph = method.graph();
|
||||
fuseStaticSubgraphs(graph);
|
||||
bool hit = false;
|
||||
for (const auto& n : module.get_method("forward").graph()->nodes()) {
|
||||
if (n->kind() == torch::jit::prim::StaticSubgraph) {
|
||||
hit = true;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(hit);
|
||||
auto output_2 = getTensor(module.forward(inputs));
|
||||
EXPECT_TRUE(output_1.equal(output_2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,6 +105,21 @@ def trivial_graph(a, b, c):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
|
||||
def loop_graph(a, b, iters : int):
|
||||
c = a + b * 2
|
||||
for i in range(iters):
|
||||
c = c + b
|
||||
c *= 2
|
||||
c -= a
|
||||
return c
|
||||
|
||||
def output_graph(a, b, c, iters : int):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
k = a + b * c + s
|
||||
d : Dict[int, Tensor] = {}
|
||||
for i in range(iters):
|
||||
d[i] = k + i
|
||||
return d
|
||||
|
||||
class TestStaticRuntime(TestCase):
|
||||
def test_multihead_attention_layer(self):
|
||||
@ -203,5 +218,63 @@ class TestStaticRuntime(TestCase):
|
||||
o_test = tg_a(s)[0]
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
def test_fusion_trivial_graph(self):
|
||||
s = torch.full((2, 2), 2)
|
||||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
torch._C._fuse_to_static_runtime(tg.graph)
|
||||
assert "StaticSubgraph" in str(tg.graph)
|
||||
o_test = tg(s, s, s)
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
def test_fusion_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
QUERY_LEN = 8
|
||||
BATCH_SIZE = 128
|
||||
LAYERS = 3
|
||||
HEADS = 8
|
||||
DROPOUT = 0.1
|
||||
device = torch.device("cpu")
|
||||
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
||||
with torch.no_grad():
|
||||
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
||||
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
||||
|
||||
attention.eval()
|
||||
attention = torch.jit.script(attention)
|
||||
attention.eval()
|
||||
o_ref = attention(src, src, src, src_mask)
|
||||
|
||||
torch._C._fuse_to_static_runtime(attention._c)
|
||||
o_test = attention(src, src, src, src_mask)
|
||||
|
||||
for a, b in zip(o_ref, o_test):
|
||||
torch.testing.assert_allclose(a, b)
|
||||
|
||||
def test_fusion_loop(self):
|
||||
a = torch.randn(5, 5)
|
||||
b = torch.randn(5, 5)
|
||||
c = 4
|
||||
lg = torch.jit.script(loop_graph)
|
||||
o_ref = lg(a, b, c)
|
||||
torch._C._fuse_to_static_runtime(lg.graph)
|
||||
assert "StaticSubgraph" in str(lg.graph)
|
||||
o_test = lg(a, b, c)
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
def test_fusion_outputs(self):
|
||||
a = torch.randn(2, 2)
|
||||
b = torch.randn(2, 2)
|
||||
c = 4
|
||||
og = torch.jit.script(output_graph)
|
||||
o_ref = og(a, b, b, c)
|
||||
torch._C._fuse_to_static_runtime(og.graph)
|
||||
assert "StaticSubgraph" in str(og.graph)
|
||||
o_test = og(a, b, b, c)
|
||||
for i in o_ref.keys():
|
||||
torch.testing.assert_allclose(o_ref[i], o_test[i])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -266,6 +266,7 @@ core_sources_full_mobile = [
|
||||
]
|
||||
|
||||
core_sources_full = core_sources_full_mobile + [
|
||||
"torch/csrc/jit/runtime/static/fusion.cpp",
|
||||
"torch/csrc/jit/runtime/static/impl.cpp",
|
||||
"torch/csrc/jit/runtime/static/ops.cpp",
|
||||
"torch/csrc/jit/runtime/static/passes.cpp",
|
||||
|
@ -486,6 +486,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
return analyzeGradOf(node);
|
||||
// TODO: think more about TensorExpr alias correctness
|
||||
case prim::TensorExprGroup:
|
||||
case prim::StaticSubgraph:
|
||||
case prim::Constant:
|
||||
case prim::AutogradZero:
|
||||
case prim::AutogradAdd:
|
||||
|
@ -320,6 +320,7 @@ struct CanEmitInline {
|
||||
// by the later BailOut in createBailoutBlock and its jf_index
|
||||
// will become invalid.
|
||||
v->node()->kind() != prim::TensorExprGroup &&
|
||||
v->node()->kind() != prim::StaticSubgraph &&
|
||||
v->node()->kind() != prim::CudaFusionGroup &&
|
||||
v->node()->kind() != prim::FusionGroup &&
|
||||
v->node()->kind() != prim::BailOut && v->uses().size() == 1 &&
|
||||
|
@ -239,6 +239,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
|
||||
prim::CudaFusionGroup, // optimization pass adds it
|
||||
prim::CudaFusionGuard, // optimization pass adds it
|
||||
prim::TensorExprGroup, // optimization pass adds it
|
||||
prim::StaticSubgraph, // optimization pass adds it
|
||||
prim::Load, // used in interpreter only
|
||||
prim::MMTreeReduce, // used as an optimization
|
||||
prim::MMBatchSide, // used as an optimization
|
||||
@ -276,6 +277,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||
prim::CudaFusionGroup,
|
||||
prim::DifferentiableGraph,
|
||||
prim::TensorExprGroup,
|
||||
prim::StaticSubgraph,
|
||||
prim::FunctionalGraph,
|
||||
prim::Constant,
|
||||
prim::Uninitialized,
|
||||
|
254
torch/csrc/jit/runtime/static/fusion.cpp
Normal file
254
torch/csrc/jit/runtime/static/fusion.cpp
Normal file
@ -0,0 +1,254 @@
|
||||
#include <torch/csrc/jit/runtime/static/fusion.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void createFusionGroups(Block* block, AliasDb* aliasDb);
|
||||
|
||||
void fuseStaticSubgraphs(std::shared_ptr<Graph> graph) {
|
||||
PrepareGraphForStaticRuntime(graph);
|
||||
auto aliasDb = torch::make_unique<AliasDb>(graph);
|
||||
createFusionGroups(graph->block(), aliasDb.get());
|
||||
torch::jit::EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
Operation createStaticSubgraphRuntime(const Node* node) {
|
||||
auto g = torch::jit::PrepareForStaticRuntime(node->g(attr::Subgraph));
|
||||
auto runtime = std::make_shared<torch::jit::StaticRuntime>(g);
|
||||
auto num_inputs = runtime->get_inference_module()->input_regs.size();
|
||||
return [runtime, num_inputs](Stack* stack) {
|
||||
RECORD_FUNCTION("Static Runtime", std::vector<c10::IValue>());
|
||||
auto inps = torch::jit::last(stack, num_inputs);
|
||||
std::vector<at::Tensor> t_inputs;
|
||||
t_inputs.reserve(num_inputs);
|
||||
for (const auto& inp : inps) {
|
||||
t_inputs.emplace_back(inp.toTensor());
|
||||
}
|
||||
torch::jit::drop(stack, num_inputs);
|
||||
auto outputs = runtime->run(t_inputs);
|
||||
for (auto& o : outputs) {
|
||||
push_one(*stack, std::move(o));
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
RegisterOperators StaticSubgraphOps({torch::jit::Operator(
|
||||
prim::StaticSubgraph,
|
||||
createStaticSubgraphRuntime,
|
||||
AliasAnalysisKind::INTERNAL_SPECIAL_CASE)});
|
||||
|
||||
#define REQ(cond) \
|
||||
if (!(cond)) { \
|
||||
GRAPH_DEBUG("Failed cond " #cond "\n"); \
|
||||
return false; \
|
||||
}
|
||||
|
||||
bool canHandle(Node* node) {
|
||||
for (Value* input : node->inputs()) {
|
||||
// TODO checks
|
||||
}
|
||||
|
||||
auto kind = node->kind();
|
||||
if (kind.is_prim()) {
|
||||
REQ(kind == prim::TupleConstruct || kind == prim::ListConstruct ||
|
||||
kind == prim::StaticSubgraph);
|
||||
return true;
|
||||
}
|
||||
const Operator& op = node->getOperator();
|
||||
auto analysis = op.aliasAnalysisKind();
|
||||
if (AliasAnalysisKind::PURE_FUNCTION == analysis ||
|
||||
(AliasAnalysisKind::FROM_SCHEMA == analysis &&
|
||||
!node->schema().is_mutable())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool canMerge(Node* consumer, Node* producer, AliasDb* aliasDb) {
|
||||
// Only fuse within a block
|
||||
REQ(consumer->owningBlock() == producer->owningBlock());
|
||||
|
||||
// Symbolic checks
|
||||
REQ(canHandle(producer) || producer->kind() == prim::StaticSubgraph);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
consumer->kind() == prim::StaticSubgraph || canHandle(consumer));
|
||||
|
||||
// Alias checks
|
||||
REQ(aliasDb->couldMoveBeforeTopologically(producer, consumer));
|
||||
|
||||
// Ops that return aliases can only be folded if this is the only use.
|
||||
if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze ||
|
||||
producer->kind() == prim::ConstantChunk) {
|
||||
for (auto& use : producer->output(0)->uses()) {
|
||||
REQ(use.user == consumer);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Node* getOrCreateStaticSubgraph(Node* n, AliasDb* aliasDb) {
|
||||
if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::StaticSubgraph) {
|
||||
return n;
|
||||
}
|
||||
GRAPH_UPDATE("Creating a static subgraph::Group node from: ", *n);
|
||||
return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
|
||||
n, prim::StaticSubgraph, *aliasDb);
|
||||
}
|
||||
|
||||
value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* b) {
|
||||
value_list result;
|
||||
for (auto i : inputs) {
|
||||
if (i->node()->owningBlock() == b) {
|
||||
result.push_back(i);
|
||||
}
|
||||
}
|
||||
// Sort in reverse topological order
|
||||
std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
|
||||
return a->node()->isAfter(b->node());
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
static void debugDumpFusionGroup(const std::string& msg, Node* n) {
|
||||
GRAPH_DEBUG(msg, *n);
|
||||
if (n->kind() == prim::StaticSubgraph) {
|
||||
GRAPH_DEBUG(*n->g(attr::Subgraph));
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<Node*> tryMerge(
|
||||
Node* fusion_group,
|
||||
Node* to_merge,
|
||||
AliasDb* aliasDb) {
|
||||
if (!canMerge(fusion_group, to_merge, aliasDb)) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
std::vector<Node*> nodes_to_merge = {to_merge};
|
||||
|
||||
if (to_merge->kind() == aten::cat) {
|
||||
Node* listconstruct = to_merge->input(0)->node();
|
||||
nodes_to_merge.push_back(listconstruct);
|
||||
}
|
||||
|
||||
// First, try to move all the nodes we want to fuse next to the fusion
|
||||
// group.
|
||||
Node* move_point = fusion_group;
|
||||
for (auto n : nodes_to_merge) {
|
||||
GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n));
|
||||
if (!aliasDb->moveBeforeTopologicallyValid(n, move_point)) {
|
||||
GRAPH_UPDATE("Failed to move because of AliasDb checks!");
|
||||
return c10::nullopt;
|
||||
}
|
||||
move_point = n;
|
||||
}
|
||||
|
||||
// Now all the nodes that we're going to fuse are moved next to the fusion
|
||||
// group, so we can safely merge them into the fusion group subgraph.
|
||||
fusion_group = getOrCreateStaticSubgraph(fusion_group, aliasDb);
|
||||
|
||||
for (auto n : nodes_to_merge) {
|
||||
GRAPH_UPDATE("Merging ", getHeader(n));
|
||||
SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
|
||||
n, fusion_group, *aliasDb);
|
||||
}
|
||||
return fusion_group;
|
||||
}
|
||||
|
||||
std::pair<graph_node_list::iterator, bool> createFusionGroup(
|
||||
Node* fusion_node,
|
||||
AliasDb* aliasDb) {
|
||||
fusion_node = getOrCreateStaticSubgraph(fusion_node, aliasDb);
|
||||
|
||||
GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n");
|
||||
auto inputs =
|
||||
sortReverseTopological(fusion_node->inputs(), fusion_node->owningBlock());
|
||||
for (auto input : inputs) {
|
||||
debugDumpFusionGroup("Current fusion group: ", fusion_node);
|
||||
GRAPH_DEBUG("Trying to merge: ", *input->node());
|
||||
if (auto maybe_fusion_group =
|
||||
tryMerge(fusion_node, input->node(), aliasDb)) {
|
||||
// we successfully merged, so the new group's `inputs` may have
|
||||
// changed. So rescan the new group for more merging opportunities.
|
||||
return std::make_pair(
|
||||
maybe_fusion_group.value()->reverseIterator(), true);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(++fusion_node->reverseIterator(), false);
|
||||
}
|
||||
|
||||
std::pair<graph_node_list::iterator, bool> scanNode(Node* n, AliasDb* aliasDb) {
|
||||
GRAPH_DEBUG("Considering node:", *n);
|
||||
|
||||
if (!canHandle(n)) {
|
||||
return std::make_pair(++n->reverseIterator(), false);
|
||||
}
|
||||
|
||||
return createFusionGroup(n, aliasDb);
|
||||
}
|
||||
|
||||
void createFusionGroups(Block* block, AliasDb* aliasDb) {
|
||||
bool any_changed = true;
|
||||
while (any_changed) {
|
||||
any_changed = false;
|
||||
for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
|
||||
bool changed;
|
||||
std::tie(it, changed) = scanNode(*it, aliasDb);
|
||||
any_changed |= changed;
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* n : block->nodes()) {
|
||||
for (Block* b : n->blocks()) {
|
||||
createFusionGroups(b, aliasDb);
|
||||
}
|
||||
}
|
||||
|
||||
// Try to merge adjacent fusion groups together. Because we have only merged
|
||||
// by looking at graph inputs, without this we would not attempt to merge
|
||||
// adjacent fusion groups that don't have a depdency on each other
|
||||
|
||||
std::vector<Node*> initial_fusion_groups;
|
||||
for (Node* n : block->nodes()) {
|
||||
if (n->kind() == prim::StaticSubgraph) {
|
||||
initial_fusion_groups.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
Node* prev_fusion_group =
|
||||
initial_fusion_groups.size() ? initial_fusion_groups[0] : nullptr;
|
||||
|
||||
for (size_t i = 1; i < initial_fusion_groups.size(); ++i) {
|
||||
// Try merging the just created fusion group into the previous one.
|
||||
// If it did not work, then put the previous fusion group into
|
||||
// fusion_groups vector - we will not touch it anymore in this loop.
|
||||
// If merging suceeded, save the merged group as the "previous" fusion
|
||||
// group so that we can try to merge the next one into it.
|
||||
|
||||
Node* fusion_group = initial_fusion_groups[i];
|
||||
debugDumpFusionGroup(
|
||||
"Trying to merge into the previous fusion group: ", prev_fusion_group);
|
||||
if (auto merged_fusion_group =
|
||||
tryMerge(prev_fusion_group, fusion_group, aliasDb)) {
|
||||
prev_fusion_group = *merged_fusion_group;
|
||||
debugDumpFusionGroup(
|
||||
"Successfully merged into the previous fusion group: ",
|
||||
prev_fusion_group);
|
||||
} else {
|
||||
GRAPH_DEBUG("Cannot merge into the previous fusion group");
|
||||
prev_fusion_group = fusion_group;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
11
torch/csrc/jit/runtime/static/fusion.h
Normal file
11
torch/csrc/jit/runtime/static/fusion.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API void fuseStaticSubgraphs(std::shared_ptr<Graph> graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -4,6 +4,7 @@
|
||||
#include <caffe2/core/scope_guard.h>
|
||||
#include <caffe2/core/timer.h>
|
||||
#include <torch/csrc/jit/passes/canonicalize.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
@ -14,14 +15,19 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
void OptimizeGraph(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
void PrepareGraphForStaticRuntime(std::shared_ptr<torch::jit::Graph> graph) {
|
||||
Inline(*graph);
|
||||
ConstantPropagation(graph);
|
||||
Canonicalize(graph);
|
||||
ConstantPropagation(graph);
|
||||
RemoveTensorMutation(graph);
|
||||
ConstantPropagation(graph);
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void OptimizeGraph(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
PrepareGraphForStaticRuntime(graph);
|
||||
FuseInferenceOpsForSparseNN(graph);
|
||||
ConstantPropagation(graph);
|
||||
}
|
||||
|
@ -83,6 +83,9 @@ struct TORCH_API InferenceModule {
|
||||
void init();
|
||||
};
|
||||
|
||||
TORCH_API void PrepareGraphForStaticRuntime(
|
||||
std::shared_ptr<torch::jit::Graph> g);
|
||||
|
||||
inline TORCH_API std::shared_ptr<InferenceModule> PrepareForStaticRuntime(
|
||||
const torch::jit::Module& m,
|
||||
InferenceModuleOptions opts = InferenceModuleOptions()) {
|
||||
|
@ -1,4 +1,6 @@
|
||||
#include <torch/csrc/jit/runtime/static/init.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/runtime/static/fusion.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
|
||||
namespace torch {
|
||||
@ -68,8 +70,23 @@ void initStaticRuntimeBindings(PyObject* module) {
|
||||
[](std::shared_ptr<torch::jit::Graph> g) {
|
||||
return StaticRuntime(PrepareForStaticRuntime(g));
|
||||
})
|
||||
.def("_jit_to_static_runtime", [](const torch::jit::Module& m) {
|
||||
return StaticRuntime(PrepareForStaticRuntime(m));
|
||||
.def(
|
||||
"_jit_to_static_runtime",
|
||||
[](const torch::jit::Module& m) {
|
||||
return StaticRuntime(PrepareForStaticRuntime(m));
|
||||
})
|
||||
.def(
|
||||
"_fuse_to_static_runtime",
|
||||
[](torch::jit::Module& module) {
|
||||
module.eval();
|
||||
module = freeze_module(module);
|
||||
|
||||
Method method = module.get_method("forward");
|
||||
auto graph = method.graph();
|
||||
fuseStaticSubgraphs(graph);
|
||||
})
|
||||
.def("_fuse_to_static_runtime", [](std::shared_ptr<torch::jit::Graph> g) {
|
||||
fuseStaticSubgraphs(g);
|
||||
});
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user