Support all unsigned int sizes on unique (#123643)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123643
Approved by: https://github.com/albanD, https://github.com/kit1980
This commit is contained in:
Edward Z. Yang
2024-04-10 18:05:40 -07:00
committed by PyTorch MergeBot
parent 416f532753
commit 8aad72b0d3
11 changed files with 56 additions and 36 deletions

View File

@ -51,5 +51,8 @@ void radix_sort_keys(
int64_t end_bit);
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)
AT_INSTATIATE_CUB_TEMPLATES(uint16_t, UInt16)
AT_INSTATIATE_CUB_TEMPLATES(uint32_t, UInt32)
AT_INSTATIATE_CUB_TEMPLATES(uint64_t, UInt64)
} // namespace at::cuda::cub

View File

@ -77,6 +77,9 @@ AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)
AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8)
AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8)
AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8)
// BFloat16 Radix sort is supported from ROCm 4.5 onwards
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500)

View File

@ -4,6 +4,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
@ -2255,7 +2256,7 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
.promote_inputs_to_common_dtype(true)
.build();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "equal_cpu", [&] {
AT_DISPATCH_V2(iter.input_dtype(), "equal_cpu", AT_WRAP([&] {
iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
if (!result) {
return;
@ -2271,7 +2272,7 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
other_data += strides[1];
}
});
});
}), kBool, kBFloat16, kHalf, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
return result.load();
}

View File

@ -2,7 +2,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/Parallel.h>
#include <ATen/native/TensorIterator.h>
#include <c10/util/irange.h>
@ -446,13 +446,13 @@ _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
self, return_inverse, /* return_counts */false);
return std::make_tuple(output, inverse);
}
return AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "unique", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique", [&] AT_WRAP({
// The current CPU implementation of unique always sort due to
// this is faster than hash table
auto [output, inverse, _] = unique_cpu_sorted_template<scalar_t>(
self, return_inverse, /* return_counts */false, IsUnique<scalar_t, /* equal_nan */false>());
return std::make_tuple(output, inverse);
});
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
@ -460,35 +460,35 @@ _unique2_cpu(const Tensor& self, const bool sorted, const bool return_inverse, c
if (self.scalar_type() == kBool) {
return unique_cpu_bool_template(self, return_inverse, return_counts);
}
return AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "unique", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
// The current CPU implementation of unique always sort due to
// this is faster than hash table
return unique_cpu_sorted_template<scalar_t>(
self, return_inverse, return_counts, IsUnique<scalar_t, /* equal_nan */ false>());
});
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] {
// The current implementation using `dim` always sorts due to unhashable tensors
return _unique_dim_cpu_template<scalar_t>(self, dim, false, return_inverse, return_counts);
});
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
unique_dim_consecutive_cpu(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] {
return _unique_dim_cpu_template<scalar_t>(self, dim, true, return_inverse, return_counts);
});
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
unique_consecutive_cpu(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
if (!dim.has_value() || (dim.value() == 0 && self.dim() == 1)) {
return AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kBool, kHalf, self.scalar_type(), "unique", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
return unique_consecutive_cpu_template<scalar_t>(self, return_inverse, return_counts);
});
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
return unique_dim_consecutive_cpu(self, dim.value(), return_inverse, return_counts);
}

View File

@ -5,6 +5,7 @@
#include <ATen/native/Sorting.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/Parallel.h>
#include <ATen/NumericUtils.h>
#include <ATen/TensorIterator.h>
@ -42,9 +43,8 @@ void _dim_apply(
auto indices_dim_stride = indices.stride(dim);
auto dim_size = values.size(dim);
AT_DISPATCH_ALL_TYPES_AND3(
ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, iter.dtype(),
"sorting_kernel_method_name", [&] {
AT_DISPATCH_V2(
iter.dtype(), "sorting_kernel_method_name", AT_WRAP([&] {
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* values_data_bytes = data[0];
auto* indices_data_bytes = data[1];
@ -69,7 +69,7 @@ void _dim_apply(
int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size);
iter.for_each(loop, /*grain_size=*/grain_size);
}
}), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
);
}

View File

@ -1,6 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/ThrustAllocator.h>
@ -186,45 +186,45 @@ std::tuple<Tensor, Tensor, Tensor> unique_dim_cuda_template(
std::tuple<Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
auto [output, inverse, _] = internal::unique_cuda_template<scalar_t>(self, false, return_inverse, false);
return std::make_tuple(output, inverse);
});
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
_unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return internal::unique_cuda_template<scalar_t>(self, false, return_inverse, return_counts);
});
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] {
return unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, return_counts);
});
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] {
return unique_dim_cuda_template<scalar_t>(self, dim, true, return_inverse, return_counts);
});
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
std::tuple<Tensor, Tensor, Tensor>
unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional<int64_t> dim) {
if (!dim.has_value()) {
return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique", [&] {
return AT_DISPATCH_V2(self.scalar_type(), "unique", AT_WRAP([&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return internal::unique_cuda_template<scalar_t>(self, true, return_inverse, return_counts);
});
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
return unique_dim_consecutive_cuda(self, dim.value(), return_inverse, return_counts);
}

View File

@ -335,6 +335,9 @@ INSTANTIATE_UNIQUE_CUDA_TEMPLATE(float);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int32_t);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int64_t);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(int16_t);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint32_t);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint64_t);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint16_t);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(bool);
INSTANTIATE_UNIQUE_CUDA_TEMPLATE(at::Half);

View File

@ -84,6 +84,13 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
// - We must not promote uint64 to int64 because this will overflow.
//
// It'll be a bit of work to fix it, so we're punting on it for now.
// However, float promotion is fine, so we handle that.
if (isFloatingType(a)) {
return a;
}
if (isFloatingType(b)) {
return b;
}
TORCH_CHECK(
false,
"Promotion for uint16, uint32, uint64 types is not supported, attempted to promote ",

View File

@ -64,6 +64,9 @@ i32 = torch.int32
i64 = torch.int64
b8 = torch.bool
u8 = torch.uint8
u16 = torch.uint16
u32 = torch.uint32
u64 = torch.uint64
foreach_op_db = (
foreach_unary_op_db +
@ -659,8 +662,8 @@ meta_function_expected_failures = {
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
torch.bincount : {i32, i64, u8, i16, i8},
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64},
torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32, u16, u32, u64},
torch.histc : {f64, f16, bf16, f32},
torch.histogram : {f64, f32},
torch.histogramdd : {f64, f32},
@ -832,7 +835,7 @@ meta_dispatch_expected_failures = {
aten._histogramdd_from_bin_cts.default : {f32, f64},
aten._histogramdd_from_bin_tensors.default : {f32, f64},
aten._local_scalar_dense.default : {c32, c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
aten._unique2.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
aten.bincount.default : {i64, i8, i32, i16, u8},
aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.histc.default : {bf16, f32, f64},
@ -840,8 +843,8 @@ meta_dispatch_expected_failures = {
aten.histogram.bin_ct : {f32, f64},
aten.histogram.bins_tensor : {f32, f64},
aten.kthvalue.default : {i8, f64, i64, f16, bf16, f32, i32, i16, u8},
aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8, u16, u32, u64},
aten.upsample_nearest3d.vec : {bf16, f32, f64, u8},
}

View File

@ -2055,7 +2055,7 @@ class TestTestParametrizationDeviceType(TestCase):
for test_func, name in _get_test_funcs_for_test_class(device_cls):
should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or
('test_other' in name and 'y_5' in name) or
('test_three' in name and name.endswith('int16')))
('test_three' in name and name.endswith('_int16')))
self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
def test_modules_decorator_applies_module_and_param_specific_decorators(self, device):

View File

@ -16793,8 +16793,8 @@ op_db: List[OpInfo] = [
skips=(
)),
OpInfo('unique',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16),
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.uint16, torch.uint32, torch.uint64),
sample_inputs_func=sample_inputs_unique,
supports_out=False,
supports_autograd=False,