mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port sort
to structured kernels.
Tracking Issue: #55070 This PR relands #67016, with the modifications discussed in https://github.com/pytorch/pytorch/pull/67015#issuecomment-982004500. In summary, we call `infer_dense_strides` on the input's strides, and pass it to `set_output`. Meaning that if one of the outputs is resized (by a `resize_output` call), we will also restride such an output using the dense strides of the input. Pull Request resolved: https://github.com/pytorch/pytorch/pull/72058 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
74454bdb46
commit
8c37a056df
@ -163,6 +163,23 @@ class TestSortAndSelect(TestCase):
|
||||
self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device))
|
||||
self.assertEqual(im, t0.sort().indices)
|
||||
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_sort_restride(self, device, dtype):
|
||||
# Input: non-contiguous (stride: 5) 3-element array
|
||||
tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
|
||||
# Outputs: 0-dim tensors
|
||||
# They will need to be resized, which means they will also be
|
||||
# restrided with the input tensor's strides as base.
|
||||
values = torch.tensor(0, dtype=dtype, device=device)
|
||||
indices = torch.tensor(0, dtype=torch.long, device=device)
|
||||
torch.sort(tensor, out=(values, indices))
|
||||
# Check: outputs were restrided to dense strides
|
||||
self.assertEqual(values.stride(), (1,))
|
||||
self.assertEqual(indices.stride(), (1,))
|
||||
# Check: 'tensor' indexed by 'indices' is equal to 'values'
|
||||
self.assertEqual(tensor[indices], values)
|
||||
|
||||
def _test_sort_discontiguous(self, device, dtype):
|
||||
# on CUDA 2048 vs >2048 have different code path for the dim being sorted
|
||||
sizes = (5, 7, 2049)
|
||||
|
Reference in New Issue
Block a user