mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Avoid one unnecessary memory allocation in XNNPACK integration. (#35350)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35350 Currently we call input.contiguous() on the input tensor resulting in an unecessary allocation and copy in cases where the input is not contiguous with regards to the requested memory format. The reason is that in such scenarios, this call re-allocates and copies the input tensor into contiguous storage, only for this newly allocated tensor to be used as the source of another copy to the final destination. Instead, if we copy into the destination directly in such circumstances, we will save an allocation and a copy. Differential Revision: D20656798 Test Plan: Imported from OSS Pulled By: AshkanAliabadi fbshipit-source-id: 3f8c51df4d1fd386fa9473e7024621a7b7c6e86c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c33ea41f9c
commit
d0ce94d20e
@ -179,8 +179,8 @@ Tensor run(
|
||||
const Tensor& input) {
|
||||
using namespace internal;
|
||||
|
||||
const Tensor input_nhwc = input.contiguous(MemoryFormat::ChannelsLast);
|
||||
const Tensor padded_input_nhwc = allocate_padded_if_needed(input_nhwc);
|
||||
const Tensor padded_input_nhwc = allocate_padded_contiguous_if_needed(
|
||||
input, MemoryFormat::ChannelsLast);
|
||||
|
||||
TORCH_CHECK(
|
||||
usable(padded_input_nhwc),
|
||||
|
@ -17,8 +17,7 @@ Tensor empty_with_tail_padding(
|
||||
const IntArrayRef size,
|
||||
const caffe2::TypeMeta dtype,
|
||||
const c10::MemoryFormat memory_format) {
|
||||
auto* allocator_ptr = get_guarding_allocator();
|
||||
|
||||
auto* const allocator_ptr = get_guarding_allocator();
|
||||
const int64_t nelements = prod_intlist(size);
|
||||
|
||||
Tensor tensor(
|
||||
@ -35,17 +34,31 @@ Tensor empty_with_tail_padding(
|
||||
return tensor.resize_(size, memory_format);
|
||||
}
|
||||
|
||||
Tensor allocate_padded_if_needed(const Tensor& input_contig) {
|
||||
const auto* allocator = input_contig.storage().allocator();
|
||||
const auto* guarding_allocator = get_guarding_allocator();
|
||||
if (allocator == guarding_allocator) {
|
||||
return input_contig;
|
||||
Tensor allocate_padded_contiguous_if_needed(
|
||||
const Tensor& input,
|
||||
const c10::MemoryFormat memory_format) {
|
||||
const auto* const allocator = input.storage().allocator();
|
||||
const auto* const guarding_allocator = get_guarding_allocator();
|
||||
|
||||
// If the allocators are the same and the memory is contiguous in the requested
|
||||
// format, then there is no need to reallocate the tensor.
|
||||
|
||||
if ((allocator == guarding_allocator) && input.is_contiguous(memory_format)) {
|
||||
return input;
|
||||
}
|
||||
Tensor padded_input =
|
||||
empty_with_tail_padding(input_contig.sizes(), input_contig.options().dtype(),
|
||||
input_contig.suggest_memory_format());
|
||||
padded_input.copy_(input_contig);
|
||||
return padded_input;
|
||||
|
||||
// If there is a need to reallocate the tensor on the other hand, either because
|
||||
// the allocators are not the same, or the allocators are the same but the input
|
||||
// is not contiguous in the requested format, then reallocate and directly copy
|
||||
// into destination. There is no need to allocate a temporary contiguous memory
|
||||
// only to use it as the source of the copy operation onto our final destination.
|
||||
|
||||
Tensor padded_input = empty_with_tail_padding(
|
||||
input.sizes(),
|
||||
input.options().dtype(),
|
||||
memory_format);
|
||||
|
||||
return padded_input.copy_(input);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
@ -9,7 +9,9 @@ namespace native {
|
||||
namespace xnnpack {
|
||||
namespace internal {
|
||||
|
||||
Tensor allocate_padded_if_needed(const Tensor& input_contig);
|
||||
Tensor allocate_padded_contiguous_if_needed(
|
||||
const Tensor& input,
|
||||
c10::MemoryFormat memory_format);
|
||||
|
||||
// TODO: Remove this function when at::native::empty() is modified to accept a
|
||||
// custom memory allocator.
|
||||
|
@ -109,7 +109,8 @@ Tensor run(
|
||||
const Tensor& input) {
|
||||
using namespace internal;
|
||||
|
||||
const Tensor padded_input = allocate_padded_if_needed(input.contiguous());
|
||||
const Tensor padded_input = allocate_padded_contiguous_if_needed(
|
||||
input, input.suggest_memory_format());
|
||||
|
||||
TORCH_CHECK(
|
||||
usable(padded_input),
|
||||
|
@ -20,10 +20,13 @@ class TestXNNPACKOps(TestCase):
|
||||
@given(batch_size=st.integers(0, 3),
|
||||
data_shape=hu.array_shapes(1, 3, 2, 64),
|
||||
weight_output_dim=st.integers(2, 64),
|
||||
use_bias=st.booleans())
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias, format):
|
||||
data_shape = [batch_size] + list(data_shape)
|
||||
input_data = torch.rand(data_shape)
|
||||
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((weight_output_dim, data_shape[-1]))
|
||||
if use_bias:
|
||||
bias = torch.rand((weight_output_dim))
|
||||
@ -47,7 +50,8 @@ class TestXNNPACKOps(TestCase):
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
use_bias=st.booleans())
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_conv2d(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
@ -62,7 +66,8 @@ class TestXNNPACKOps(TestCase):
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
use_bias):
|
||||
use_bias,
|
||||
format):
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
kernels = (kernel_h, kernel_w)
|
||||
@ -75,6 +80,8 @@ class TestXNNPACKOps(TestCase):
|
||||
dilations[1] * (kernels[1] - 1) + 1)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
bias = None
|
||||
if use_bias:
|
||||
@ -95,8 +102,9 @@ class TestXNNPACKSerDes(TestCase):
|
||||
@given(batch_size=st.integers(0, 3),
|
||||
data_shape=hu.array_shapes(1, 3, 2, 64),
|
||||
weight_output_dim=st.integers(2, 64),
|
||||
use_bias=st.booleans())
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias, format):
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, weight, bias=None):
|
||||
super(Linear, self).__init__()
|
||||
@ -123,12 +131,16 @@ class TestXNNPACKSerDes(TestCase):
|
||||
scripted_linear = torch.jit.script(Linear(weight, bias))
|
||||
scripted_linear_clamp_prepacked = torch.jit.script(LinearPrePacked(weight, bias))
|
||||
input_data = torch.rand(data_shape)
|
||||
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
ref_result = scripted_linear(input_data)
|
||||
output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
|
||||
torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
|
||||
|
||||
# Serialize the modules and then deserialize
|
||||
input_data = torch.rand(data_shape)
|
||||
if ((format is not None) and ((format != torch.channels_last) or (len(data_shape) == 4))):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_linear, buffer)
|
||||
buffer.seek(0)
|
||||
@ -154,7 +166,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
use_bias=st.booleans())
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_conv2d(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
@ -169,7 +182,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
use_bias):
|
||||
use_bias,
|
||||
format):
|
||||
class Conv2D(torch.nn.Module):
|
||||
def __init__(self, weight, bias, strides, paddings, dilations, groups):
|
||||
super(Conv2D, self).__init__()
|
||||
@ -205,6 +219,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
dilations[1] * (kernels[1] - 1) + 1)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
bias = None
|
||||
if use_bias:
|
||||
@ -220,6 +236,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
|
||||
# Serialize the modules and then deserialize
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_conv2d, buffer)
|
||||
buffer.seek(0)
|
||||
@ -246,7 +264,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
linear_weight_output_dim=st.integers(2, 64),
|
||||
use_bias=st.booleans())
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_combined_model(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
@ -262,7 +281,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
pad_w,
|
||||
dilation,
|
||||
linear_weight_output_dim,
|
||||
use_bias):
|
||||
use_bias,
|
||||
format):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, conv_weight, conv_bias, linear_weight, linear_bias,
|
||||
strides, paddings, dilations, groups):
|
||||
@ -311,6 +331,8 @@ class TestXNNPACKSerDes(TestCase):
|
||||
dilations[1] * (kernels[1] - 1) + 1)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
conv_bias = None
|
||||
if use_bias:
|
||||
@ -323,7 +345,6 @@ class TestXNNPACKSerDes(TestCase):
|
||||
strides, paddings, dilations, groups)
|
||||
linear_input_shape = result.shape[1]
|
||||
|
||||
input_data = input_data.contiguous(memory_format=torch.channels_last)
|
||||
linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
|
||||
linear_bias = None
|
||||
if use_bias:
|
||||
|
Reference in New Issue
Block a user