Removed an internal assertion for the optional stable value and inste… (#117414)

…ad defaulted to the standard (=false).

Fixes #117255.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117414
Approved by: https://github.com/ezyang
This commit is contained in:
Tobias Ringwald
2024-01-17 02:25:21 +00:00
committed by PyTorch MergeBot
parent 1872834247
commit 4a54ab328c
2 changed files with 8 additions and 4 deletions

View File

@ -72,9 +72,6 @@ TORCH_META_FUNC(topk)
TORCH_META_FUNC2(sort, stable)
(const Tensor& self, c10::optional<bool> stable, int64_t dim, bool descending) {
TORCH_INTERNAL_ASSERT(
stable.has_value(),
"sort(): c10::optional<bool> for stable has to have value.");
maybe_wrap_dim(dim, self.dim());
// See issue: https://github.com/pytorch/pytorch/issues/65863
@ -953,7 +950,7 @@ TORCH_IMPL_FUNC(sort_stable_out)
indices.zero_();
} else {
dim = maybe_wrap_dim(dim, self.dim());
sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value());
sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value_or(false));
}
}

View File

@ -137,6 +137,13 @@ class TestSortAndSelect(TestCase):
self.assertIsOrdered('descending', x, res2val, res2ind,
'random with NaNs')
def test_sort_stable_none(self):
# Called sort with stable=None used to trigger an assertion
# See https://github.com/pytorch/pytorch/issues/117255
x = torch.ones(10)
y = x.sort(stable=None).values
self.assertTrue(torch.all(y == torch.ones(10)).item())
@onlyCUDA
def test_sort_large_slice(self, device):
# tests direct cub path