Handle 1D input for xnnpack::linear (#54986)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54986

If the input is 1D xnnpack::linear fails while aten::linear makes it (1, D) and continues

Test Plan: buck test //caffe2/test:xnnpack_integration  -- TestXNNPACKOps

Reviewed By: kimishpatel

Differential Revision: D27441966

fbshipit-source-id: dfb2c23b91247632e0e3fd2482056a503c246c39
This commit is contained in:
Akshit Khurana
2021-03-31 14:40:48 -07:00
committed by Facebook GitHub Bot
parent fb1c193eed
commit 5c3963373a
2 changed files with 29 additions and 2 deletions

View File

@ -36,6 +36,22 @@ class TestXNNPACKOps(TestCase):
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
@given(input_size=st.integers(2, 32),
weight_output_dim=st.integers(2, 64),
use_bias=st.booleans())
def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
input_data = torch.rand(input_size)
weight = torch.rand((weight_output_dim, input_data.shape[-1]))
if use_bias:
bias = torch.rand((weight_output_dim))
else:
bias = None
ref_result = F.linear(input_data, weight, bias)
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
@given(batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),