mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port fuse_linear from pytorch/tvm (#25623)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25623 Port over fuse_linear pass from pytorch/tvm project, we'll need this in backend specific quantization pass to match aten::linear and swap it with quantized linear Test Plan: python test/test_jit.py 'TestJit.test_fuse_linear' Imported from OSS Differential Revision: D17208890 fbshipit-source-id: f4ff3889ae4525797d3b986f46ae37e50ea49116
This commit is contained in:
committed by
Facebook Github Bot
parent
18a0040fec
commit
be82239c86
@ -423,6 +423,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/fuse_linear.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
|
||||
|
@ -1273,6 +1273,36 @@ graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w
|
||||
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
|
||||
.run(str(get_forward(m.sub).graph))
|
||||
|
||||
def test_fuse_linear(self):
|
||||
input_strs = ["""
|
||||
graph(%input, %weight, %bias, %4):
|
||||
# CHECK-NOT: aten::t
|
||||
# CHECK-NOT: aten::addmm
|
||||
# CHECK: aten::linear
|
||||
%weight_t = aten::t(%weight)
|
||||
%res = aten::addmm(%bias, %input, %weight_t, %4, %4)
|
||||
return (%res)""", """
|
||||
graph(%input, %weight, %bias, %4):
|
||||
# CHECK-NOT: aten::t
|
||||
# CHECK-NOT: aten::matmul
|
||||
# CHECK-NOT: aten::add_
|
||||
# CHECK: aten::linear
|
||||
%weight_t = aten::t(%weight)
|
||||
%output = aten::matmul(%input, %weight_t)
|
||||
%res = aten::add_(%output, %bias, %4)
|
||||
return (%res)""", """
|
||||
graph(%input, %weight):
|
||||
# CHECK-NOT: aten::t
|
||||
# CHECK-NOT: aten::matmul
|
||||
# CHECK: aten::linear
|
||||
%weight_t = aten::t(%weight)
|
||||
%output = aten::matmul(%input, %weight_t)
|
||||
return (%output)"""]
|
||||
for input_str in input_strs:
|
||||
graph = parse_ir(input_str)
|
||||
torch._C._jit_pass_fuse_linear(graph)
|
||||
FileCheck().run(input_str, graph)
|
||||
|
||||
def test_pattern_based_rewrite(self):
|
||||
# mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
|
||||
# --> mulmul(mulmul(x,y,z), x, y)
|
||||
|
@ -112,6 +112,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/passes/peephole.cpp",
|
||||
"torch/csrc/jit/passes/python_print.cpp",
|
||||
"torch/csrc/jit/passes/quantization.cpp",
|
||||
"torch/csrc/jit/passes/fuse_linear.cpp",
|
||||
"torch/csrc/jit/passes/remove_expands.cpp",
|
||||
"torch/csrc/jit/passes/requires_grad_analysis.cpp",
|
||||
"torch/csrc/jit/passes/shape_analysis.cpp",
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/decompose_ops.h>
|
||||
#include <torch/csrc/jit/passes/erase_number_types.h>
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/graph_fuser.h>
|
||||
#include <torch/csrc/jit/passes/inline_fork_wait.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
@ -166,6 +167,9 @@ void initJITBindings(PyObject* module) {
|
||||
"_jit_pass_quant_fusion",
|
||||
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
|
||||
.def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d)
|
||||
.def(
|
||||
"_jit_pass_fuse_linear",
|
||||
[](std::shared_ptr<Graph>& g) { return FuseLinear(g); })
|
||||
.def(
|
||||
"_jit_pass_quantlint",
|
||||
[](std::shared_ptr<Graph>& g) { return QuantLinting(g); })
|
||||
|
52
torch/csrc/jit/passes/fuse_linear.cpp
Normal file
52
torch/csrc/jit/passes/fuse_linear.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void FuseLinear(std::shared_ptr<Graph>& graph) {
|
||||
std::string addmm_pattern = R"IR(
|
||||
graph(%input, %weight, %bias, %4):
|
||||
%weight_t = aten::t(%weight)
|
||||
%res = aten::addmm(%bias, %input, %weight_t, %4, %4)
|
||||
return (%res))IR";
|
||||
std::string matmul_add_pattern = R"IR(
|
||||
graph(%input, %weight, %bias, %4):
|
||||
%weight_t = aten::t(%weight)
|
||||
%output = aten::matmul(%input, %weight_t)
|
||||
%res = aten::add_(%output, %bias, %4)
|
||||
return (%res))IR";
|
||||
std::string fused_linear = R"IR(
|
||||
graph(%input, %weight, %bias, %4):
|
||||
%res = aten::linear(%input, %weight, %bias)
|
||||
return (%res))IR";
|
||||
|
||||
std::string matmul_pattern = R"IR(
|
||||
graph(%input, %weight):
|
||||
%weight_t = aten::t(%weight)
|
||||
%output = aten::matmul(%input, %weight_t)
|
||||
return (%output))IR";
|
||||
std::string fused_linear_bias_none = R"IR(
|
||||
graph(%input, %weight):
|
||||
%bias: Tensor? = prim::Constant()
|
||||
%res = aten::linear(%input, %weight, %bias)
|
||||
return (%res))IR";
|
||||
|
||||
// replace addmm pattern to linear
|
||||
SubgraphRewriter addmm_to_linear;
|
||||
addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear);
|
||||
addmm_to_linear.runOnGraph(graph);
|
||||
|
||||
// replace matmul + add pattern to linear
|
||||
SubgraphRewriter matmuladd_to_linear;
|
||||
matmuladd_to_linear.RegisterRewritePattern(matmul_add_pattern, fused_linear);
|
||||
matmuladd_to_linear.runOnGraph(graph);
|
||||
|
||||
// replace matmul with bias=None pattern to linear
|
||||
SubgraphRewriter matmul_to_linear;
|
||||
matmul_to_linear.RegisterRewritePattern(
|
||||
matmul_pattern, fused_linear_bias_none);
|
||||
matmul_to_linear.runOnGraph(graph);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
17
torch/csrc/jit/passes/fuse_linear.h
Normal file
17
torch/csrc/jit/passes/fuse_linear.h
Normal file
@ -0,0 +1,17 @@
|
||||
/** \brief Fusing linear patterns as single at::linear for easier pattern
|
||||
* matching in later passes
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
/** \brief Match the at::linear pattern and fuse it into a single at::linear
|
||||
* This pass fuse the addmm or matmul + add generated by JIT back to linear
|
||||
* This pass can be deleted once the JIT can emit the aten::linear in the future
|
||||
*/
|
||||
TORCH_API void FuseLinear(std::shared_ptr<Graph>& graph);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user