mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
improve mkldnn_linear_pointwise_binary performance for contiguous tensor with non default contiguous strides (#132019)
Fixes https://github.com/pytorch/pytorch/issues/131734 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132019 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
This commit is contained in:
committed by
PyTorch MergeBot
parent
40f8db5741
commit
fc6066b80f
@ -268,6 +268,8 @@ static Tensor mkldnn_linear_pointwise_binary(
|
||||
check_mkldnn_binary_fusion_inputs(input_t, other_t, weight_t, bias);
|
||||
|
||||
auto input = input_t.contiguous();
|
||||
// Make sure input has default contiguous strides if it's contiguous tensors for better performance.
|
||||
input = may_convert_to_default_contiguous_strides(input);
|
||||
|
||||
auto it_binary = fusion_binary_alg_map().find(attr);
|
||||
TORCH_CHECK(
|
||||
@ -286,6 +288,7 @@ static Tensor mkldnn_linear_pointwise_binary(
|
||||
return output;
|
||||
}
|
||||
auto other_reshaped = other_t.contiguous();
|
||||
other_reshaped = may_convert_to_default_contiguous_strides(other_reshaped);
|
||||
|
||||
if (dim != 2) {
|
||||
std::vector<int64_t> output_size_reshaped = {
|
||||
|
@ -325,11 +325,15 @@ class TestMkldnnFusion(JitTestCase):
|
||||
|
||||
out_feature = 20
|
||||
for pointwise_name, pointwise_fn in self._binary_list().items():
|
||||
options = itertools.product([[2, 3, 10], [2, 10]], [True, False])
|
||||
for input_shape, bias in options:
|
||||
# Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor
|
||||
# but it's strides is not default contiguous strides.
|
||||
options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False])
|
||||
for (input_shape, input_stride), bias in options:
|
||||
with torch.no_grad():
|
||||
mod = M(pointwise_fn, input_shape[-1], out_feature, bias).eval()
|
||||
v = torch.randn(input_shape)
|
||||
if input_stride is not None:
|
||||
v = v.as_strided(input_shape, input_stride)
|
||||
other = torch.randn(input_shape[:-1] + [out_feature])
|
||||
ref = mod(v, other)
|
||||
attr = pointwise_name
|
||||
|
Reference in New Issue
Block a user