mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
improve mkldnn_linear_pointwise performance for contiguous tensor with non default contiguous strides (#114939)
This PR will convert the stride to the default contiguous stride in `mkldnn_linear_pointwise` before calling oneDNN to run into an optimization path similar to https://github.com/pytorch/pytorch/pull/99511. Also refactored the code to provide a common utility function.
https://github.com/pytorch/pytorch/pull/111976 will ignore Dims of value 1 in Require_Stride_order. For a tensor with `size = [1, 1280]`, `stride = [0, 1]`:
**Before the above PR**, it is considered as non-contiguous, thus in the below call, it is converted to `size = [1, 1280]`, `stride = [1280,1]`:
25b83521be/torch/_inductor/ir.py (L5263)
**While after the above PR**, dims of value 1 are ignored so this tensor is already contiguous and we'll feed a tensor with `stride = [0, 1]` to oneDNN, which results in poor performance.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114939
Approved by: https://github.com/jgong5
This commit is contained in:
committed by
PyTorch MergeBot
parent
e666159e2f
commit
80d8a2a237
@ -205,11 +205,15 @@ class TestMkldnnFusion(JitTestCase):
|
||||
return x
|
||||
|
||||
for pointwise_info in self._unary_list().values():
|
||||
options = itertools.product([[2, 3, 10], [2, 10]], [True, False])
|
||||
for input_shape, bias in options:
|
||||
# Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor
|
||||
# but it's strides is not default contiguous strides.
|
||||
options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False])
|
||||
for (input_shape, input_stride), bias in options:
|
||||
with torch.no_grad():
|
||||
mod = M(pointwise_info.pointwise_module, input_shape[-1], 10, bias).eval()
|
||||
v = torch.randn(input_shape)
|
||||
if input_stride is not None:
|
||||
v = v.as_strided(input_shape, input_stride)
|
||||
ref = mod(v)
|
||||
attr = pointwise_info.attr
|
||||
scalars = pointwise_info.scalars
|
||||
|
Reference in New Issue
Block a user