Port fuse_linear from pytorch/tvm (#25623)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25623

Port over fuse_linear pass from pytorch/tvm project, we'll need this
in backend specific quantization pass to match aten::linear and swap
it with quantized linear

Test Plan:
python test/test_jit.py 'TestJit.test_fuse_linear'

Imported from OSS

Differential Revision: D17208890

fbshipit-source-id: f4ff3889ae4525797d3b986f46ae37e50ea49116
This commit is contained in:
Jerry Zhang
2019-09-12 18:38:30 -07:00
committed by Facebook Github Bot
parent 18a0040fec
commit be82239c86
6 changed files with 105 additions and 0 deletions

View File

@ -423,6 +423,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/fuse_linear.cpp
${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp

View File

@ -1273,6 +1273,36 @@ graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
.run(str(get_forward(m.sub).graph))
def test_fuse_linear(self):
input_strs = ["""
graph(%input, %weight, %bias, %4):
# CHECK-NOT: aten::t
# CHECK-NOT: aten::addmm
# CHECK: aten::linear
%weight_t = aten::t(%weight)
%res = aten::addmm(%bias, %input, %weight_t, %4, %4)
return (%res)""", """
graph(%input, %weight, %bias, %4):
# CHECK-NOT: aten::t
# CHECK-NOT: aten::matmul
# CHECK-NOT: aten::add_
# CHECK: aten::linear
%weight_t = aten::t(%weight)
%output = aten::matmul(%input, %weight_t)
%res = aten::add_(%output, %bias, %4)
return (%res)""", """
graph(%input, %weight):
# CHECK-NOT: aten::t
# CHECK-NOT: aten::matmul
# CHECK: aten::linear
%weight_t = aten::t(%weight)
%output = aten::matmul(%input, %weight_t)
return (%output)"""]
for input_str in input_strs:
graph = parse_ir(input_str)
torch._C._jit_pass_fuse_linear(graph)
FileCheck().run(input_str, graph)
def test_pattern_based_rewrite(self):
# mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
# --> mulmul(mulmul(x,y,z), x, y)

View File

@ -112,6 +112,7 @@ libtorch_sources = [
"torch/csrc/jit/passes/peephole.cpp",
"torch/csrc/jit/passes/python_print.cpp",
"torch/csrc/jit/passes/quantization.cpp",
"torch/csrc/jit/passes/fuse_linear.cpp",
"torch/csrc/jit/passes/remove_expands.cpp",
"torch/csrc/jit/passes/requires_grad_analysis.cpp",
"torch/csrc/jit/passes/shape_analysis.cpp",

View File

@ -19,6 +19,7 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/passes/erase_number_types.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/inline_fork_wait.h>
#include <torch/csrc/jit/passes/inliner.h>
@ -166,6 +167,9 @@ void initJITBindings(PyObject* module) {
"_jit_pass_quant_fusion",
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
.def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d)
.def(
"_jit_pass_fuse_linear",
[](std::shared_ptr<Graph>& g) { return FuseLinear(g); })
.def(
"_jit_pass_quantlint",
[](std::shared_ptr<Graph>& g) { return QuantLinting(g); })

View File

@ -0,0 +1,52 @@
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch {
namespace jit {
void FuseLinear(std::shared_ptr<Graph>& graph) {
std::string addmm_pattern = R"IR(
graph(%input, %weight, %bias, %4):
%weight_t = aten::t(%weight)
%res = aten::addmm(%bias, %input, %weight_t, %4, %4)
return (%res))IR";
std::string matmul_add_pattern = R"IR(
graph(%input, %weight, %bias, %4):
%weight_t = aten::t(%weight)
%output = aten::matmul(%input, %weight_t)
%res = aten::add_(%output, %bias, %4)
return (%res))IR";
std::string fused_linear = R"IR(
graph(%input, %weight, %bias, %4):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
std::string matmul_pattern = R"IR(
graph(%input, %weight):
%weight_t = aten::t(%weight)
%output = aten::matmul(%input, %weight_t)
return (%output))IR";
std::string fused_linear_bias_none = R"IR(
graph(%input, %weight):
%bias: Tensor? = prim::Constant()
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
// replace addmm pattern to linear
SubgraphRewriter addmm_to_linear;
addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear);
addmm_to_linear.runOnGraph(graph);
// replace matmul + add pattern to linear
SubgraphRewriter matmuladd_to_linear;
matmuladd_to_linear.RegisterRewritePattern(matmul_add_pattern, fused_linear);
matmuladd_to_linear.runOnGraph(graph);
// replace matmul with bias=None pattern to linear
SubgraphRewriter matmul_to_linear;
matmul_to_linear.RegisterRewritePattern(
matmul_pattern, fused_linear_bias_none);
matmul_to_linear.runOnGraph(graph);
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,17 @@
/** \brief Fusing linear patterns as single at::linear for easier pattern
* matching in later passes
*/
#pragma once
#include <torch/csrc/jit/ir.h>
namespace torch {
namespace jit {
/** \brief Match the at::linear pattern and fuse it into a single at::linear
* This pass fuse the addmm or matmul + add generated by JIT back to linear
* This pass can be deleted once the JIT can emit the aten::linear in the future
*/
TORCH_API void FuseLinear(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch