# Owner(s): ["module: intel"] import itertools from typing import NamedTuple import torch import torch.nn as nn from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import run_tests, TestCase CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} class PointwisePostOp(NamedTuple): attr: str pointwise_module: nn.Module scalars: list = [] algorithm: str = "" class TestoneDNNFusion(TestCase): def _unary_list(self): unary_list = { "relu": PointwisePostOp("relu", nn.ReLU()), "sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()), "tanh": PointwisePostOp("tanh", nn.Tanh()), "hardswish": PointwisePostOp("hardswish", nn.Hardswish()), "swish": PointwisePostOp("swish", nn.SiLU()), "leaky_relu": PointwisePostOp( "leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1] ), "hardtanh": PointwisePostOp( "hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4], ), "gelu_none": PointwisePostOp( "gelu", nn.GELU(approximate="none"), algorithm="none" ), "gelu_tanh": PointwisePostOp( "gelu", nn.GELU(approximate="tanh"), algorithm="tanh" ), } return unary_list def _binary_list(self): binary_list = { "add": torch.add, "sub": torch.sub, "mul": torch.mul, } return binary_list def test_linear_unary_fusion_ops(self, device): class M(nn.Module): def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() self.linear = torch.nn.Linear( in_channels, out_channels, bias=bias, **kwargs ) self.unary = unary_fn def forward(self, x): x = self.linear(x) x = self.unary(x) return x for pointwise_info in self._unary_list().values(): options = itertools.product([[2, 3, 10], [2, 10]], [True, False]) for input_shape, bias in options: with torch.no_grad(): mod = M( pointwise_info.pointwise_module, input_shape[-1], 10, bias ).eval() mod = mod.to(device) v = torch.randn(input_shape) v = v.to(device) ref = mod(v) attr = pointwise_info.attr scalars = pointwise_info.scalars algorithm = pointwise_info.algorithm fused = torch.ops.mkldnn._linear_pointwise( v, mod.linear.weight, mod.linear.bias, attr, scalars, algorithm, ) self.assertEqual(ref, fused) def test_linear_binary_fusion_ops(self, device): class M(nn.Module): def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() self.linear = torch.nn.Linear( in_channels, out_channels, bias=bias, **kwargs ) self.binary = binary_fn def forward(self, x, other): x = self.linear(x) x = self.binary(x, other) return x out_feature = 20 in_feature = 10 for pointwise_name, pointwise_fn in self._binary_list().items(): with torch.no_grad(): input = torch.randn(4, in_feature).to(device) model = M(pointwise_fn, in_feature, out_feature, True).eval().to(device) other = torch.randn(4, out_feature).to(device) ref = model(input, other) attr = pointwise_name fused = torch.ops.mkldnn._linear_pointwise( input, other, model.linear.weight, model.linear.bias, attr, ) self.assertEqual(ref, fused) def test_conv_unary_fusion_ops(self): class M(nn.Module): def __init__( self, unary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs, ): super().__init__() self.conv = CONV_MODULES[dim]( in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs, ) self.unary = unary_fn def forward(self, x): x = self.conv(x) x = self.unary(x) return x input_shapes = {2: (112, 112), 3: (55, 55, 55)} for pointwise_info in self._unary_list().values(): for dim in [2, 3]: channels_last = ( torch.channels_last if dim == 2 else torch.channels_last_3d ) options = itertools.product( [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], ) for bias, dilation, groups, memory_format in options: oC = 32 * groups iC = 3 * groups x_shape = (1, iC) + input_shapes[dim] x = torch.randn(x_shape, dtype=torch.float32).to( memory_format=memory_format ) mod = M( pointwise_info.pointwise_module, dim, iC, oC, dilation, groups, bias, kernel_size=3, ) mod = mod.to(memory_format=memory_format).eval() with torch.no_grad(): x = x.to("xpu") mod = mod.to("xpu") ref = mod(x) attr = pointwise_info.attr scalars = pointwise_info.scalars algorithm = pointwise_info.algorithm fused = torch.ops.mkldnn._convolution_pointwise( x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, mod.conv.groups, attr, scalars, algorithm, ) self.assertEqual(ref, fused) def test_conv_binary_fusion_ops(self): class M(nn.Module): def __init__( self, binary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs, ): super().__init__() self.conv = CONV_MODULES[dim]( in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs, ) self.binary = binary_fn def forward(self, x, other): x = self.conv(x) x = self.binary(x, other) return x for pointwise_name, pointwise_fn in self._binary_list().items(): x = torch.randn( ( 1, 3, 112, 112, ) ).to("xpu") mod = M(pointwise_fn, 2, 3, 3, 1, 1, True, kernel_size=3).to("xpu") other = torch.randn_like(mod.conv(x)) with torch.no_grad(): ref = mod(x, other) unary_attr = None attr = pointwise_name fused = torch.ops.mkldnn._convolution_pointwise( x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, mod.conv.groups, attr, None, unary_attr, [], None, ) if attr == "add": fused_inplace = torch.ops.mkldnn._convolution_pointwise_( other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, mod.conv.groups, attr, None, unary_attr, [], None, ) self.assertEqual(ref, other) self.assertEqual(ref, fused_inplace) self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4) instantiate_device_type_tests( TestoneDNNFusion, globals(), only_for="xpu", allow_xpu=True ) if __name__ == "__main__": run_tests()