mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add preprocessing that fuses decomposed linear into linear. (#37937)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37937 Sometime traces models dont preseve aten::linear ops and they are decomposed into addmm or mul + add. Adding thie preprocessing step helps us catch more lowerable linear nodes. Please see the test for example. Test Plan: python test/test_xnnpack_integration.py Reviewed By: xcheng16 Differential Revision: D21428069 fbshipit-source-id: 6c4ea3335eaf5722852c639fb4ee593746bb408f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
376c9a40dc
commit
002f5ec51b
@ -378,39 +378,42 @@ class TestXNNPACKSerDes(TestCase):
|
||||
" XNNPACK must be enabled for these tests."
|
||||
" Please build with USE_XNNPACK=1.")
|
||||
class TestXNNPACKRewritePass(TestCase):
|
||||
@staticmethod
|
||||
def validate_transformed_module(
|
||||
# To please flake
|
||||
self,
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=False,
|
||||
fuse_clamping_ops=False):
|
||||
module_instance = self
|
||||
scripted_model = torch.jit.script(module_instance)
|
||||
scripted_model.eval()
|
||||
input_data = torch.normal(1, 20, size=data_shape)
|
||||
ref_result = scripted_model(input_data)
|
||||
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
|
||||
if fuse_clamping_ops or prepack_removal:
|
||||
scripted_model._c = torch._C._freeze_module(scripted_model._c)
|
||||
if fuse_clamping_ops:
|
||||
torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
|
||||
if (prepack_removal):
|
||||
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_model, buffer)
|
||||
buffer.seek(0)
|
||||
deserialized_scripted_model = torch.jit.load(buffer)
|
||||
for pattern, v in pattern_count_map.items():
|
||||
if (v == 0):
|
||||
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
||||
elif (v == -1):
|
||||
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
||||
else:
|
||||
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
|
||||
xnnpack_result = deserialized_scripted_model(input_data)
|
||||
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
def test_linear(self):
|
||||
def validate_transformed_module(
|
||||
module_instance,
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=False,
|
||||
fuse_clamping_ops=False):
|
||||
scripted_model = torch.jit.script(module_instance)
|
||||
scripted_model.eval()
|
||||
input_data = torch.normal(1, 20, size=data_shape)
|
||||
ref_result = scripted_model(input_data)
|
||||
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
|
||||
if fuse_clamping_ops or prepack_removal:
|
||||
scripted_model._c = torch._C._freeze_module(scripted_model._c)
|
||||
if fuse_clamping_ops:
|
||||
torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
|
||||
if (prepack_removal):
|
||||
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_model, buffer)
|
||||
buffer.seek(0)
|
||||
deserialized_scripted_model = torch.jit.load(buffer)
|
||||
for pattern, v in pattern_count_map.items():
|
||||
if (v == 0):
|
||||
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
||||
elif (v == -1):
|
||||
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
||||
else:
|
||||
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
|
||||
xnnpack_result = deserialized_scripted_model(input_data)
|
||||
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
data_shape = [2, 3, 32]
|
||||
weight_output_dim = 24
|
||||
weight_shape = (weight_output_dim, data_shape[-1])
|
||||
@ -436,8 +439,8 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
pattern_count_map = {"Tensor = prim::CallFunction": -1,
|
||||
"prepacked::linear_clamp_prepack": 1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(Linear(), pattern_count_map, data_shape)
|
||||
validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(Linear(), pattern_count_map, data_shape)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
|
||||
|
||||
# Conv params
|
||||
batch_size = 2
|
||||
@ -477,7 +480,7 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"prepacked::conv2d_clamp_prepack": 1,
|
||||
"prepacked::conv2d_clamp_run": 1}
|
||||
validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
@ -513,11 +516,11 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
"prepacked::conv2d_clamp_run": 1,
|
||||
"prepacked::linear_clamp_prepack": 1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(M(), pattern_count_map, data_shape)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape)
|
||||
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
|
||||
pattern_count_map["Tensor = prim::CallFunction"] = -1
|
||||
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
|
||||
validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
|
||||
|
||||
# Not inplace relu fusion test.
|
||||
pattern_count_map = {"aten::relu": 2,
|
||||
@ -525,11 +528,16 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
"prepacked::conv2d_clamp_run": 1,
|
||||
"prepacked::linear_clamp_prepack": -1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
|
||||
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
|
||||
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
|
||||
pattern_count_map["aten::relu"] = -1
|
||||
validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True, fuse_clamping_ops=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
M(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
# Inplace relu fusion test.
|
||||
pattern_count_map = {"aten::relu": 2,
|
||||
@ -537,12 +545,20 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
"prepacked::conv2d_clamp_run": 1,
|
||||
"prepacked::linear_clamp_prepack": -1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(M(F.relu_), pattern_count_map, data_shape, prepack_removal=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
M(F.relu_),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True)
|
||||
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
|
||||
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
|
||||
pattern_count_map["aten::relu"] = -1
|
||||
validate_transformed_module(M(F.relu_), pattern_count_map, data_shape,
|
||||
prepack_removal=True, fuse_clamping_ops=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
M(F.relu_),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
# Not inplace hardtanh fusion test.
|
||||
pattern_count_map = {"aten::hardtanh": 2,
|
||||
@ -550,12 +566,20 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
"prepacked::conv2d_clamp_run": 1,
|
||||
"prepacked::linear_clamp_prepack": -1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
M(F.hardtanh),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True)
|
||||
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
|
||||
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
|
||||
pattern_count_map["aten::hardtanh"] = -1
|
||||
validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape,
|
||||
prepack_removal=True, fuse_clamping_ops=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
M(F.hardtanh),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
# Inplace hardtanh fusion test.
|
||||
pattern_count_map = {"aten::hardtanh_": 2,
|
||||
@ -563,12 +587,20 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
"prepacked::conv2d_clamp_run": 1,
|
||||
"prepacked::linear_clamp_prepack": -1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
M(F.hardtanh_),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True)
|
||||
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
|
||||
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
|
||||
pattern_count_map["aten::hardtanh_"] = -1
|
||||
validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape,
|
||||
prepack_removal=True, fuse_clamping_ops=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
M(F.hardtanh_),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
class MFusionAntiPattern(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -591,8 +623,12 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
"aten::relu": -1, # relu is fused.
|
||||
"prepacked::linear_clamp_prepack": -1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(MFusionAntiPattern(), pattern_count_map, (16, linear_weight_shape[1]),
|
||||
prepack_removal=True, fuse_clamping_ops=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
MFusionAntiPattern(),
|
||||
pattern_count_map,
|
||||
(16, linear_weight_shape[1]),
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
class MFusionAntiPatternParamMinMax(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -615,8 +651,58 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
pattern_count_map = {"aten::hardtanh": 1, # hardtanh cannot be.
|
||||
"prepacked::linear_clamp_prepack": -1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
validate_transformed_module(MFusionAntiPatternParamMinMax(), pattern_count_map, (16, linear_weight_shape[1]),
|
||||
prepack_removal=True, fuse_clamping_ops=True)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(
|
||||
MFusionAntiPatternParamMinMax(),
|
||||
pattern_count_map,
|
||||
(16, linear_weight_shape[1]),
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
def test_decomposed_linear(self):
|
||||
data_shape = [2, 32]
|
||||
weight_output_dim = 24
|
||||
weight_shape = (weight_output_dim, data_shape[-1])
|
||||
|
||||
class DecomposedLinearAddmm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(DecomposedLinearAddmm, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
weight_t = self.weight.t()
|
||||
return torch.addmm(self.bias, x, weight_t)
|
||||
|
||||
class DecomposedLinearMatmulAdd(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(DecomposedLinearMatmulAdd, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
weight_t = self.weight.t()
|
||||
y = torch.matmul(x, weight_t)
|
||||
res = y.add_(self.bias)
|
||||
return res
|
||||
|
||||
class DecomposedLinearMatmul(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(DecomposedLinearMatmul, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
weight_t = self.weight.t()
|
||||
res = torch.matmul(x, weight_t)
|
||||
return res
|
||||
|
||||
# Linear with bias pattern.
|
||||
pattern_count_map = {"Tensor = prim::CallFunction": -1,
|
||||
"prepacked::linear_clamp_prepack": 1,
|
||||
"prepacked::linear_clamp_run": 1}
|
||||
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearAddmm(), pattern_count_map, data_shape)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
||||
#include <torch/csrc/jit/passes/constant_pooling.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
#include <torch/csrc/jit/passes/prepack_folding.h>
|
||||
#include <torch/csrc/jit/passes/quantization.h>
|
||||
@ -19,6 +20,9 @@ namespace jit {
|
||||
namespace {
|
||||
|
||||
void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
|
||||
// fuse decomposed linear into aten::linear
|
||||
FuseLinear(graph);
|
||||
|
||||
std::string linear_before_inline = R"(
|
||||
graph(%linear, %input, %weight, %bias):
|
||||
%r = prim::CallFunction(%linear, %input, %weight, %bias)
|
||||
|
Reference in New Issue
Block a user