mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Avoid one unnecessary memory allocation in XNNPACK integration. (#35350)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35350 Currently we call input.contiguous() on the input tensor resulting in an unecessary allocation and copy in cases where the input is not contiguous with regards to the requested memory format. The reason is that in such scenarios, this call re-allocates and copies the input tensor into contiguous storage, only for this newly allocated tensor to be used as the source of another copy to the final destination. Instead, if we copy into the destination directly in such circumstances, we will save an allocation and a copy. Differential Revision: D20656798 Test Plan: Imported from OSS Pulled By: AshkanAliabadi fbshipit-source-id: 3f8c51df4d1fd386fa9473e7024621a7b7c6e86c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c33ea41f9c
commit
d0ce94d20e
@ -20,10 +20,13 @@ class TestXNNPACKOps(TestCase):
|
||||
@given(batch_size=st.integers(0, 3),
|
||||
data_shape=hu.array_shapes(1, 3, 2, 64),
|
||||
weight_output_dim=st.integers(2, 64),
|
||||
use_bias=st.booleans())
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias, format):
|
||||
data_shape = [batch_size] + list(data_shape)
|
||||
input_data = torch.rand(data_shape)
|
||||
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((weight_output_dim, data_shape[-1]))
|
||||
if use_bias:
|
||||
bias = torch.rand((weight_output_dim))
|
||||
@ -47,7 +50,8 @@ class TestXNNPACKOps(TestCase):
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
use_bias=st.booleans())
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_conv2d(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
@ -62,7 +66,8 @@ class TestXNNPACKOps(TestCase):
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
use_bias):
|
||||
use_bias,
|
||||
format):
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
kernels = (kernel_h, kernel_w)
|
||||
@ -75,6 +80,8 @@ class TestXNNPACKOps(TestCase):
|
||||
dilations[1] * (kernels[1] - 1) + 1)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
bias = None
|
||||
if use_bias:
|
||||
@ -95,8 +102,9 @@ class TestXNNPACKSerDes(TestCase):
|
||||
@given(batch_size=st.integers(0, 3),
|
||||
data_shape=hu.array_shapes(1, 3, 2, 64),
|
||||
weight_output_dim=st.integers(2, 64),
|
||||
use_bias=st.booleans())
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias, format):
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, weight, bias=None):
|
||||
super(Linear, self).__init__()
|
||||
@ -123,12 +131,16 @@ class TestXNNPACKSerDes(TestCase):
|
||||
scripted_linear = torch.jit.script(Linear(weight, bias))
|
||||
scripted_linear_clamp_prepacked = torch.jit.script(LinearPrePacked(weight, bias))
|
||||
input_data = torch.rand(data_shape)
|
||||
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
ref_result = scripted_linear(input_data)
|
||||
output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
|
||||
torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
|
||||
|
||||
# Serialize the modules and then deserialize
|
||||
input_data = torch.rand(data_shape)
|
||||
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_linear, buffer)
|
||||
buffer.seek(0)
|
||||
@ -154,7 +166,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
use_bias=st.booleans())
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_conv2d(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
@ -169,7 +182,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
use_bias):
|
||||
use_bias,
|
||||
format):
|
||||
class Conv2D(torch.nn.Module):
|
||||
def __init__(self, weight, bias, strides, paddings, dilations, groups):
|
||||
super(Conv2D, self).__init__()
|
||||
@ -205,6 +219,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
dilations[1] * (kernels[1] - 1) + 1)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
bias = None
|
||||
if use_bias:
|
||||
@ -220,6 +236,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
|
||||
# Serialize the modules and then deserialize
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_conv2d, buffer)
|
||||
buffer.seek(0)
|
||||
@ -246,7 +264,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
linear_weight_output_dim=st.integers(2, 64),
|
||||
use_bias=st.booleans())
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_combined_model(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
@ -262,7 +281,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_w,
|
||||
dilation,
|
||||
linear_weight_output_dim,
|
||||
use_bias):
|
||||
use_bias,
|
||||
format):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, conv_weight, conv_bias, linear_weight, linear_bias,
|
||||
strides, paddings, dilations, groups):
|
||||
@ -311,6 +331,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
dilations[1] * (kernels[1] - 1) + 1)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
conv_bias = None
|
||||
if use_bias:
|
||||
@ -323,7 +345,6 @@ class TestXNNPACKSerDes(TestCase):
|
||||
strides, paddings, dilations, groups)
|
||||
linear_input_shape = result.shape[1]
|
||||
|
||||
input_data = input_data.contiguous(memory_format=torch.channels_last)
|
||||
linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
|
||||
linear_bias = None
|
||||
if use_bias:
|
||||
|
Reference in New Issue
Block a user