Avoid one unnecessary memory allocation in XNNPACK integration. (#35350)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35350

Currently we call input.contiguous() on the input tensor resulting in an
unecessary allocation and copy in cases where the input is not contiguous
with regards to the requested memory format.  The reason is that in such
scenarios, this call re-allocates and copies the input tensor into
contiguous storage, only for this newly allocated tensor to be used as
the source of another copy to the final destination.  Instead, if we copy
into the destination directly in such circumstances, we will save an
allocation and a copy.

Differential Revision: D20656798

Test Plan: Imported from OSS

Pulled By: AshkanAliabadi

fbshipit-source-id: 3f8c51df4d1fd386fa9473e7024621a7b7c6e86c
This commit is contained in:
Ashkan Aliabadi
2020-04-02 21:31:44 -07:00
committed by Facebook GitHub Bot
parent c33ea41f9c
commit d0ce94d20e
5 changed files with 64 additions and 27 deletions

View File

@ -20,10 +20,13 @@ class TestXNNPACKOps(TestCase):
@given(batch_size=st.integers(0, 3),
data_shape=hu.array_shapes(1, 3, 2, 64),
weight_output_dim=st.integers(2, 64),
use_bias=st.booleans())
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
use_bias=st.booleans(),
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias, format):
data_shape = [batch_size] + list(data_shape)
input_data = torch.rand(data_shape)
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand((weight_output_dim, data_shape[-1]))
if use_bias:
bias = torch.rand((weight_output_dim))
@ -47,7 +50,8 @@ class TestXNNPACKOps(TestCase):
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans())
use_bias=st.booleans(),
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
def test_conv2d(self,
batch_size,
input_channels_per_group,
@ -62,7 +66,8 @@ class TestXNNPACKOps(TestCase):
pad_h,
pad_w,
dilation,
use_bias):
use_bias,
format):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
@ -75,6 +80,8 @@ class TestXNNPACKOps(TestCase):
dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
bias = None
if use_bias:
@ -95,8 +102,9 @@ class TestXNNPACKSerDes(TestCase):
@given(batch_size=st.integers(0, 3),
data_shape=hu.array_shapes(1, 3, 2, 64),
weight_output_dim=st.integers(2, 64),
use_bias=st.booleans())
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
use_bias=st.booleans(),
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias, format):
class Linear(torch.nn.Module):
def __init__(self, weight, bias=None):
super(Linear, self).__init__()
@ -123,12 +131,16 @@ class TestXNNPACKSerDes(TestCase):
scripted_linear = torch.jit.script(Linear(weight, bias))
scripted_linear_clamp_prepacked = torch.jit.script(LinearPrePacked(weight, bias))
input_data = torch.rand(data_shape)
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
input_data = input_data.contiguous(memory_format=format)
ref_result = scripted_linear(input_data)
output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
# Serialize the modules and then deserialize
input_data = torch.rand(data_shape)
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
input_data = input_data.contiguous(memory_format=format)
buffer = io.BytesIO()
torch.jit.save(scripted_linear, buffer)
buffer.seek(0)
@ -154,7 +166,8 @@ class TestXNNPACKSerDes(TestCase):
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans())
use_bias=st.booleans(),
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
def test_conv2d(self,
batch_size,
input_channels_per_group,
@ -169,7 +182,8 @@ class TestXNNPACKSerDes(TestCase):
pad_h,
pad_w,
dilation,
use_bias):
use_bias,
format):
class Conv2D(torch.nn.Module):
def __init__(self, weight, bias, strides, paddings, dilations, groups):
super(Conv2D, self).__init__()
@ -205,6 +219,8 @@ class TestXNNPACKSerDes(TestCase):
dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
bias = None
if use_bias:
@ -220,6 +236,8 @@ class TestXNNPACKSerDes(TestCase):
# Serialize the modules and then deserialize
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
input_data = input_data.contiguous(memory_format=format)
buffer = io.BytesIO()
torch.jit.save(scripted_conv2d, buffer)
buffer.seek(0)
@ -246,7 +264,8 @@ class TestXNNPACKSerDes(TestCase):
pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
linear_weight_output_dim=st.integers(2, 64),
use_bias=st.booleans())
use_bias=st.booleans(),
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
def test_combined_model(self,
batch_size,
input_channels_per_group,
@ -262,7 +281,8 @@ class TestXNNPACKSerDes(TestCase):
pad_w,
dilation,
linear_weight_output_dim,
use_bias):
use_bias,
format):
class M(torch.nn.Module):
def __init__(self, conv_weight, conv_bias, linear_weight, linear_bias,
strides, paddings, dilations, groups):
@ -311,6 +331,8 @@ class TestXNNPACKSerDes(TestCase):
dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
input_data = input_data.contiguous(memory_format=format)
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
conv_bias = None
if use_bias:
@ -323,7 +345,6 @@ class TestXNNPACKSerDes(TestCase):
strides, paddings, dilations, groups)
linear_input_shape = result.shape[1]
input_data = input_data.contiguous(memory_format=torch.channels_last)
linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
linear_bias = None
if use_bias: