Add preprocessing that fuses decomposed linear into linear. (#37937)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37937

Sometime traces models dont preseve aten::linear ops and they are decomposed
into addmm or mul + add. Adding thie preprocessing step helps us catch more
lowerable linear nodes.
Please see the test for example.

Test Plan: python test/test_xnnpack_integration.py

Reviewed By: xcheng16

Differential Revision: D21428069

fbshipit-source-id: 6c4ea3335eaf5722852c639fb4ee593746bb408f
This commit is contained in:
Kimish Patel
2020-05-07 18:03:48 -07:00
committed by Facebook GitHub Bot
parent 376c9a40dc
commit 002f5ec51b
2 changed files with 142 additions and 52 deletions

View File

@ -378,39 +378,42 @@ class TestXNNPACKSerDes(TestCase):
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
class TestXNNPACKRewritePass(TestCase):
@staticmethod
def validate_transformed_module(
# To please flake
self,
pattern_count_map,
data_shape,
prepack_removal=False,
fuse_clamping_ops=False):
module_instance = self
scripted_model = torch.jit.script(module_instance)
scripted_model.eval()
input_data = torch.normal(1, 20, size=data_shape)
ref_result = scripted_model(input_data)
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
if fuse_clamping_ops or prepack_removal:
scripted_model._c = torch._C._freeze_module(scripted_model._c)
if fuse_clamping_ops:
torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
if (prepack_removal):
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
for pattern, v in pattern_count_map.items():
if (v == 0):
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif (v == -1):
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
else:
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
xnnpack_result = deserialized_scripted_model(input_data)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
def test_linear(self):
def validate_transformed_module(
module_instance,
pattern_count_map,
data_shape,
prepack_removal=False,
fuse_clamping_ops=False):
scripted_model = torch.jit.script(module_instance)
scripted_model.eval()
input_data = torch.normal(1, 20, size=data_shape)
ref_result = scripted_model(input_data)
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
if fuse_clamping_ops or prepack_removal:
scripted_model._c = torch._C._freeze_module(scripted_model._c)
if fuse_clamping_ops:
torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
if (prepack_removal):
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
for pattern, v in pattern_count_map.items():
if (v == 0):
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif (v == -1):
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
else:
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
xnnpack_result = deserialized_scripted_model(input_data)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
data_shape = [2, 3, 32]
weight_output_dim = 24
weight_shape = (weight_output_dim, data_shape[-1])
@ -436,8 +439,8 @@ class TestXNNPACKRewritePass(TestCase):
pattern_count_map = {"Tensor = prim::CallFunction": -1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(Linear(), pattern_count_map, data_shape)
validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
TestXNNPACKRewritePass.validate_transformed_module(Linear(), pattern_count_map, data_shape)
TestXNNPACKRewritePass.validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
# Conv params
batch_size = 2
@ -477,7 +480,7 @@ class TestXNNPACKRewritePass(TestCase):
pattern_count_map = {"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": 1,
"prepacked::conv2d_clamp_run": 1}
validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
TestXNNPACKRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
input_data = torch.rand((batch_size, input_channels, height, width))
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
@ -513,11 +516,11 @@ class TestXNNPACKRewritePass(TestCase):
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(M(), pattern_count_map, data_shape)
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["Tensor = prim::CallFunction"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
# Not inplace relu fusion test.
pattern_count_map = {"aten::relu": 2,
@ -525,11 +528,16 @@ class TestXNNPACKRewritePass(TestCase):
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::relu"] = -1
validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True, fuse_clamping_ops=True)
TestXNNPACKRewritePass.validate_transformed_module(
M(),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
# Inplace relu fusion test.
pattern_count_map = {"aten::relu": 2,
@ -537,12 +545,20 @@ class TestXNNPACKRewritePass(TestCase):
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(M(F.relu_), pattern_count_map, data_shape, prepack_removal=True)
TestXNNPACKRewritePass.validate_transformed_module(
M(F.relu_),
pattern_count_map,
data_shape,
prepack_removal=True)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::relu"] = -1
validate_transformed_module(M(F.relu_), pattern_count_map, data_shape,
prepack_removal=True, fuse_clamping_ops=True)
TestXNNPACKRewritePass.validate_transformed_module(
M(F.relu_),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
# Not inplace hardtanh fusion test.
pattern_count_map = {"aten::hardtanh": 2,
@ -550,12 +566,20 @@ class TestXNNPACKRewritePass(TestCase):
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True)
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh),
pattern_count_map,
data_shape,
prepack_removal=True)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::hardtanh"] = -1
validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape,
prepack_removal=True, fuse_clamping_ops=True)
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
# Inplace hardtanh fusion test.
pattern_count_map = {"aten::hardtanh_": 2,
@ -563,12 +587,20 @@ class TestXNNPACKRewritePass(TestCase):
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True)
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh_),
pattern_count_map,
data_shape,
prepack_removal=True)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::hardtanh_"] = -1
validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape,
prepack_removal=True, fuse_clamping_ops=True)
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh_),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
class MFusionAntiPattern(torch.nn.Module):
def __init__(self):
@ -591,8 +623,12 @@ class TestXNNPACKRewritePass(TestCase):
"aten::relu": -1, # relu is fused.
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(MFusionAntiPattern(), pattern_count_map, (16, linear_weight_shape[1]),
prepack_removal=True, fuse_clamping_ops=True)
TestXNNPACKRewritePass.validate_transformed_module(
MFusionAntiPattern(),
pattern_count_map,
(16, linear_weight_shape[1]),
prepack_removal=True,
fuse_clamping_ops=True)
class MFusionAntiPatternParamMinMax(torch.nn.Module):
def __init__(self):
@ -615,8 +651,58 @@ class TestXNNPACKRewritePass(TestCase):
pattern_count_map = {"aten::hardtanh": 1, # hardtanh cannot be.
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1}
validate_transformed_module(MFusionAntiPatternParamMinMax(), pattern_count_map, (16, linear_weight_shape[1]),
prepack_removal=True, fuse_clamping_ops=True)
TestXNNPACKRewritePass.validate_transformed_module(
MFusionAntiPatternParamMinMax(),
pattern_count_map,
(16, linear_weight_shape[1]),
prepack_removal=True,
fuse_clamping_ops=True)
def test_decomposed_linear(self):
data_shape = [2, 32]
weight_output_dim = 24
weight_shape = (weight_output_dim, data_shape[-1])
class DecomposedLinearAddmm(torch.nn.Module):
def __init__(self):
super(DecomposedLinearAddmm, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))), requires_grad=False)
def forward(self, x):
weight_t = self.weight.t()
return torch.addmm(self.bias, x, weight_t)
class DecomposedLinearMatmulAdd(torch.nn.Module):
def __init__(self):
super(DecomposedLinearMatmulAdd, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))), requires_grad=False)
def forward(self, x):
weight_t = self.weight.t()
y = torch.matmul(x, weight_t)
res = y.add_(self.bias)
return res
class DecomposedLinearMatmul(torch.nn.Module):
def __init__(self):
super(DecomposedLinearMatmul, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))), requires_grad=False)
def forward(self, x):
weight_t = self.weight.t()
res = torch.matmul(x, weight_t)
return res
# Linear with bias pattern.
pattern_count_map = {"Tensor = prim::CallFunction": -1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1}
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearAddmm(), pattern_count_map, data_shape)
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
if __name__ == "__main__":

View File

@ -5,6 +5,7 @@
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
#include <torch/csrc/jit/passes/quantization.h>
@ -19,6 +20,9 @@ namespace jit {
namespace {
void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
// fuse decomposed linear into aten::linear
FuseLinear(graph);
std::string linear_before_inline = R"(
graph(%linear, %input, %weight, %bias):
%r = prim::CallFunction(%linear, %input, %weight, %bias)