Files
pytorch/test/xpu/test_fusion.py

291 lines
9.9 KiB
Python

# Owner(s): ["module: intel"]
import itertools
from typing import NamedTuple
import torch
import torch.nn as nn
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, TestCase
CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
class PointwisePostOp(NamedTuple):
attr: str
pointwise_module: nn.Module
scalars: list = []
algorithm: str = ""
class TestoneDNNFusion(TestCase):
def _unary_list(self):
unary_list = {
"relu": PointwisePostOp("relu", nn.ReLU()),
"sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()),
"tanh": PointwisePostOp("tanh", nn.Tanh()),
"hardswish": PointwisePostOp("hardswish", nn.Hardswish()),
"swish": PointwisePostOp("swish", nn.SiLU()),
"leaky_relu": PointwisePostOp(
"leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]
),
"hardtanh": PointwisePostOp(
"hardtanh",
nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False),
scalars=[-0.5, 4],
),
"gelu_none": PointwisePostOp(
"gelu", nn.GELU(approximate="none"), algorithm="none"
),
"gelu_tanh": PointwisePostOp(
"gelu", nn.GELU(approximate="tanh"), algorithm="tanh"
),
}
return unary_list
def _binary_list(self):
binary_list = {
"add": torch.add,
"sub": torch.sub,
"mul": torch.mul,
}
return binary_list
def test_linear_unary_fusion_ops(self, device):
class M(nn.Module):
def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.linear = torch.nn.Linear(
in_channels, out_channels, bias=bias, **kwargs
)
self.unary = unary_fn
def forward(self, x):
x = self.linear(x)
x = self.unary(x)
return x
for pointwise_info in self._unary_list().values():
options = itertools.product([[2, 3, 10], [2, 10]], [True, False])
for input_shape, bias in options:
with torch.no_grad():
mod = M(
pointwise_info.pointwise_module, input_shape[-1], 10, bias
).eval()
mod = mod.to(device)
v = torch.randn(input_shape)
v = v.to(device)
ref = mod(v)
attr = pointwise_info.attr
scalars = pointwise_info.scalars
algorithm = pointwise_info.algorithm
fused = torch.ops.mkldnn._linear_pointwise(
v,
mod.linear.weight,
mod.linear.bias,
attr,
scalars,
algorithm,
)
self.assertEqual(ref, fused)
def test_linear_binary_fusion_ops(self, device):
class M(nn.Module):
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.linear = torch.nn.Linear(
in_channels, out_channels, bias=bias, **kwargs
)
self.binary = binary_fn
def forward(self, x, other):
x = self.linear(x)
x = self.binary(x, other)
return x
out_feature = 20
in_feature = 10
for pointwise_name, pointwise_fn in self._binary_list().items():
with torch.no_grad():
input = torch.randn(4, in_feature).to(device)
model = M(pointwise_fn, in_feature, out_feature, True).eval().to(device)
other = torch.randn(4, out_feature).to(device)
ref = model(input, other)
attr = pointwise_name
fused = torch.ops.mkldnn._linear_pointwise(
input,
other,
model.linear.weight,
model.linear.bias,
attr,
)
self.assertEqual(ref, fused)
def test_conv_unary_fusion_ops(self):
class M(nn.Module):
def __init__(
self,
unary_fn,
dim,
in_channels,
out_channels,
dilation,
groups,
bias,
**kwargs,
):
super().__init__()
self.conv = CONV_MODULES[dim](
in_channels,
out_channels,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs,
)
self.unary = unary_fn
def forward(self, x):
x = self.conv(x)
x = self.unary(x)
return x
input_shapes = {2: (112, 112), 3: (55, 55, 55)}
for pointwise_info in self._unary_list().values():
for dim in [2, 3]:
channels_last = (
torch.channels_last if dim == 2 else torch.channels_last_3d
)
options = itertools.product(
[True, False],
[1, 2],
[1, 4],
[torch.contiguous_format, channels_last],
)
for bias, dilation, groups, memory_format in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC) + input_shapes[dim]
x = torch.randn(x_shape, dtype=torch.float32).to(
memory_format=memory_format
)
mod = M(
pointwise_info.pointwise_module,
dim,
iC,
oC,
dilation,
groups,
bias,
kernel_size=3,
)
mod = mod.to(memory_format=memory_format).eval()
with torch.no_grad():
x = x.to("xpu")
mod = mod.to("xpu")
ref = mod(x)
attr = pointwise_info.attr
scalars = pointwise_info.scalars
algorithm = pointwise_info.algorithm
fused = torch.ops.mkldnn._convolution_pointwise(
x,
mod.conv.weight,
mod.conv.bias,
mod.conv.padding,
mod.conv.stride,
mod.conv.dilation,
mod.conv.groups,
attr,
scalars,
algorithm,
)
self.assertEqual(ref, fused)
def test_conv_binary_fusion_ops(self):
class M(nn.Module):
def __init__(
self,
binary_fn,
dim,
in_channels,
out_channels,
dilation,
groups,
bias,
**kwargs,
):
super().__init__()
self.conv = CONV_MODULES[dim](
in_channels,
out_channels,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs,
)
self.binary = binary_fn
def forward(self, x, other):
x = self.conv(x)
x = self.binary(x, other)
return x
for pointwise_name, pointwise_fn in self._binary_list().items():
x = torch.randn(
(
1,
3,
112,
112,
)
).to("xpu")
mod = M(pointwise_fn, 2, 3, 3, 1, 1, True, kernel_size=3).to("xpu")
other = torch.randn_like(mod.conv(x))
with torch.no_grad():
ref = mod(x, other)
unary_attr = None
attr = pointwise_name
fused = torch.ops.mkldnn._convolution_pointwise(
x,
other,
mod.conv.weight,
mod.conv.bias,
mod.conv.padding,
mod.conv.stride,
mod.conv.dilation,
mod.conv.groups,
attr,
None,
unary_attr,
[],
None,
)
if attr == "add":
fused_inplace = torch.ops.mkldnn._convolution_pointwise_(
other,
x,
mod.conv.weight,
mod.conv.bias,
mod.conv.padding,
mod.conv.stride,
mod.conv.dilation,
mod.conv.groups,
attr,
None,
unary_attr,
[],
None,
)
self.assertEqual(ref, other)
self.assertEqual(ref, fused_inplace)
self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4)
instantiate_device_type_tests(
TestoneDNNFusion, globals(), only_for="xpu", allow_xpu=True
)
if __name__ == "__main__":
run_tests()