mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: For jit **SubgraphRewriter**, it doesn't keep output type after overwriting the old graph, for example, in profiling mode, the old graph has the old operator's shapes, but after replacing the old operator with a newer operator by applying **SubgraphRewriter**, the tensor shape info was eliminated. The activation is that I want to replace pytorch convolution with a customer's convolution, I first register **aten::_convolution** as a profiler node that can reorder the input and output's shapes, and then using graph rewrite to replace it as **aten::conv2d**, which tensors' shapes info are eliminated. I hope using input size do some pre-progress before replacing **aten::conv2d** with the customer's convolution. Before rewrite: ``` graph(%self.1 : __torch__.MyModule, %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)): %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/ site-packages/torch/nn/modules/conv.py:443:0 %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %4 : NoneType = prim::Constant() %3 : int[] = prim::Constant[value=[1, 1]]() %2 : int[] = prim::Constant[value=[0, 0]]() %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1) %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2 2:0 %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv) %x : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::_convolution(%x.1, %weight, %4, %3, %2, %3, %6, %2, %7, %6, %6, %5, %5), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3. 6/site-packages/torch/nn/modules/conv.py:443:0 %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%x, %z, %7) # jit_test.py: 24:0 return (%16) ``` after rewrite by using **aten::conv2d** ``` graph(%self.1 : __torch__.MyModule, %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)): %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0 %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0 %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0 %4 : NoneType = prim::Constant() %3 : int[] = prim::Constant[value=[1, 1]]() %2 : int[] = prim::Constant[value=[0, 0]]() %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1) %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:22:0 %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv) %18 : Tensor = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7) %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py:24:0 return (%16) ``` expected result after replace **aten::_convolution** with **aten::conv2d**: ``` graph(%self.1 : __torch__.MyModule, %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)): %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/ site-packages/torch/nn/modules/conv.py:443:0 %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6 /site-packages/torch/nn/modules/conv.py:443:0 %4 : NoneType = prim::Constant() %3 : int[] = prim::Constant[value=[1, 1]]() %2 : int[] = prim::Constant[value=[0, 0]]() %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1) %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2 2:0 %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv) %18 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7) %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py :24:0 return (%16) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/65453 Reviewed By: zdevito Differential Revision: D31162489 Pulled By: ZolotukhinM fbshipit-source-id: 0d1c1d607cb612df47c64f173d9f4c9e8b1d6c49
311 lines
7.9 KiB
C++
311 lines
7.9 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
using namespace testing;
|
|
|
|
TEST(SubgraphRewriterTest, FilterMatch) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b : int = prim::Constant[value=1]()
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR",
|
|
graph.get());
|
|
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b):
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR";
|
|
Graph pattern_graph;
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
|
|
parseIR(pattern, &pattern_graph, vmap);
|
|
|
|
auto b_is_constant = [](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto b_node = match_vmap.at(vmap.at("b"))->node();
|
|
return b_node->kind() == prim::Constant;
|
|
};
|
|
|
|
auto b_is_one = [](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto b_val = toIValue(match_vmap.at(vmap.at("b")));
|
|
return b_val && b_val->isInt() && b_val->toInt() == 1;
|
|
};
|
|
|
|
auto b_is_two = [](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto b_val = toIValue(match_vmap.at(vmap.at("b")));
|
|
return b_val && b_val->isInt() && b_val->toInt() == 2;
|
|
};
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %b):
|
|
%d = d::ddd(%a, %b)
|
|
return (%d))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
|
|
// b is constant, so the match will succeed
|
|
{
|
|
auto g = graph->copy();
|
|
rewriter.runOnGraph(g, b_is_constant);
|
|
FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
|
|
}
|
|
|
|
// b is constant and the value is one, the match will succeed
|
|
{
|
|
auto g = graph->copy();
|
|
rewriter.runOnGraph(g, {b_is_constant, b_is_one});
|
|
FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
|
|
}
|
|
|
|
// b is constant but the value is not two, the match will fail
|
|
{
|
|
auto g = graph->copy();
|
|
rewriter.runOnGraph(g, {b_is_constant, b_is_two});
|
|
FileCheck().check("c::ccc")->check_not("d::ddd")->run(*g);
|
|
}
|
|
}
|
|
|
|
TEST(SubgraphRewriterTest, FilterNoMatch) {
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = prim::Constant[value=1]()
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR",
|
|
graph.get());
|
|
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b):
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR";
|
|
Graph pattern_graph;
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
|
|
parseIR(pattern, &pattern_graph, vmap);
|
|
|
|
auto filter = [](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto b_node = match_vmap.at(vmap.at("b"))->node();
|
|
// b_node is not prim::Assign, so this won't match and we'll skip the
|
|
// rewrite
|
|
return b_node->kind() == prim::Assign;
|
|
};
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %b):
|
|
%d = d::ddd(%a, %b)
|
|
return (%d))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
rewriter.runOnGraph(graph, filter);
|
|
|
|
FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph);
|
|
}
|
|
|
|
TEST(SubgraphRewriterTest, MultiOutput) {
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
// Basic multi-output pattern rewriting
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0, %1):
|
|
%a1, %a2 = a::aaa(%0, %1)
|
|
%b = b::bbb(%a1)
|
|
%c = c::ccc(%b)
|
|
|
|
%x1, %x2 = a::aaa(%c, %a2)
|
|
%y = b::bbb(%x1)
|
|
%z = d::ddd(%y)
|
|
return (%z))IR",
|
|
graph.get());
|
|
|
|
std::string pattern = R"IR(
|
|
graph(%0, %1):
|
|
%a1, %a2 = a::aaa(%0, %1)
|
|
%b = b::bbb(%a1)
|
|
return (%b, %a2))IR";
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %b):
|
|
%x, %y = ab::ababab(%a, %b)
|
|
return (%x, %y))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
|
|
auto g = graph->copy();
|
|
rewriter.runOnGraph(g);
|
|
FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
// Mimic a real model case
|
|
parseIR(
|
|
R"IR(
|
|
graph(%k, %m, %x1, %x2, %x3, %x4, %y1, %y2, %y3, %y4):
|
|
%a1 = aa::aaa(%x1, %k)
|
|
%b1_1, %b1_2 = bb::bbb(%y1, %a1)
|
|
%a2 = aa::aaa(%x2, %k)
|
|
%b2_1, %b2_2 = bb::bbb(%y2, %a2)
|
|
%a3 = aa::aaa(%x3, %k)
|
|
%b3_1, %b3_2 = bb::bbb(%y3, %a3)
|
|
%a4 = aa::aaa(%x4, %k)
|
|
%b4_1, %b4_2 = bb::bbb(%y4, %a4)
|
|
%c = cc::ccc(%b4_1)
|
|
%d1 = dd::ddd(%b1_2, %m)
|
|
%e1 = ee::eee(%b1_1, %d1)
|
|
%d2 = dd::ddd(%b2_2, %m)
|
|
%e2 = ee::eee(%b2_1, %d2)
|
|
%d3 = dd::ddd(%b3_2, %m)
|
|
%e3 = ee::eee(%b3_1, %d3)
|
|
%d4 = dd::ddd(%b4_2, %m)
|
|
%e4 = ee::eee(%b4_1, %d4)
|
|
return (%d1, %d2, %d3, %d4, %e1, %e2, %e3, %e4)
|
|
)IR",
|
|
graph.get());
|
|
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b, %c, %d):
|
|
%y0 = aa::aaa(%b, %c)
|
|
%y1, %y2 = bb::bbb(%a, %y0)
|
|
%y3 = dd::ddd(%y2, %d)
|
|
return (%y3, %y1))IR";
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %b, %c, %d):
|
|
%x, %y = ab::ababab(%a, %b, %c, %d)
|
|
return (%x, %y))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
|
|
auto g = graph->copy();
|
|
rewriter.runOnGraph(g);
|
|
FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
// A case where no rewriting should occur due to data dependencies
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%a = aa::aaa(%x)
|
|
%b = bb::bbb(%a)
|
|
%e = ee::eee(%b)
|
|
%c = cc::ccc(%y)
|
|
%d = dd::ddd(%b, %c)
|
|
%f = ff::fff(%b, %d)
|
|
return (%f)
|
|
)IR",
|
|
graph.get());
|
|
|
|
std::string pattern = R"IR(
|
|
graph(%a, %c):
|
|
%b = bb::bbb(%a)
|
|
%d = dd::ddd(%b, %c)
|
|
return (%d, %b))IR";
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %c):
|
|
%d, %b = db::fused(%a, %c)
|
|
return (%d, %b))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
|
|
auto g = graph->copy();
|
|
rewriter.runOnGraph(g);
|
|
// We should not perform the replacement on the given graph due to data
|
|
// dependency constraints: the output %b is used in %e, which precedes one
|
|
// def of the input %c.
|
|
FileCheck().check_not("db::fused")->run(*g);
|
|
}
|
|
}
|
|
|
|
TEST(SubgraphRewriterTest, OutputType) {
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b):
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR";
|
|
Graph pattern_graph;
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
|
|
parseIR(pattern, &pattern_graph, vmap);
|
|
|
|
auto b_is_constant = [](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto b_node = match_vmap.at(vmap.at("b"))->node();
|
|
return b_node->kind() == prim::Constant;
|
|
};
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %b):
|
|
%d = d::ddd(%a, %b)
|
|
return (%d))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a : Float(10, 20) = a::aaa(%0)
|
|
%b : int = prim::Constant[value=1]()
|
|
%c : Float(10, 20) = c::ccc(%a, %b)
|
|
return (%c))IR",
|
|
graph.get());
|
|
|
|
// output has shape info.
|
|
rewriter.runOnGraph(graph, b_is_constant);
|
|
FileCheck()
|
|
.check("Float(10, 20) = d::ddd")
|
|
->check_not("c::ccc")
|
|
->run(*graph);
|
|
}
|
|
{
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b : int = prim::Constant[value=1]()
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR",
|
|
graph.get());
|
|
|
|
// output has not shape info.
|
|
rewriter.runOnGraph(graph, b_is_constant);
|
|
FileCheck().check("Tensor = d::ddd")->check_not("c::ccc")->run(*graph);
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|