Files
pytorch/test/test_xnnpack_integration.py
Aaron Gokaslan 292af3cc89 [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2025-02-04 19:07:04 +00:00

1500 lines
54 KiB
Python

# Owner(s): ["oncall: mobile"]
import io
import itertools
import unittest
from hypothesis import assume, given, strategies as st
import torch
import torch.backends.xnnpack
import torch.testing._internal.hypothesis_utils as hu
from torch.nn import functional as F
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
IS_FBCODE,
run_tests,
slowTest,
TEST_WITH_TSAN,
TestCase,
)
from torch.utils.mobile_optimizer import optimize_for_mobile
@unittest.skipUnless(
torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
)
@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
)
class TestXNNPACKOps(TestCase):
@unittest.skip(
"Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
)
@given(
batch_size=st.integers(0, 3),
data_shape=hu.array_shapes(1, 3, 2, 64),
weight_output_dim=st.integers(2, 64),
use_bias=st.booleans(),
)
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
data_shape = [batch_size] + list(data_shape)
input_data = torch.rand(data_shape)
weight = torch.rand((weight_output_dim, data_shape[-1]))
if use_bias:
bias = torch.rand(weight_output_dim)
else:
bias = None
ref_result = F.linear(input_data, weight, bias)
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
input_data, packed_weight_bias
)
torch.testing.assert_close(
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
)
@given(
input_size=st.integers(2, 32),
weight_output_dim=st.integers(2, 64),
use_bias=st.booleans(),
)
def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
input_data = torch.rand(input_size)
weight = torch.rand((weight_output_dim, input_data.shape[-1]))
if use_bias:
bias = torch.rand(weight_output_dim)
else:
bias = None
ref_result = F.linear(input_data, weight, bias)
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
input_data, packed_weight_bias
)
torch.testing.assert_close(
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
)
@given(
batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
width=st.integers(5, 64),
output_channels_per_group=st.integers(1, 32),
groups=st.integers(1, 16),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans(),
format=st.sampled_from(
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
),
)
def test_conv2d(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
use_bias,
format,
):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if format is not None:
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand(
(output_channels, input_channels_per_group, kernel_h, kernel_w)
)
bias = None
if use_bias:
bias = torch.rand(output_channels)
ref_result = F.conv2d(
input_data, weight, bias, strides, paddings, dilations, groups
)
packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
weight, bias, strides, paddings, dilations, groups
)
xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(
input_data, packed_weight_bias
)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@given(
batch_size=st.integers(1, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
width=st.integers(5, 64),
output_channels_per_group=st.integers(1, 32),
groups=st.integers(1, 16),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
output_pad_h=st.integers(0, 2),
output_pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans(),
format=st.sampled_from(
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
),
)
def test_conv2d_transpose(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
output_pad_h,
output_pad_w,
dilation,
use_bias,
format,
):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
output_paddings = (output_pad_h, output_pad_w)
dilations = (dilation, dilation)
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
input_data = torch.rand((batch_size, input_channels, height, width))
if format is not None:
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand(
(input_channels, output_channels_per_group, kernel_h, kernel_w)
)
bias = None
if use_bias:
bias = torch.rand(output_channels)
# Note that groups/dilation is in reverse order from conv2d
ref_result = F.conv_transpose2d(
input_data,
weight,
bias,
strides,
paddings,
output_paddings,
groups,
dilation,
)
packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(
weight, bias, strides, paddings, output_paddings, dilations, groups
)
xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(
input_data, packed_weight_bias
)
torch.testing.assert_close(
ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3
)
@unittest.skipUnless(
torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
)
@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
)
class TestXNNPACKSerDes(TestCase):
@unittest.skip(
"Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
)
@given(
batch_size=st.integers(0, 3),
data_shape=hu.array_shapes(1, 3, 2, 64),
weight_output_dim=st.integers(2, 64),
use_bias=st.booleans(),
)
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
class Linear(torch.nn.Module):
def __init__(self, weight, bias=None):
super().__init__()
self.weight = weight
self.bias = bias
def forward(self, x):
return F.linear(x, self.weight, self.bias)
class LinearPrePacked(torch.nn.Module):
def __init__(self, weight, bias=None):
super().__init__()
self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(
weight, bias
)
def forward(self, x):
return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
data_shape = [batch_size] + list(data_shape)
weight = torch.rand((weight_output_dim, data_shape[-1]))
if use_bias:
bias = torch.rand(weight_output_dim)
else:
bias = None
scripted_linear = torch.jit.script(Linear(weight, bias))
scripted_linear_clamp_prepacked = torch.jit.script(
LinearPrePacked(weight, bias)
)
input_data = torch.rand(data_shape)
ref_result = scripted_linear(input_data)
output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
torch.testing.assert_close(
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
)
# Serialize the modules and then deserialize
input_data = torch.rand(data_shape)
buffer = io.BytesIO()
torch.jit.save(scripted_linear, buffer)
buffer.seek(0)
deserialized_linear = torch.jit.load(buffer)
buffer = io.BytesIO()
torch.jit.save(scripted_linear_clamp_prepacked, buffer)
buffer.seek(0)
deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
ref_result = deserialized_linear(input_data)
output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
torch.testing.assert_close(
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
)
@given(
batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
width=st.integers(5, 64),
output_channels_per_group=st.integers(1, 32),
groups=st.integers(1, 16),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans(),
format=st.sampled_from(
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
),
)
def test_conv2d(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
use_bias,
format,
):
class Conv2D(torch.nn.Module):
def __init__(self, weight, bias, strides, paddings, dilations, groups):
super().__init__()
self.weight = weight
self.bias = bias
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv2d(
x,
self.weight,
self.bias,
self.strides,
self.paddings,
self.dilations,
self.groups,
)
class Conv2DPrePacked(torch.nn.Module):
def __init__(self, weight, bias, strides, paddings, dilations, groups):
super().__init__()
self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
weight, bias, strides, paddings, dilations, groups
)
def forward(self, x):
return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias)
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if format is not None:
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand(
(output_channels, input_channels_per_group, kernel_h, kernel_w)
)
bias = None
if use_bias:
bias = torch.rand(output_channels)
scripted_conv2d = torch.jit.script(
Conv2D(weight, bias, strides, paddings, dilations, groups)
)
scripted_conv2d_clamp_prepacked = torch.jit.script(
Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups)
)
ref_result = scripted_conv2d(input_data)
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
# Serialize the modules and then deserialize
input_data = torch.rand((batch_size, input_channels, height, width))
if format is not None:
input_data = input_data.contiguous(memory_format=format)
buffer = io.BytesIO()
torch.jit.save(scripted_conv2d, buffer)
buffer.seek(0)
deserialized_conv2d = torch.jit.load(buffer)
buffer = io.BytesIO()
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
buffer.seek(0)
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
ref_result = deserialized_conv2d(input_data)
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@given(
batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
width=st.integers(5, 64),
output_channels_per_group=st.integers(1, 32),
groups=st.integers(1, 16),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
output_pad_h=st.integers(0, 2),
output_pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans(),
format=st.sampled_from(
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
),
)
def test_conv2d_transpose(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
output_pad_h,
output_pad_w,
dilation,
use_bias,
format,
):
class Conv2DT(torch.nn.Module):
def __init__(
self,
weight,
bias,
strides,
paddings,
output_paddings,
dilations,
groups,
):
super().__init__()
self.weight = weight
self.bias = bias
self.strides = strides
self.paddings = paddings
self.output_paddings = output_paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv_transpose2d(
x,
self.weight,
self.bias,
self.strides,
self.paddings,
self.output_paddings,
self.groups,
self.dilations,
)
class Conv2DTPrePacked(torch.nn.Module):
def __init__(
self,
weight,
bias,
strides,
paddings,
output_paddings,
dilations,
groups,
):
super().__init__()
self.packed_weight_bias = (
torch.ops.prepacked.conv2d_transpose_clamp_prepack(
weight,
bias,
strides,
paddings,
output_paddings,
dilations,
groups,
)
)
def forward(self, x):
return torch.ops.prepacked.conv2d_transpose_clamp_run(
x, self.packed_weight_bias
)
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
output_paddings = (output_pad_h, output_pad_w)
dilations = (dilation, dilation)
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
input_data = torch.rand((batch_size, input_channels, height, width))
if format is not None:
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand(
(input_channels, output_channels_per_group, kernel_h, kernel_w)
)
bias = None
if use_bias:
bias = torch.rand(output_channels)
scripted_conv2d = torch.jit.script(
Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups)
)
scripted_conv2d_clamp_prepacked = torch.jit.script(
Conv2DTPrePacked(
weight, bias, strides, paddings, output_paddings, dilations, groups
)
)
ref_result = scripted_conv2d(input_data)
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
# Serialize the modules and then deserialize
input_data = torch.rand((batch_size, input_channels, height, width))
if format is not None:
input_data = input_data.contiguous(memory_format=format)
buffer = io.BytesIO()
torch.jit.save(scripted_conv2d, buffer)
buffer.seek(0)
deserialized_conv2d = torch.jit.load(buffer)
buffer = io.BytesIO()
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
buffer.seek(0)
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
ref_result = deserialized_conv2d(input_data)
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@unittest.skip(
"Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
)
@given(
batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
width=st.integers(5, 64),
output_channels_per_group=st.integers(1, 32),
groups=st.integers(1, 16),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
linear_weight_output_dim=st.integers(2, 64),
use_bias=st.booleans(),
format=st.sampled_from(
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
),
)
def test_combined_model(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
linear_weight_output_dim,
use_bias,
format,
):
class M(torch.nn.Module):
def __init__(
self,
conv_weight,
conv_bias,
linear_weight,
linear_bias,
strides,
paddings,
dilations,
groups,
):
super().__init__()
self.conv_weight = conv_weight
self.conv_bias = conv_bias
self.linear_weight = linear_weight
self.linear_bias = linear_bias
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(
x,
self.conv_weight,
self.conv_bias,
self.strides,
self.paddings,
self.dilations,
self.groups,
)
o = o.permute([0, 2, 3, 1])
o = F.linear(o, self.linear_weight, self.linear_bias)
return F.relu(o)
class MPrePacked(torch.nn.Module):
def __init__(
self,
conv_weight,
conv_bias,
linear_weight,
linear_bias,
strides,
paddings,
dilations,
groups,
):
super().__init__()
self.conv2d_clamp_run_weight_bias = (
torch.ops.prepacked.conv2d_clamp_prepack(
conv_weight, conv_bias, strides, paddings, dilations, groups
)
)
self.linear_clamp_run_weight_bias = (
torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias)
)
def forward(self, x):
o = torch.ops.prepacked.conv2d_clamp_run(
x, self.conv2d_clamp_run_weight_bias
)
o = o.permute([0, 2, 3, 1])
o = torch.ops.prepacked.linear_clamp_run(
o, self.linear_clamp_run_weight_bias
)
return F.relu(o)
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if format is not None:
input_data = input_data.contiguous(memory_format=format)
conv_weight = torch.rand(
(output_channels, input_channels_per_group, kernel_h, kernel_w)
)
conv_bias = None
if use_bias:
conv_bias = torch.rand(output_channels)
# This is done just to find the output shape of the result
# so that the shape of weight for the following linear layer
# can be determined.
result = F.conv2d(
input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
)
linear_input_shape = result.shape[1]
linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
linear_bias = None
if use_bias:
linear_bias = torch.rand(linear_weight_output_dim)
scripted_m = torch.jit.script(
M(
conv_weight,
conv_bias,
linear_weight,
linear_bias,
strides,
paddings,
dilations,
groups,
)
)
scripted_m_prepacked = torch.jit.script(
MPrePacked(
conv_weight,
conv_bias,
linear_weight,
linear_bias,
strides,
paddings,
dilations,
groups,
)
)
ref_result = scripted_m(input_data)
xnnpack_result = scripted_m_prepacked(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
# Serialize the modules and then deserialize
input_data = torch.rand((batch_size, input_channels, height, width))
input_data = input_data.contiguous(memory_format=torch.channels_last)
buffer = io.BytesIO()
torch.jit.save(scripted_m, buffer)
buffer.seek(0)
deserialized_m = torch.jit.load(buffer)
buffer = io.BytesIO()
torch.jit.save(scripted_m_prepacked, buffer)
buffer.seek(0)
deserialized_m_prepacked = torch.jit.load(buffer)
ref_result = deserialized_m(input_data)
xnnpack_result = deserialized_m_prepacked(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@unittest.skipUnless(
torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
)
@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
)
class TestXNNPACKRewritePass(TestCase):
@staticmethod
def validate_transformed_module(
# To please flake
self,
pattern_count_map,
data_shape,
prepack_removal=False,
fuse_clamping_ops=False,
):
input_data = torch.normal(1, 20, size=data_shape)
for jit_method in ["script", "trace"]:
module_instance = self
if jit_method == "script":
scripted_model = torch.jit.script(module_instance)
else:
scripted_model = torch.jit.trace(module_instance, input_data)
scripted_model.eval()
ref_result = scripted_model(input_data)
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
if fuse_clamping_ops or prepack_removal:
scripted_model._c = torch._C._freeze_module(scripted_model._c)
if fuse_clamping_ops:
torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
if prepack_removal:
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
for pattern, v in pattern_count_map.items():
if v == 0:
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif v == -1:
FileCheck().check_not(pattern).run(
deserialized_scripted_model.graph
)
else:
FileCheck().check_count(pattern, v, exactly=True).run(
deserialized_scripted_model.graph
)
xnnpack_result = deserialized_scripted_model(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
def test_linear(self):
data_shape = [2, 3, 32]
weight_output_dim = 24
weight_shape = (weight_output_dim, data_shape[-1])
class Linear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(weight_shape), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.rand(weight_output_dim), requires_grad=False
)
def forward(self, x):
return F.linear(x, self.weight, self.bias)
class LinearNoBias(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(weight_shape), requires_grad=False
)
def forward(self, x):
return F.linear(x, self.weight, None)
# Linear with bias pattern.
pattern_count_map = {
"Tensor = prim::CallFunction": -1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
Linear(), pattern_count_map, data_shape
)
TestXNNPACKRewritePass.validate_transformed_module(
LinearNoBias(), pattern_count_map, data_shape
)
# Conv params
batch_size = 2
input_channels_per_group = 6
height = 16
width = 16
output_channels_per_group = 6
groups = 4
kernel_h = kernel_w = 3
stride_h = stride_w = 1
pad_h = pad_w = 1
output_pad_h = output_pad_w = 0
dilation = 1
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
output_paddings = (output_pad_h, output_pad_w)
dilations = (dilation, dilation)
conv_weight_shape = (
output_channels,
input_channels_per_group,
kernel_h,
kernel_w,
)
conv_transpose_weight_shape = (
input_channels,
output_channels_per_group,
kernel_h,
kernel_w,
)
conv_bias_shape = output_channels
class Conv2D(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(conv_weight_shape), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.rand(conv_bias_shape), requires_grad=False
)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv2d(
x,
self.weight,
self.bias,
self.strides,
self.paddings,
self.dilations,
self.groups,
)
class Conv2DT(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(conv_transpose_weight_shape), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.rand(conv_bias_shape), requires_grad=False
)
self.strides = strides
self.paddings = paddings
self.output_paddings = output_paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv_transpose2d(
x,
self.weight,
self.bias,
self.strides,
self.paddings,
self.output_paddings,
self.groups,
self.dilations,
)
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {
"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": 1,
"prepacked::conv2d_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
Conv2D(), pattern_count_map, data_shape
)
transpose_data_shape = (batch_size, input_channels, height, width) # noqa: F841
transpose_pattern_count_map = {
"Tensor = aten::conv_transpose2d": -1,
"prepacked::conv2d_transpose_clamp_prepack": 1,
"prepacked::conv2d_transpose_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
Conv2DT(), transpose_pattern_count_map, data_shape
)
input_data = torch.rand((batch_size, input_channels, height, width))
conv_weight = torch.rand(
(output_channels, input_channels_per_group, kernel_h, kernel_w)
)
conv_bias = torch.rand(output_channels)
result = F.conv2d(
input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
)
linear_input_shape = result.shape[1]
linear_weight_shape = (weight_output_dim, linear_input_shape)
class M(torch.nn.Module):
def __init__(self, activation_fn=F.relu):
super().__init__()
self.conv_weight = torch.nn.Parameter(
torch.rand(conv_weight_shape), requires_grad=False
)
self.conv_bias = torch.nn.Parameter(
torch.rand(conv_bias_shape), requires_grad=False
)
self.linear_weight = torch.nn.Parameter(
torch.rand(linear_weight_shape), requires_grad=False
)
self.linear_bias = torch.nn.Parameter(
torch.rand(weight_output_dim), requires_grad=False
)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
self.activation_fn = activation_fn
def forward(self, x):
o = F.conv2d(
x,
self.conv_weight,
self.conv_bias,
self.strides,
self.paddings,
self.dilations,
self.groups,
)
o = self.activation_fn(o)
o = o.permute([0, 2, 3, 1])
o = F.linear(o, self.linear_weight, self.linear_bias)
return self.activation_fn(o)
pattern_count_map = {
"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": 1,
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
M(), pattern_count_map, data_shape
)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["Tensor = prim::CallFunction"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
TestXNNPACKRewritePass.validate_transformed_module(
M(), pattern_count_map, data_shape, prepack_removal=True
)
# Not inplace relu fusion test.
pattern_count_map = {
"aten::relu": 2,
"prepacked::conv2d_clamp_prepack": -1,
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
M(), pattern_count_map, data_shape, prepack_removal=True
)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::relu"] = -1
TestXNNPACKRewritePass.validate_transformed_module(
M(),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True,
)
# Inplace relu fusion test.
pattern_count_map = {
"aten::relu": 2,
"prepacked::conv2d_clamp_prepack": -1,
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
M(F.relu_), pattern_count_map, data_shape, prepack_removal=True
)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::relu"] = -1
TestXNNPACKRewritePass.validate_transformed_module(
M(F.relu_),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True,
)
# Not inplace hardtanh fusion test.
pattern_count_map = {
"aten::hardtanh": 2,
"prepacked::conv2d_clamp_prepack": -1,
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True
)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::hardtanh"] = -1
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True,
)
# Inplace hardtanh fusion test.
pattern_count_map = {
"aten::hardtanh_": 2,
"prepacked::conv2d_clamp_prepack": -1,
"prepacked::conv2d_clamp_run": 1,
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True
)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
pattern_count_map["aten::hardtanh_"] = -1
TestXNNPACKRewritePass.validate_transformed_module(
M(F.hardtanh_),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True,
)
class MFusionAntiPattern(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear_weight = torch.nn.Parameter(
torch.rand(linear_weight_shape), requires_grad=False
)
self.linear_bias = torch.nn.Parameter(
torch.rand(weight_output_dim), requires_grad=False
)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.linear(x, self.linear_weight, self.linear_bias)
o = F.relu(o)
o = F.hardtanh(o)
return o
# Unfusable hardtanh.
pattern_count_map = {
"aten::hardtanh": 1, # hardtanh cannot be.
"aten::relu": -1, # relu is fused.
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
MFusionAntiPattern(),
pattern_count_map,
(16, linear_weight_shape[1]),
prepack_removal=True,
fuse_clamping_ops=True,
)
class MFusionAntiPatternParamMinMax(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear_weight = torch.nn.Parameter(
torch.rand(linear_weight_shape), requires_grad=False
)
self.linear_bias = torch.nn.Parameter(
torch.rand(weight_output_dim), requires_grad=False
)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
min = x[0, 0]
max = min + 10
o = F.linear(x, self.linear_weight, self.linear_bias)
o = F.hardtanh(o, min, max)
return o
# Unfusable hardtanh.
pattern_count_map = {
"aten::hardtanh": 1, # hardtanh cannot be.
"prepacked::linear_clamp_prepack": -1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
MFusionAntiPatternParamMinMax(),
pattern_count_map,
(16, linear_weight_shape[1]),
prepack_removal=True,
fuse_clamping_ops=True,
)
def test_decomposed_linear(self):
data_shape = [2, 32]
weight_output_dim = 24
weight_shape = (weight_output_dim, data_shape[-1])
class DecomposedLinearAddmm(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(weight_shape), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.rand(weight_output_dim), requires_grad=False
)
def forward(self, x):
weight_t = self.weight.t()
return torch.addmm(self.bias, x, weight_t)
class DecomposedLinearMatmulAdd(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(weight_shape), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.rand(weight_output_dim), requires_grad=False
)
def forward(self, x):
weight_t = self.weight.t()
y = torch.matmul(x, weight_t)
res = y.add_(self.bias)
return res
class DecomposedLinearMatmul(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(weight_shape), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.rand(weight_output_dim), requires_grad=False
)
def forward(self, x):
weight_t = self.weight.t()
res = torch.matmul(x, weight_t)
return res
# Linear with bias pattern.
pattern_count_map = {
"Tensor = prim::CallFunction": -1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1,
}
TestXNNPACKRewritePass.validate_transformed_module(
DecomposedLinearAddmm(), pattern_count_map, data_shape
)
TestXNNPACKRewritePass.validate_transformed_module(
DecomposedLinearMatmulAdd(), pattern_count_map, data_shape
)
TestXNNPACKRewritePass.validate_transformed_module(
DecomposedLinearMatmul(), pattern_count_map, data_shape
)
@unittest.skipUnless(
torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
)
@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
class TestXNNPACKConv1dTransformPass(TestCase):
@staticmethod
def validate_transform_conv1d_to_conv2d(
self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape
):
input_data = torch.normal(1, 20, size=data_shape)
for jit_method in ["script", "trace"]:
module_instance = self
if jit_method == "script":
scripted_model = torch.jit.script(module_instance)
else:
scripted_model = torch.jit.trace(module_instance, input_data)
scripted_model.eval()
ref_result = scripted_model(input_data)
torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
optimized_scripted_model = optimize_for_mobile(scripted_model)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
for pattern, v in pattern_count_transformed_map.items():
if v == 0:
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif v == -1:
FileCheck().check_not(pattern).run(
deserialized_scripted_model.graph
)
else:
FileCheck().check_count(pattern, v, exactly=True).run(
deserialized_scripted_model.graph
)
transformed_result = deserialized_scripted_model(input_data)
torch.testing.assert_close(
ref_result, transformed_result, rtol=1e-2, atol=1e-3
)
optimized_buffer = io.BytesIO()
torch.jit.save(optimized_scripted_model, optimized_buffer)
optimized_buffer.seek(0)
deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
for pattern, v in pattern_count_optimized_map.items():
if v == 0:
FileCheck().check(pattern).run(
deserialized_optimized_scripted_model.graph
)
elif v == -1:
FileCheck().check_not(pattern).run(
deserialized_optimized_scripted_model.graph
)
else:
FileCheck().check_count(pattern, v, exactly=True).run(
deserialized_optimized_scripted_model.graph
)
xnnpack_result = deserialized_optimized_scripted_model(input_data)
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@unittest.skipIf(IS_FBCODE, "T137513244")
def test_conv1d_basic(self):
batch_size_list = range(1, 3)
input_channels_per_group_list = range(10, 12)
width_list = range(10, 12)
output_channels_per_group_list = range(10, 12)
groups_list = range(1, 3)
kernel_list = range(1, 4)
stride_list = range(1, 3)
padding_list = range(0, 3)
dilation_list = range(1, 3)
for hparams in itertools.product(
batch_size_list,
input_channels_per_group_list,
width_list,
output_channels_per_group_list,
groups_list,
kernel_list,
stride_list,
padding_list,
dilation_list,
):
(
batch_size,
input_channels_per_group,
width,
output_channels_per_group,
groups,
kernel,
stride,
padding,
dilation,
) = hparams
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
conv_bias_shape = output_channels
class Conv1D(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(
torch.rand(conv_weight_shape), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.rand(conv_bias_shape), requires_grad=False
)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
def forward(self, x):
return F.conv1d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
data_shape = (batch_size, input_channels, width)
pattern_count_transformed_map = {
"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": 1,
}
pattern_count_optimized_map = {
"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": -1,
"prepacked::conv2d_clamp_run": 1,
}
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
Conv1D(),
pattern_count_transformed_map,
pattern_count_optimized_map,
data_shape,
)
# See https://github.com/pytorch/pytorch/issues/46066
@slowTest
def test_conv1d_with_relu_fc(self):
batch_size_list = range(1, 3)
input_channels_per_group_list = range(10, 12)
width_list = range(10, 12)
output_channels_per_group_list = range(10, 12)
groups_list = range(1, 3)
kernel_list = range(1, 4)
stride_list = range(1, 3)
padding_list = range(0, 3)
dilation_list = range(1, 3)
output_features_list = range(1, 3)
for hparams in itertools.product(
batch_size_list,
input_channels_per_group_list,
width_list,
output_channels_per_group_list,
groups_list,
kernel_list,
stride_list,
padding_list,
dilation_list,
output_features_list,
):
(
batch_size,
input_channels_per_group,
width,
output_channels_per_group,
groups,
kernel,
stride,
padding,
dilation,
output_features,
) = hparams
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
conv_bias_shape = output_channels
conv_output_width = (
int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
)
fc_weight_shape = (output_features, output_channels * conv_output_width)
fc_bias_shape = output_features
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv_weight = torch.nn.Parameter(
torch.rand(conv_weight_shape), requires_grad=False
)
self.conv_bias = torch.nn.Parameter(
torch.rand(conv_bias_shape), requires_grad=False
)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.fc_weight = torch.nn.Parameter(
torch.rand(fc_weight_shape), requires_grad=False
)
self.fc_bias = torch.nn.Parameter(
torch.rand(fc_bias_shape), requires_grad=False
)
def forward(self, x):
x = F.conv1d(
x,
self.conv_weight,
self.conv_bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
x = F.relu(x)
x = x.view(x.size(0), -1)
x = F.linear(x, self.fc_weight, self.fc_bias)
return x
data_shape = (batch_size, input_channels, width)
pattern_count_transformed_map = {
"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": 1,
}
pattern_count_optimized_map = {
"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": -1,
"prepacked::conv2d_clamp_run": 1,
}
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
Net(),
pattern_count_transformed_map,
pattern_count_optimized_map,
data_shape,
)
if __name__ == "__main__":
run_tests()