mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Support torch.bool in torch.sort + CUDA (#139409)
Summary: This might be out-dated, so I'm adding it back and see if we pass all the tests. I'm pretty sure cuda12 is ok. Test Plan: CI Differential Revision: D65282650 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139409 Approved by: https://github.com/zou3519, https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
06f619d999
commit
e7cf7d00be
@ -63,9 +63,6 @@ void sort_cuda_kernel(
|
||||
"The dimension being sorted can not have more than INT_MAX elements.");
|
||||
|
||||
const auto self_dtype = self.dtype();
|
||||
// FIXME: remove this check once cub sort supports bool
|
||||
TORCH_CHECK(self_dtype != ScalarType::Bool,
|
||||
"Sort currently does not support bool dtype on CUDA.");
|
||||
TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble,
|
||||
"Sort currently does not support complex dtypes on CUDA.");
|
||||
|
||||
|
@ -193,8 +193,7 @@ class TestSortAndSelect(TestCase):
|
||||
self.assertEqual(res1val, res1val_cpu.cuda())
|
||||
self.assertEqual(res1ind, res1ind_cpu.cuda())
|
||||
|
||||
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
|
||||
def test_stable_sort(self, device, dtype):
|
||||
sizes = (100, 1000, 10000)
|
||||
for ncopies in sizes:
|
||||
@ -323,8 +322,7 @@ class TestSortAndSelect(TestCase):
|
||||
self.assertEqual(indices, indices_cont)
|
||||
self.assertEqual(values, values_cont)
|
||||
|
||||
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
|
||||
def test_stable_sort_against_numpy(self, device, dtype):
|
||||
if dtype in floating_types_and(torch.float16, torch.bfloat16):
|
||||
inf = float("inf")
|
||||
|
@ -3320,6 +3320,9 @@ def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs):
|
||||
flag = [True, False]
|
||||
for dim, descending, stable in product(dims, flag, flag):
|
||||
# default schema without stable sort
|
||||
if not (dtype == torch.bool and torch.device(device).type == 'cuda'):
|
||||
# bool and cuda requires stable sort for stable results, at least
|
||||
# for the return index
|
||||
yield SampleInput(small_3d_unique(), dim, descending)
|
||||
# schema with stable sort, no CUDA support yet
|
||||
if torch.device(device).type == 'cpu':
|
||||
@ -18477,11 +18480,13 @@ op_db: List[OpInfo] = [
|
||||
)),
|
||||
OpInfo('sort',
|
||||
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_sort,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
|
||||
dtypes=[torch.bool], device_type='cuda'),
|
||||
)),
|
||||
OpInfo('unique',
|
||||
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64),
|
||||
@ -19506,12 +19511,14 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_unfold),
|
||||
OpInfo('msort',
|
||||
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
check_batched_gradgrad=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_msort,
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
|
||||
dtypes=[torch.bool], device_type='cuda'),
|
||||
)),
|
||||
OpInfo('movedim',
|
||||
aliases=('moveaxis',),
|
||||
@ -21324,7 +21331,7 @@ op_db: List[OpInfo] = [
|
||||
OpInfo(
|
||||
"argsort",
|
||||
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_sort,
|
||||
supports_out=False,
|
||||
supports_autograd=False,
|
||||
@ -21335,6 +21342,13 @@ op_db: List[OpInfo] = [
|
||||
"test_variant_consistency_jit",
|
||||
dtypes=(torch.float32,),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestCommon",
|
||||
"test_non_standard_bool_values",
|
||||
dtypes=[torch.bool],
|
||||
device_type='cuda',
|
||||
),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
|
Reference in New Issue
Block a user