mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torchdynamo: add convolution pointwise(unary) fusion kernel (#86581)
Support unary fusion of Convolution with: - relu - sigmoid - tanh - hardswish - leaky_relu - hardtanh - gelu Pull Request resolved: https://github.com/pytorch/pytorch/pull/86581 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5a7e6db38
commit
9a7a49b254
@ -13,8 +13,7 @@ from test_tensorexpr import warmup_and_run_forward
|
||||
|
||||
FUSION_GROUP = 'prim::TensorExprGroup'
|
||||
|
||||
|
||||
class EltwiseFusionOp(NamedTuple):
|
||||
class PointwisePostOp(NamedTuple):
|
||||
attr : str
|
||||
pointwise_module : nn.Module
|
||||
scalars : List = []
|
||||
@ -96,31 +95,31 @@ class TestMkldnnFusion(JitTestCase):
|
||||
else:
|
||||
self.assertGraphContains(graph, kind=conv_node_name)
|
||||
|
||||
def test_conv_eltwise(self):
|
||||
def test_conv_unary_fusion_nnc(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs):
|
||||
def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs):
|
||||
super(M, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
|
||||
self.eltwise = eltwise_fn
|
||||
self.unary = unary_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.unary(x)
|
||||
return x
|
||||
|
||||
for memory_format, enabled in [
|
||||
[torch.contiguous_format, False],
|
||||
[torch.channels_last, True],
|
||||
]:
|
||||
for eltwise_fn in [torch.relu]:
|
||||
for unary_fn in [torch.relu]:
|
||||
for bias in [True, False]:
|
||||
for oC in [1, 10]:
|
||||
m = M(eltwise_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format)
|
||||
m = M(unary_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format)
|
||||
x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format)
|
||||
|
||||
graph = self._check_model(m, x)
|
||||
if enabled:
|
||||
self.assertFused(graph, ['aten::conv2d', 'aten::' + eltwise_fn.__name__])
|
||||
self.assertFused(graph, ['aten::conv2d', 'aten::' + unary_fn.__name__])
|
||||
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
|
||||
else:
|
||||
self.assertGraphContains(graph, kind='aten::conv2d')
|
||||
@ -166,40 +165,39 @@ class TestMkldnnFusion(JitTestCase):
|
||||
graph = self._check_model(m, x, trace)
|
||||
self.assertGraphContains(graph, kind='aten::_convolution')
|
||||
|
||||
def _eltwise_list(self):
|
||||
eltwise_list = {
|
||||
"relu": EltwiseFusionOp("relu", nn.ReLU()),
|
||||
"sigmoid": EltwiseFusionOp("sigmoid", nn.Sigmoid()),
|
||||
"tanh": EltwiseFusionOp("tanh", nn.Tanh()),
|
||||
"hardswish": EltwiseFusionOp("hardswish", nn.Hardswish()),
|
||||
"leaky_relu": EltwiseFusionOp("leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]),
|
||||
"hardtanh": EltwiseFusionOp("hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4]),
|
||||
"gelu_none": EltwiseFusionOp("gelu", nn.GELU(approximate="none"), algorithm="none"),
|
||||
"gelu_tanh": EltwiseFusionOp("gelu", nn.GELU(approximate="tanh"), algorithm="tanh"),
|
||||
def _unary_list(self):
|
||||
unary_list = {
|
||||
"relu": PointwisePostOp("relu", nn.ReLU()),
|
||||
"sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()),
|
||||
"tanh": PointwisePostOp("tanh", nn.Tanh()),
|
||||
"hardswish": PointwisePostOp("hardswish", nn.Hardswish()),
|
||||
"leaky_relu": PointwisePostOp("leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]),
|
||||
"hardtanh": PointwisePostOp("hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4]),
|
||||
"gelu_none": PointwisePostOp("gelu", nn.GELU(approximate="none"), algorithm="none"),
|
||||
"gelu_tanh": PointwisePostOp("gelu", nn.GELU(approximate="tanh"), algorithm="tanh"),
|
||||
}
|
||||
return eltwise_list
|
||||
return unary_list
|
||||
|
||||
def test_linear_eltwise(self):
|
||||
def test_linear_unary_fusion_ops(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs):
|
||||
def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs):
|
||||
super(M, self).__init__()
|
||||
self.linear = torch.nn.Linear(
|
||||
in_channels, out_channels, bias=bias, **kwargs
|
||||
)
|
||||
self.eltwise = eltwise_fn
|
||||
self.unary = unary_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.eltwise(x)
|
||||
x = self.unary(x)
|
||||
return x
|
||||
|
||||
for pointwise_name, pointwise_info in self._eltwise_list().items():
|
||||
for pointwise_name, pointwise_info in self._unary_list().items():
|
||||
options = itertools.product([[2, 3, 10], [2, 10]], [True, False])
|
||||
for input_shape, bias in options:
|
||||
with torch.no_grad():
|
||||
mod = M(pointwise_info.pointwise_module, input_shape[-1], 10, bias).eval()
|
||||
v = torch.randn(input_shape)
|
||||
|
||||
ref = mod(v)
|
||||
attr = pointwise_info.attr
|
||||
scalars = pointwise_info.scalars
|
||||
@ -210,5 +208,43 @@ class TestMkldnnFusion(JitTestCase):
|
||||
self.assertEqual(ref, fused)
|
||||
|
||||
|
||||
def test_conv_unary_fusion_ops(self):
|
||||
conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||
|
||||
class M(nn.Module):
|
||||
def __init__(self, unary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs):
|
||||
super(M, self).__init__()
|
||||
self.conv = conv_module[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs)
|
||||
self.unary = unary_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.unary(x)
|
||||
return x
|
||||
|
||||
input_shapes = {2: (112, 112), 3: (55, 55, 55)}
|
||||
for pointwise_name, pointwise_info in self._unary_list().items():
|
||||
for dim in [2, 3]:
|
||||
channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
|
||||
options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last])
|
||||
for bias, dilation, groups, memory_format in options:
|
||||
oC = 32 * groups
|
||||
iC = 3 * groups
|
||||
x_shape = (1, iC) + input_shapes[dim]
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format)
|
||||
mod = M(pointwise_info.pointwise_module, dim, iC, oC, dilation, groups, bias, kernel_size=3)
|
||||
mod = mod.to(memory_format=memory_format).eval()
|
||||
with torch.no_grad():
|
||||
ref = mod(x)
|
||||
attr = pointwise_info.attr
|
||||
scalars = pointwise_info.scalars
|
||||
algorithm = pointwise_info.algorithm
|
||||
fused = torch.ops.mkldnn._convolution_pointwise(
|
||||
x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
|
||||
mod.conv.groups, attr, scalars, algorithm
|
||||
)
|
||||
self.assertEqual(ref, fused)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user