diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index ad2a0e2f67a1..2118d8f4f871 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -800,27 +800,6 @@ - arg: bool keepdim default: "false" ]] -[[ - name: _th_kthvalue - backends: - - CPU - variants: function - cname: kthvalue - return: argument 0,1 - scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1) - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - long k - - arg: long dim - wrap_dim: self - default: __last_dim - - arg: bool keepdim - default: "false" -]] [[ name: _th_mode variants: function diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp new file mode 100644 index 000000000000..4bc3c2c7e8ba --- /dev/null +++ b/aten/src/ATen/native/Sorting.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include + +namespace at { +namespace native { + +namespace { + +// maybe these days, one should define a random access iterator and use +// std::sort... +/* Note from TH: + + I cut and pasted (slightly adapted) the quicksort code from + Sedgewick's 1978 "Implementing Quicksort Programs" article + http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf + + It is the state of the art existing implementation. The macros + are here to make as close a match as possible to the pseudocode of + Program 2 p.851 + + Note that other partition schemes exist, and are typically presented + in textbook, but those are less efficient. See e.g. + http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto + + Julien, November 12th 2013 +*/ + +constexpr int64_t MAX_LEVELS = 300; +constexpr int64_t M_SMALL = 10; // Limit for small subfiles + +template +void dim_apply(TensorList tensors, int64_t dim, Fn f) { + AT_ASSERT(tensors.size() > 0); + auto t = tensors[0]; + auto sizes = t.sizes(); + int64_t ndim = t.dim(); + int64_t itersize = 1; + for (int64_t i = 0; i < ndim; i++) { + if (i != dim) { + itersize *= t.size(i); + } + } + parallel_for(0, itersize, 1, [&](int64_t i_begin, int64_t i_end) { + std::vector narrowed_tensors; + narrowed_tensors.reserve(tensors.size()); + for (int64_t it = i_begin; it < i_end; it++) { + narrowed_tensors.clear(); + for (auto ti : tensors) { + int64_t i = it; + Tensor nt = ti; + for (size_t d = 0; d < ndim; d++) { + if (d != dim) { + // this could be avoided for slower-changing dimensions if done + // better + nt = nt.select((d > dim ? 1 : 0), i % sizes[d]); + i = i / sizes[d]; + } + } + narrowed_tensors.emplace_back(nt); + } + f(it, narrowed_tensors); + } + }); +} + +template +void quick_select_template( + TensorAccessor arr, + int64_t k, + Fn swap_fn) { + int64_t P, L, R, i, j, swap; + scalar_t rswap, piv; + L = 0; + R = arr.size(0) - 1; + + do { + if (R <= L) // One element only + return; + + if (R == L + 1) { // Two elements only + if (arr[L] > arr[R]) { + swap_fn(L, R); + } + return; + } + + // Use median of three for pivot choice + P = (L + R) >> 1; + swap_fn(P, L + 1); + if (arr[L + 1] > arr[R]) { + swap_fn(L + 1, R); + } + if (arr[L] > arr[R]) { + swap_fn(L, R); + } + if (arr[L + 1] > arr[L]) { + swap_fn(L + 1, L); + } + + i = L + 1; + j = R; + piv = arr[L]; + do { + do + i++; + while (arr[i] < piv); + do + j--; + while (arr[j] > piv); + if (j < i) + break; + swap_fn(i, j); + } while (1); + swap_fn(L, j); + + // Re-set active partition + if (j <= k) + L = i; + if (j >= k) + R = j - 1; + } while (1); +} + +} // namespace + +std::tuple kthvalue_out_cpu( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); + // FIXME: This seems bogus, I only do this because it was the old behaviour. + // The reductions are fine, as long as the axis being reduced along + // isn't of 0 elements (and the output has elements). + AT_CHECK( + self.numel() > 0, + "cannot perform reduction function kthvalue", + " on tensor with no elements because the operation does not have an identity"); + AT_CHECK( + k > 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), + "selected index k out of range"); + + _reduction_with_indices_allocate_or_resize_output( + values, indices, self, dim_, keepdim); + if (self.dim() == 0 && self.numel() == 1) { + values.copy_(self); + indices.zero_(); + return std::forward_as_tuple(values, indices); + } + auto tmp_values = self.clone(); + auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong)); + AT_DISPATCH_ALL_TYPES(self.type(), "kthvalue", [&] { + dim_apply( + {tmp_values, tmp_indices, values, indices}, + dim, + [&](int64_t i, TensorList tl) { + auto tmp_values = tl[0].accessor(); + auto tmp_indices = tl[1].accessor(); + scalar_t* mode_value = tl[2].data(); + int64_t* mode_index = tl[3].data(); + for (int64_t j = 0; j < tmp_indices.size(0); j++) { + tmp_indices[j] = j; + } + quick_select_template(tmp_values, k - 1, [&](int64_t i, int64_t j) { + std::swap(tmp_values[i], tmp_values[j]); + std::swap(tmp_indices[i], tmp_indices[j]); + }); + *mode_value = tmp_values[k - 1]; + *mode_index = tmp_indices[k - 1]; + }); + }); + if (!keepdim) { + values.squeeze_(dim); + indices.squeeze_(dim); + } + return std::forward_as_tuple(values, indices); +} + +std::tuple kthvalue( + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + Tensor values = at::empty({0}, self.options()); + Tensor indices = at::empty({0}, self.options().dtype(kLong)); + at::kthvalue_out(values, indices, self, k, dim, keepdim); + return std::make_tuple(values, indices); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/SortingUtils.h b/aten/src/ATen/native/SortingUtils.h new file mode 100644 index 000000000000..9cc7afb8b4c1 --- /dev/null +++ b/aten/src/ATen/native/SortingUtils.h @@ -0,0 +1,48 @@ +#pragma once + +namespace at { +namespace native { + +// ensure we get good values and indices for kthvalue, mode, median +// this will always be with the reducing dim as 1-d +static void _reduction_with_indices_allocate_or_resize_output( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); + auto result_sizes = self.sizes().vec(); + if (result_sizes.size() > 0) { + result_sizes[dim] = 1; + } + if (values.defined()) { + AT_CHECK( + self.type() == values.type(), + "output values must be of same type as input"); + if (!keepdim && values.dim() == self.dim() - 1) { + // unsqueeze to preserve passed in noncontiguous tensor in resize + values.unsqueeze_(dim); + } + values.resize_(result_sizes); + } else { + values = at::empty(result_sizes, self.options()); + } + if (indices.defined()) { + AT_CHECK( + indices.dtype() == kLong, "output indices must be of scalar type Long"); + AT_CHECK( + indices.device() == self.device(), + "output indices must be on same device as input"); + if (!keepdim && indices.dim() == self.dim() - 1) { + // unsqueeze to preserve passed in noncontiguous tensor in resize + indices.unsqueeze_(dim); + } + indices.resize_(result_sizes); + } else { + indices = at::empty(result_sizes, self.options().dtype(kLong)); + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 0f58f4f76c75..eaac69b8b319 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -97,26 +97,6 @@ Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& o return ret; } -std::tuple kthvalue(const Tensor& self, int64_t k, int64_t dim, bool keepdim) { - Tensor values = at::empty({0}, self.options()); - Tensor indices = at::empty({0}, self.options().dtype(kLong)); - return at::native::kthvalue_out(values, indices, self, k, dim, keepdim); -} - -std::tuple kthvalue_out(Tensor& values, Tensor& indices, - const Tensor& self, int64_t k, int64_t dim, bool keepdim) { - AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "kthvalue only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); - dim = maybe_wrap_dim(dim, self.dim()); - if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "kthvalue")) { - AT_ASSERT(values.dim() == 0); - indices.resize_({}).fill_(0); - return std::forward_as_tuple(values, indices); - } else { - return at::legacy::th::_th_kthvalue_out(values, indices, self, k, dim, keepdim); - } -} - std::tuple median(const Tensor& self, int64_t dim, bool keepdim) { Tensor values = at::empty({0}, self.options()); Tensor indices = at::empty({0}, self.options().dtype(kLong)); diff --git a/aten/src/ATen/native/cuda/SortingCommon.cuh b/aten/src/ATen/native/cuda/SortingCommon.cuh new file mode 100644 index 000000000000..1f7988d3955a --- /dev/null +++ b/aten/src/ATen/native/cuda/SortingCommon.cuh @@ -0,0 +1,226 @@ +#include +#include +#include +#include +#include +#include +#include +#include // only for THCRoundUp? +#include +#include +#include // AddOp + +namespace at { +namespace native { + +#if defined(__HIP_PLATFORM_HCC__) +constexpr int WARP_SIZE = 64; +constexpr int MAX_BLOCK_SIZE = 256; + +#else +constexpr int WARP_SIZE = 32; +constexpr int MAX_BLOCK_SIZE = 1024; +#endif + +// Maximum size per grid dimension that we assume (compute capability >= 2.0) +constexpr int64_t MAX_GRID_SIZE = 65535LL; + +static bool getGridFromTiles(int64_t gridTiles, dim3& grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + + grid = dim3(gridX, gridY, gridZ); + return true; +} + +template +struct ThrustGTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && THCNumerics::isnan(lhs) && + !THCNumerics::isnan(rhs)) || + THCNumerics::gt(lhs, rhs); + } +}; + +template +struct ThrustLTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && THCNumerics::isnan(rhs) && + !THCNumerics::isnan(lhs)) || + THCNumerics::lt(lhs, rhs); + } +}; + +template +__device__ __forceinline__ index_t getLinearBlockId() { + return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + + blockIdx.x; +} + +// `base` is the base address of a tensor +// For each slice (defined as a linear point of `out`, from 0 -> +// (sliceSize - 1) * sliceStride, we fill that slice from `0` to +// `sliceSize - 1`. +template +__global__ void fillSliceWithIndex_kernel( + cuda::detail::TensorInfo out, + index_t totalSlices, + index_t sliceSize, + index_t sliceStride) { + index_t slice = getLinearBlockId(); + + if (slice >= totalSlices) { + return; + } + + const uint64_t offset = + cuda::detail::IndexToOffset::get(slice, out); + int64_t* base = &out.data[offset]; + + for (int64_t i = threadIdx.x; i < sliceSize; i += blockDim.x) { + // Torch indices are 1-based (hence the +1) + base[i * sliceStride] = i; + } +} + +// For slice sorting in Thrust; extracts a slice index from a linear +// index and uses that for comparison +struct SliceComp { + SliceComp(int64_t size) : sliceSize(size) {} + + __device__ bool operator()(const int64_t& a, const int64_t& b) const { + // Since the slices are guaranteed to be innermost, + // the segment is just via int64_t division + int64_t segA = a / sliceSize; + int64_t segB = b / sliceSize; + return segA < segB; + } + + const int64_t sliceSize; +}; + +// For sorting in Thurst; extracts a within-slice index from a linear index +struct GlobalIndexToPerSliceIndex { + GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {} + + __device__ inline void operator()(int64_t& v) const { + v = v % sliceSize; + } + + const int64_t sliceSize; +}; + +// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks +static uint64_t nextHighestPowerOf2(uint64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; +#ifndef _MSC_VER + n |= n >> 32; +#endif + n++; + + return n; +} + + +template +void run_launcher( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + Launcher l) { + auto self_info = cuda::detail::getTensorInfo(self); + auto values_info = cuda::detail::getTensorInfo(values); + auto indices_info = cuda::detail::getTensorInfo(indices); + + int64_t slice_size = self.size(dim); + /* We use these structures solely to find the offset to */ + /* each slice we are operating on */ + self_info.reduceDim(dim); + values_info.reduceDim(dim); + indices_info.reduceDim(dim); + + /* Collapse all other dims */ + int collapse_self_dim = self_info.collapseDims(dim); + int collapse_values_dim = values_info.collapseDims(dim); + int collapse_indices_dim = indices_info.collapseDims(dim); + + int64_t num_slices = 1; + for (int i = 0; i < self_info.dims; ++i) { + num_slices *= self_info.sizes[i]; + } + + /* This is used as a template parameter to calculate indices. */ + /* We only specialize it if all collapsed dim sizes are the */ + /* same; otherwise, we use -1 which is the specialization */ + /* parameter for arbitrary dimensions */ + int all_dims = self_info.dims; + if (values_info.dims != all_dims || indices_info.dims != all_dims) { + all_dims = -1; + } + + if (all_dims == 1) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 2) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 3) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu new file mode 100644 index 000000000000..d3c488ee3768 --- /dev/null +++ b/aten/src/ATen/native/cuda/SortingKthValue.cu @@ -0,0 +1,249 @@ +#include +#include +#include +#include +#include +#include +#include +#include // only for THCRoundUp? +#include +#include +#include // AddOp + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +namespace { + + +template +__global__ void gatherKthValue( + cuda::detail::TensorInfo input, + index_t inputSliceSize, + index_t k, + + index_t numInputSlices, + index_t inputWithinSliceStride, + + cuda::detail::TensorInfo kthValue, + cuda::detail::TensorInfo indices) { + // Indices are limited to integer fp precision, so counts can fit in + // int32, regardless of index_t + __shared__ int smem[WARP_SIZE]; // one per each warp, up to warp limit + + index_t slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Find the start offset for our slice + index_t sliceStartIndex = + cuda::detail::IndexToOffset::get(slice, input); + index_t kthValueSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, kthValue); + index_t indicesSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, indices); + + scalar_t* inputSliceStart = &input.data[sliceStartIndex]; + scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + // Find the k-th highest element in our input + scalar_t kValue = static_cast(0); + radixSelect< + scalar_t, + typename TopKTypeConfig::RadixType, + index_t, + false>( + inputSliceStart, + k, + inputSliceSize, + inputWithinSliceStride, + smem, + &kValue); + + // Find the index of the k-th highest element + index_t kValueIndex = 0; + bool foundKValue = false; + + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) + : static_cast(0); + bool isKValue = inRange && (THCNumerics::eq(v, kValue)); + + if (isKValue) { + kValueIndex = i; + foundKValue = true; + break; + } + } + + if (foundKValue) { + kthValueSliceStart[0] = kValue; + indicesSliceStart[0] = kValueIndex; + } +} + +struct KthValueLauncher { + int64_t k; + + KthValueLauncher(int64_t k) : k(k) {} + + template + inline void launch( + cuda::detail::TensorInfo values_info, + int collapse_values_dim, + cuda::detail::TensorInfo indices_info, + int collapse_indices_dim, + cuda::detail::TensorInfo self_info, + int collapse_self_dim, + int64_t num_slices, + int64_t slice_size) { + dim3 grid; + if (!getGridFromTiles(num_slices, grid)) { + AT_ERROR("slices are too many"); + } + + dim3 block( + std::min(THCRoundUp(slice_size, (int64_t)WARP_SIZE), (int64_t)1024)); + auto stream = at::cuda::getCurrentCUDAStream(); + gatherKthValue<<>>( + self_info, + slice_size, + k, + num_slices, + /* The actual dimension that the k-selection is running in */ + /* may have changed from collapseDims() */ + self_info.strides[collapse_self_dim], + values_info, + indices_info); + } +}; + +template +void kthvalue_cuda_template( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim()); + int64_t slicesize = self.size(dim); + // FIXME: This seems bogus, I only do this because it was the old behaviour. + // The reductions are fine, as long as the axis being reduced along + // isn't of 0 elements (and the output has elements). + AT_CHECK( + self.numel() > 0, + "cannot perform reduction function kthvalue", + " on tensor with no elements because the operation does not have an identity"); + AT_CHECK(k >= 1 && k <= slicesize, "selected number k out of range"); + + _reduction_with_indices_allocate_or_resize_output( + values, indices, self, dim, keepdim); + if (self.dim() == 0 && self.numel() == 1) { + values.copy_(self); + indices.zero_(); + return; + } + + AT_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + // Based on required index size, run the algorithm with the + // appropriate index type + if (cuda::detail::canUse32BitIndexMath(self) && + cuda::detail::canUse32BitIndexMath(values) && + cuda::detail::canUse32BitIndexMath(indices)) { + run_launcher( + values, indices, self, dim, KthValueLauncher(k)); + } else { + run_launcher( + values, indices, self, dim, KthValueLauncher(k)); + } + + if (!keepdim) { + values.squeeze_(dim); + indices.squeeze_(dim); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +// this does not reduce to median with dim beause we don't want to copy twice +template +Tensor median_cuda_template(const Tensor& self) { + AT_CHECK(self.numel() > 0, "median cannot be called with empty tensor"); + if (self.dim() == 0 && self.numel() == 1) { + return self.clone(); + } + auto self_copy = self.clone().view(-1); + auto values = at::empty({1}, self.options()); + auto indices = at::empty({1}, self.options().dtype(kLong)); + AT_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + // Based on required index size, run the algorithm with the + // appropriate index type + if (cuda::detail::canUse32BitIndexMath(self) && + cuda::detail::canUse32BitIndexMath(values) && + cuda::detail::canUse32BitIndexMath(indices)) { + run_launcher( + values, + indices, + self_copy, + 0, + KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based + } else { + run_launcher( + values, + indices, + self_copy, + 0, + KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based + } + return values.view({}); +} + +} // namespace + +std::tuple kthvalue_out_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "kthvalue", [&] { + kthvalue_cuda_template(values, indices, self, k, dim, keepdim); + }); + return std::forward_as_tuple(values, indices); +} + +Tensor median_cuda(const Tensor& self) { + return AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "median", [&] { + return median_cuda_template(self); + }); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh new file mode 100644 index 000000000000..2c5d81ab21da --- /dev/null +++ b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh @@ -0,0 +1,392 @@ +namespace at { +namespace native { + +template +struct TopKTypeConfig {}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + // Converts a float to an integer representation with the same + // sorting; i.e., for floats f1, f2: + // if f1 < f2 then convert(f1) < convert(f2) + // We use this to enable radix selection of floating-point values. + // This also gives a relative order for NaNs, but that's ok, as they + // will all be adjacent + static inline __device__ RadixType convert(float v) { + RadixType x = __float_as_int(v); + RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; + + return (x ^ mask); + } + + static inline __device__ float deconvert(RadixType v) { + RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; + + return __int_as_float(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(uint8_t v) { + return v; + } + + static inline __device__ uint8_t deconvert(RadixType v) { + return v; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int8_t v) { + return 128u + v; + } + + static inline __device__ int8_t deconvert(RadixType v) { + return v - 128; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int16_t v) { + assert(sizeof(short) == 2); + return 32768u + v; + } + + static inline __device__ int16_t deconvert(RadixType v) { + return v - 32768; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int32_t v) { + assert(sizeof(int) == 4); + return 2147483648u + v; + } + + static inline __device__ int32_t deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(int64_t v) { + assert(sizeof(int64_t) == 8); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (x ^ mask); + } + + static inline __device__ double deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::Half v) { +#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ + RadixType x = __half_as_ushort(v); + RadixType mask = -((x >> 15)) | 0x8000; + return (x ^ mask); +#else + assert(false); + return 0u; +#endif + } + + static inline __device__ at::Half deconvert(RadixType v) { +#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ + RadixType mask = ((v >> 15) - 1) | 0x8000; + return __ushort_as_half(v ^ mask); +#else + assert(false); + return static_cast(0); +#endif + } +}; + +// This function counts the distribution of all input values in a +// slice we are selecting by radix digit at `radixDigitPos`, but only +// those that pass the filter `((v & desiredMask) == desired)`. +// This produces and broadcasts the seen counts for a single block only. +// `smem` must have at least `RadixSize` elements. +template < + typename scalar_t, + typename bitwise_t, + typename index_t, + typename CountType, + int RadixSize, + int RadixBits> +__device__ void countRadixUsingMask( + CountType counts[RadixSize], + CountType* smem, + bitwise_t desired, + bitwise_t desiredMask, + int radixDigitPos, + index_t sliceSize, + index_t withinSliceStride, + scalar_t* data) { + // Clear out per-thread counts from a previous round +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + + if (threadIdx.x < RadixSize) { + smem[threadIdx.x] = 0; + } + __syncthreads(); + + // Scan over all the data. Upon a read, the warp will accumulate + // counts per each digit in the radix using warp voting. + for (index_t i = threadIdx.x; i < sliceSize; i += blockDim.x) { + bitwise_t val = + TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); + + bool hasVal = ((val & desiredMask) == desired); + bitwise_t digitInRadix = + Bitfield::getBitfield(val, radixDigitPos, RadixBits); + +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = hasVal && (digitInRadix == j); +#if defined(__HIP_PLATFORM_HCC__) + counts[j] += __popcll(WARP_BALLOT(vote)); +#else + counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK())); +#endif + } + } + + // Now, for each warp, sum values + if (getLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + atomicAdd(&smem[i], counts[i]); + } + } + + __syncthreads(); + + // For each thread, read in the total counts +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = smem[i]; + } + + __syncthreads(); +} + +// Over what radix we are selecting values +constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +// This finds the unique value `v` that matches the pattern +// ((v & desired) == desiredMask) in our sorted int format +template +__device__ scalar_t findPattern( + scalar_t* smem, + scalar_t* data, + index_t sliceSize, + index_t withinSliceStride, + bitwise_t desired, + bitwise_t desiredMask) { + if (threadIdx.x < WARP_SIZE) { + smem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + // All threads participate in the loop, in order to sync on the flag + index_t numIterations = + THCRoundUp(sliceSize, static_cast(blockDim.x)); + for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? doLdg(&data[i * withinSliceStride]) + : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + // There should not be conflicts if we are using findPattern, + // since the result is unique + smem[0] = static_cast(1); + smem[1] = v; // can't use val as the flag, since it could be 0 + } + + __syncthreads(); + + scalar_t found = smem[0]; + scalar_t val = smem[1]; + + __syncthreads(); + + // Check to see if a thread found the value + if (THCNumerics::ne(found, static_cast(0))) { + // all threads return this value + return val; + } + } + + // should not get here + assert(false); + return static_cast(0); +} + +// Returns the top-Kth element found in the data using radix selection +template +__device__ void radixSelect( + scalar_t* data, + index_t k, + index_t sliceSize, + index_t withinSliceStride, + int* smem, + scalar_t* topK) { + // Per-thread buckets into which we accumulate digit counts in our + // radix + int counts[RADIX_SIZE]; + + // We only consider elements x such that (x & desiredMask) == desired + // Initially, we consider all elements of the array, so the above + // statement is true regardless of input. + bitwise_t desired = 0; + bitwise_t desiredMask = 0; + + // We are looking for the top kToFind-th element when iterating over + // digits; this count gets reduced by elimination when counting + // successive digits + int kToFind = k; + + // We start at the most significant digit in our radix, scanning + // through to the least significant digit +#pragma unroll + for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0; + digitPos -= RADIX_BITS) { + // Count radix distribution for the current position and reduce + // across all threads + countRadixUsingMask< + scalar_t, + bitwise_t, + index_t, + int, + RADIX_SIZE, + RADIX_BITS>( + counts, + smem, + desired, + desiredMask, + digitPos, + sliceSize, + withinSliceStride, + data); + + auto found_unique = [&](int i, int count) -> bool { + /* All threads have the same value in counts here, so all */ + /* threads will return from the function. */ + if (count == 1 && kToFind == 1) { + /* There is a unique answer. */ + desired = + Bitfield::setBitfield(desired, i, digitPos, RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The answer is now the unique element v such that: */ + /* (v & desiredMask) == desired */ + /* However, we do not yet know what the actual element is. We */ + /* need to perform a search through the data to find the */ + /* element that matches this pattern. */ + *topK = findPattern( + (scalar_t*)smem, + data, + sliceSize, + withinSliceStride, + desired, + desiredMask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= kToFind) { + desired = + Bitfield::setBitfield(desired, i, digitPos, RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The top-Kth element v must now be one such that: */ + /* (v & desiredMask == desired) */ + /* but we haven't narrowed it down; we must check the next */ + /* least-significant digit */ + return true; + } + kToFind -= count; + return false; // continue the loop + }; + + // All threads participate in the comparisons below to know the + // final result + if (Order) { + // Process in descending order +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } else { + // Process in ascending order +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } + } // end digitPos for + + // There is no unique result, but there is a non-unique result + // matching `desired` exactly + *topK = TopKTypeConfig::deconvert(desired); +} +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 11d50a0f618d..bb6f8cc380eb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1196,6 +1196,9 @@ variants: function, method - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices) + dispatch: + CPU: kthvalue_out_cpu + CUDA: kthvalue_out_cuda - func: layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor matches_jit_signature: True diff --git a/test/data/test_cuda_ignores.txt b/test/data/test_cuda_ignores.txt index 46016e9b294c..afad1c57906e 100644 --- a/test/data/test_cuda_ignores.txt +++ b/test/data/test_cuda_ignores.txt @@ -2,7 +2,6 @@ # These are skipped by test_cuda.py torch.ByteTensor.dist torch.ByteTensor.dot -torch.ByteTensor.kthvalue torch.ByteTensor.lerp torch.ByteTensor.lerp_ torch.ByteTensor.mean @@ -13,7 +12,6 @@ torch.ByteTensor.std torch.ByteTensor.var torch.CharTensor.dist torch.CharTensor.dot -torch.CharTensor.kthvalue torch.CharTensor.lerp torch.CharTensor.lerp_ torch.CharTensor.mean @@ -22,8 +20,6 @@ torch.CharTensor.renorm torch.CharTensor.renorm_ torch.CharTensor.std torch.CharTensor.var -torch.DoubleTensor.kthvalue -torch.FloatTensor.kthvalue torch.HalfTensor.chunk_ torch.HalfTensor.clone_ torch.HalfTensor.contiguous_ @@ -47,7 +43,6 @@ torch.HalfTensor.inverse_ torch.HalfTensor.is_contiguous_ torch.HalfTensor.is_same_size_ torch.HalfTensor.is_set_to_ -torch.HalfTensor.kthvalue torch.HalfTensor.kthvalue_ torch.HalfTensor.max_ torch.HalfTensor.mean_ @@ -87,7 +82,6 @@ torch.HalfTensor.zeros torch.HalfTensor.zeros_ torch.IntTensor.dist torch.IntTensor.dot -torch.IntTensor.kthvalue torch.IntTensor.lerp torch.IntTensor.lerp_ torch.IntTensor.mean @@ -98,7 +92,6 @@ torch.IntTensor.std torch.IntTensor.var torch.LongTensor.dist torch.LongTensor.dot -torch.LongTensor.kthvalue torch.LongTensor.lerp torch.LongTensor.lerp_ torch.LongTensor.mean @@ -109,7 +102,6 @@ torch.LongTensor.std torch.LongTensor.var torch.ShortTensor.dist torch.ShortTensor.dot -torch.ShortTensor.kthvalue torch.ShortTensor.lerp torch.ShortTensor.lerp_ torch.ShortTensor.mean