mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NNC] enable fusion of conv with elementwise OP (#77157)
## Pitch Enable Conv-Eltwise fusion in NNC. ## Description This PR adds a `FuseConvWithEltwise` pass to fuse convolution with elementwise OP for TE subgraph. This pass will insert prepack and packed run ops for conv2d and enable fusion of conv2d with elementwise OPs. The fused packed run ops is implemented via external call in NNC. ## Code structure Graph rewrite pass related code is placed in: ``` torch/csrc/jit/passes/mkldnn_rewrite.h torch/csrc/jit/passes/mkldnn_rewrite.cpp ``` NNC integration of fused conv-eltwise OP via external call is located in: ``` torch/csrc/jit/tensorexpr/kernel.cpp torch/csrc/jit/tensorexpr/operators/conv2d.h torch/csrc/jit/tensorexpr/operators/conv2d.cpp torch/csrc/jit/tensorexpr/lowerings.cpp torch/csrc/jit/tensorexpr/external_functions.cpp ``` Fused prepack OP context is in: ``` aten/src/ATen/native/mkldnn/Common.h aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp aten/src/ATen/native/mkldnn/OpContext.h aten/src/ATen/native/mkldnn/OpContext.cpp ``` Fused OP implementation is done in: ``` aten/src/ATen/native/mkldnn/ConvPrepack.h aten/src/ATen/native/mkldnn/ConvPrepack.cpp ``` ## OP benchmark for conv-relu The below performance is measured on top of these two PRs to support NHWC: https://github.com/pytorch/pytorch/pull/76948 and https://github.com/pytorch/pytorch/pull/78238. - Measured on Cascade Lake 8280 - Jemalloc enabled - batch_size = 1 - Channels Last format ### Single thread: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> </head> <body link="#0563C1" vlink="#954F72"> shape | time (us)_no_fusion | time (us)_fusion | Gain -- | -- | -- | -- kernel=3, N=1, iC=64, H=56, W=56, oC=64, stride=1, pad=1, dilates=1, g=1 | 1706.22 | 1371.97 | 19.59% kernel=1, N=1, iC=256, H=56, W=56, oC=512, stride=2, pad=0, dilates=1, g=1 | 2499.28 | 1571.52 | 37.12% kernel=3, N=1, iC=256, H=56, W=56, oC=256, stride=1, pad=1, dilates=1, g=32 | 4169.52 | 2738.53 | 34.32% kernel=3, N=1, iC=512, H=56, W=56, oC=512, stride=2, pad=1, dilates=1, g=32 | 3998.77 | 3085.85 | 22.83% kernel=1, N=1, iC=64, H=56, W=56, oC=64, stride=1, pad=0, dilates=1, g=1 | 673.73 | 430.81 | 36.06% kernel=1, N=1, iC=256, H=56, W=56, oC=64, stride=1, pad=0, dilates=1, g=1 | 1101.87 | 801.07 | 27.30% kernel=1, N=1, iC=256, H=56, W=56, oC=256, stride=1, pad=0, dilates=1, g=1 | 4692.91 | 3116.13 | 33.60% kernel=1, N=1, iC=512, H=28, W=28, oC=512, stride=1, pad=0, dilates=1, g=1 | 3310.64 | 2503.39 | 24.38% </body> </html> ### 4 threads: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> </head> <body link="#0563C1" vlink="#954F72"> shape | time (us)_no_fusion | time (us)_fusion | Gain -- | -- | -- | -- kernel=3, N=1, iC=64, H=56, W=56, oC=64, stride=1, pad=1, dilates=1, g=1 | 360.07 | 321.21 | 10.79% kernel=1, N=1, iC=256, H=56, W=56, oC=512, stride=2, pad=0, dilates=1, g=1 | 391.49 | 323.17 | 17.45% kernel=3, N=1, iC=256, H=56, W=56, oC=256, stride=1, pad=1, dilates=1, g=32 | 536.4 | 465.97 | 13.13% kernel=3, N=1, iC=512, H=56, W=56, oC=512, stride=2, pad=1, dilates=1, g=32 | 674.98 | 616.32 | 8.69% kernel=1, N=1, iC=64, H=56, W=56, oC=64, stride=1, pad=0, dilates=1, g=1 | 160.97 | 70.05 | 56.48% kernel=1, N=1, iC=256, H=56, W=56, oC=64, stride=1, pad=0, dilates=1, g=1 | 215.81 | 182.6 | 15.39% kernel=1, N=1, iC=256, H=56, W=56, oC=256, stride=1, pad=0, dilates=1, g=1 | 658.45 | 576.97 | 12.37% kernel=1, N=1, iC=512, H=28, W=28, oC=512, stride=1, pad=0, dilates=1, g=1 | 702.18 | 566.39 | 19.34% </body> </html> ### 1 socket (28 cores): <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> </head> <body link="#0563C1" vlink="#954F72"> shape | time (us)_no_fusion | time (us)_fusion | Gain -- | -- | -- | -- kernel=3, N=1, iC=64, H=56, W=56, oC=64, stride=1, pad=1, dilates=1, g=1 | 149.92 | 103.78 | 30.78% kernel=1, N=1, iC=256, H=56, W=56, oC=512, stride=2, pad=0, dilates=1, g=1 | 192.76 | 110.87 | 42.48% kernel=3, N=1, iC=256, H=56, W=56, oC=256, stride=1, pad=1, dilates=1, g=32 | 160.67 | 127.24 | 20.81% kernel=3, N=1, iC=512, H=56, W=56, oC=512, stride=2, pad=1, dilates=1, g=32 | 212.45 | 180.55 | 15.02% kernel=1, N=1, iC=64, H=56, W=56, oC=64, stride=1, pad=0, dilates=1, g=1 | 114.57 | 50.58 | 55.85% kernel=1, N=1, iC=256, H=56, W=56, oC=64, stride=1, pad=0, dilates=1, g=1 | 198.64 | 70.6 | 64.46% kernel=1, N=1, iC=256, H=56, W=56, oC=256, stride=1, pad=0, dilates=1, g=1 | 281.35 | 155.8 | 44.62% kernel=1, N=1, iC=512, H=28, W=28, oC=512, stride=1, pad=0, dilates=1, g=1 | 262.15 | 162.94 | 37.84% </body> </html> ## UT ``` test/test_mkldnn_fusion.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/77157 Approved by: https://github.com/ZolotukhinM
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c83ec8f61
commit
693a8dd04c
118
test/test_mkldnn_fusion.py
Normal file
118
test/test_mkldnn_fusion.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Owner(s): ["module: mkldnn"]
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
from test_tensorexpr import warmup_and_run_forward
|
||||
|
||||
FUSION_GROUP = 'prim::TensorExprGroup'
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
||||
class TestMkldnnFusion(JitTestCase):
|
||||
def assertFused(self, graph, fused_patterns):
|
||||
for pat in fused_patterns:
|
||||
self.assertGraphContainsExactly(graph, pat, 0)
|
||||
|
||||
def _check_model(self, m, x):
|
||||
old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
|
||||
torch._C._debug_set_fusion_group_inlining(False)
|
||||
|
||||
old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
|
||||
old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(False)
|
||||
|
||||
m.eval()
|
||||
with torch.no_grad():
|
||||
script = torch.jit.script(m)
|
||||
script = torch.jit.freeze(script)
|
||||
|
||||
with torch.no_grad():
|
||||
y = warmup_and_run_forward(script, x)
|
||||
y = script(x)
|
||||
y_ref = m(x)
|
||||
|
||||
graph = script.graph_for(*x)
|
||||
self.assertEqual(y, y_ref)
|
||||
|
||||
torch._C._debug_set_fusion_group_inlining(old_fusion_inlining)
|
||||
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu)
|
||||
return graph
|
||||
|
||||
def test_single_conv(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, bias, **kwargs):
|
||||
super(M, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv(x)
|
||||
return res
|
||||
|
||||
for memory_format, enabled in [
|
||||
[torch.contiguous_format, False],
|
||||
[torch.channels_last, True],
|
||||
]:
|
||||
input_size = 224
|
||||
batch_size = 1
|
||||
kernel_size = 3
|
||||
options = itertools.product([True, False], [1, 2], [1, 4])
|
||||
for bias, dilation, groups in options:
|
||||
iC = 3 * groups
|
||||
oC = 10 * groups
|
||||
m = M(iC,
|
||||
oC,
|
||||
bias,
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
groups=groups).to(memory_format=memory_format)
|
||||
x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
|
||||
graph = self._check_model(m, x)
|
||||
if enabled:
|
||||
self.assertFused(graph, ['aten::conv2d'])
|
||||
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
|
||||
else:
|
||||
self.assertGraphContains(graph, kind='aten::conv2d')
|
||||
|
||||
def test_conv_eltwise(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs):
|
||||
super(M, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
|
||||
self.eltwise = eltwise_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.eltwise(x)
|
||||
return x
|
||||
|
||||
for memory_format, enabled in [
|
||||
[torch.contiguous_format, False],
|
||||
[torch.channels_last, True],
|
||||
]:
|
||||
for eltwise_fn in [torch.relu]:
|
||||
for bias in [True, False]:
|
||||
for oC in [1, 10]:
|
||||
m = M(eltwise_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format)
|
||||
x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format)
|
||||
|
||||
graph = self._check_model(m, x)
|
||||
if enabled:
|
||||
self.assertFused(graph, ['aten::conv2d', 'aten::' + eltwise_fn.__name__])
|
||||
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
|
||||
else:
|
||||
self.assertGraphContains(graph, kind='aten::conv2d')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
Reference in New Issue
Block a user