Port sort to structured kernels.

Tracking Issue: #55070

This PR relands #67016, with the modifications discussed in https://github.com/pytorch/pytorch/pull/67015#issuecomment-982004500.

In summary, we call `infer_dense_strides` on the input's strides, and pass it to `set_output`. Meaning that if one of the outputs is resized (by a `resize_output` call), we will also restride such an output using the dense strides of the input.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72058
Approved by: https://github.com/bdhirsh
This commit is contained in:
Yukio Siraichi
2022-04-19 15:54:52 +00:00
committed by PyTorch MergeBot
parent 74454bdb46
commit 8c37a056df
9 changed files with 109 additions and 167 deletions

View File

@ -163,6 +163,23 @@ class TestSortAndSelect(TestCase):
self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device))
self.assertEqual(im, t0.sort().indices)
@dtypes(torch.float32)
def test_sort_restride(self, device, dtype):
# Input: non-contiguous (stride: 5) 3-element array
tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
# Outputs: 0-dim tensors
# They will need to be resized, which means they will also be
# restrided with the input tensor's strides as base.
values = torch.tensor(0, dtype=dtype, device=device)
indices = torch.tensor(0, dtype=torch.long, device=device)
torch.sort(tensor, out=(values, indices))
# Check: outputs were restrided to dense strides
self.assertEqual(values.stride(), (1,))
self.assertEqual(indices.stride(), (1,))
# Check: 'tensor' indexed by 'indices' is equal to 'values'
self.assertEqual(tensor[indices], values)
def _test_sort_discontiguous(self, device, dtype):
# on CUDA 2048 vs >2048 have different code path for the dim being sorted
sizes = (5, 7, 2049)