mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Handle 1D input for xnnpack::linear (#54986)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54986 If the input is 1D xnnpack::linear fails while aten::linear makes it (1, D) and continues Test Plan: buck test //caffe2/test:xnnpack_integration -- TestXNNPACKOps Reviewed By: kimishpatel Differential Revision: D27441966 fbshipit-source-id: dfb2c23b91247632e0e3fd2482056a503c246c39
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fb1c193eed
commit
5c3963373a
@ -36,6 +36,22 @@ class TestXNNPACKOps(TestCase):
|
||||
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
|
||||
torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
|
||||
|
||||
@given(input_size=st.integers(2, 32),
|
||||
weight_output_dim=st.integers(2, 64),
|
||||
use_bias=st.booleans())
|
||||
def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
|
||||
input_data = torch.rand(input_size)
|
||||
weight = torch.rand((weight_output_dim, input_data.shape[-1]))
|
||||
if use_bias:
|
||||
bias = torch.rand((weight_output_dim))
|
||||
else:
|
||||
bias = None
|
||||
ref_result = F.linear(input_data, weight, bias)
|
||||
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
|
||||
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
|
||||
torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
@given(batch_size=st.integers(0, 3),
|
||||
input_channels_per_group=st.integers(1, 32),
|
||||
height=st.integers(5, 64),
|
||||
|
Reference in New Issue
Block a user