mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Removed an internal assertion for the optional stable value and inste… (#117414)
…ad defaulted to the standard (=false). Fixes #117255. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117414 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1872834247
commit
4a54ab328c
@ -72,9 +72,6 @@ TORCH_META_FUNC(topk)
|
|||||||
|
|
||||||
TORCH_META_FUNC2(sort, stable)
|
TORCH_META_FUNC2(sort, stable)
|
||||||
(const Tensor& self, c10::optional<bool> stable, int64_t dim, bool descending) {
|
(const Tensor& self, c10::optional<bool> stable, int64_t dim, bool descending) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
stable.has_value(),
|
|
||||||
"sort(): c10::optional<bool> for stable has to have value.");
|
|
||||||
maybe_wrap_dim(dim, self.dim());
|
maybe_wrap_dim(dim, self.dim());
|
||||||
|
|
||||||
// See issue: https://github.com/pytorch/pytorch/issues/65863
|
// See issue: https://github.com/pytorch/pytorch/issues/65863
|
||||||
@ -953,7 +950,7 @@ TORCH_IMPL_FUNC(sort_stable_out)
|
|||||||
indices.zero_();
|
indices.zero_();
|
||||||
} else {
|
} else {
|
||||||
dim = maybe_wrap_dim(dim, self.dim());
|
dim = maybe_wrap_dim(dim, self.dim());
|
||||||
sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value());
|
sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value_or(false));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,6 +137,13 @@ class TestSortAndSelect(TestCase):
|
|||||||
self.assertIsOrdered('descending', x, res2val, res2ind,
|
self.assertIsOrdered('descending', x, res2val, res2ind,
|
||||||
'random with NaNs')
|
'random with NaNs')
|
||||||
|
|
||||||
|
def test_sort_stable_none(self):
|
||||||
|
# Called sort with stable=None used to trigger an assertion
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/117255
|
||||||
|
x = torch.ones(10)
|
||||||
|
y = x.sort(stable=None).values
|
||||||
|
self.assertTrue(torch.all(y == torch.ones(10)).item())
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_sort_large_slice(self, device):
|
def test_sort_large_slice(self, device):
|
||||||
# tests direct cub path
|
# tests direct cub path
|
||||||
|
Reference in New Issue
Block a user