mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement kthvalue in ATen (#17544)
Summary: The CPU version is based on the TH version. The GPU version is based on #8406 by Pararth Shah (thank you). CPU quickselect based on that in TH's THTensorMoreMath.cpp, but with C++ (quickselectnoindex will be achieved by a different swap) CPU kthvalue is based on the THTensor function in the same file. The dim_apply function is a C++ replacement for TH_TENSOR_DIM_APPLYx macros. The CUDA kernel uses functions adapted from the THCTensorSortK implementation. In particular radixSelect is from THCTensorTopK.cuh. The CUDA launcher code replaces a bunch of macros with C++. It will be re-used in one of the following patches. Plan for further PRs: - This - Sort - TopK + Mode + Median in any order - Rip out THC stuff. There may be utility functions / structs in the SortingCommon.cuh that come into relevance only with sort. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17544 Differential Revision: D14286934 Pulled By: ezyang fbshipit-source-id: 35dbea050b097e88777ac5fa5c0f499d5e23c738
This commit is contained in:
committed by
Facebook Github Bot
parent
43f94077d8
commit
c6715eda06
@ -800,27 +800,6 @@
|
||||
- arg: bool keepdim
|
||||
default: "false"
|
||||
]]
|
||||
[[
|
||||
name: _th_kthvalue
|
||||
backends:
|
||||
- CPU
|
||||
variants: function
|
||||
cname: kthvalue
|
||||
return: argument 0,1
|
||||
scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1)
|
||||
arguments:
|
||||
- arg: THTensor* values
|
||||
output: True
|
||||
- arg: THIndexTensor* indices
|
||||
output: True
|
||||
- THTensor* self
|
||||
- long k
|
||||
- arg: long dim
|
||||
wrap_dim: self
|
||||
default: __last_dim
|
||||
- arg: bool keepdim
|
||||
default: "false"
|
||||
]]
|
||||
[[
|
||||
name: _th_mode
|
||||
variants: function
|
||||
|
195
aten/src/ATen/native/Sorting.cpp
Normal file
195
aten/src/ATen/native/Sorting.cpp
Normal file
@ -0,0 +1,195 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/native/SortingUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
// maybe these days, one should define a random access iterator and use
|
||||
// std::sort...
|
||||
/* Note from TH:
|
||||
|
||||
I cut and pasted (slightly adapted) the quicksort code from
|
||||
Sedgewick's 1978 "Implementing Quicksort Programs" article
|
||||
http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf
|
||||
|
||||
It is the state of the art existing implementation. The macros
|
||||
are here to make as close a match as possible to the pseudocode of
|
||||
Program 2 p.851
|
||||
|
||||
Note that other partition schemes exist, and are typically presented
|
||||
in textbook, but those are less efficient. See e.g.
|
||||
http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto
|
||||
|
||||
Julien, November 12th 2013
|
||||
*/
|
||||
|
||||
constexpr int64_t MAX_LEVELS = 300;
|
||||
constexpr int64_t M_SMALL = 10; // Limit for small subfiles
|
||||
|
||||
template <typename Fn>
|
||||
void dim_apply(TensorList tensors, int64_t dim, Fn f) {
|
||||
AT_ASSERT(tensors.size() > 0);
|
||||
auto t = tensors[0];
|
||||
auto sizes = t.sizes();
|
||||
int64_t ndim = t.dim();
|
||||
int64_t itersize = 1;
|
||||
for (int64_t i = 0; i < ndim; i++) {
|
||||
if (i != dim) {
|
||||
itersize *= t.size(i);
|
||||
}
|
||||
}
|
||||
parallel_for(0, itersize, 1, [&](int64_t i_begin, int64_t i_end) {
|
||||
std::vector<Tensor> narrowed_tensors;
|
||||
narrowed_tensors.reserve(tensors.size());
|
||||
for (int64_t it = i_begin; it < i_end; it++) {
|
||||
narrowed_tensors.clear();
|
||||
for (auto ti : tensors) {
|
||||
int64_t i = it;
|
||||
Tensor nt = ti;
|
||||
for (size_t d = 0; d < ndim; d++) {
|
||||
if (d != dim) {
|
||||
// this could be avoided for slower-changing dimensions if done
|
||||
// better
|
||||
nt = nt.select((d > dim ? 1 : 0), i % sizes[d]);
|
||||
i = i / sizes[d];
|
||||
}
|
||||
}
|
||||
narrowed_tensors.emplace_back(nt);
|
||||
}
|
||||
f(it, narrowed_tensors);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename Fn>
|
||||
void quick_select_template(
|
||||
TensorAccessor<scalar_t, 1> arr,
|
||||
int64_t k,
|
||||
Fn swap_fn) {
|
||||
int64_t P, L, R, i, j, swap;
|
||||
scalar_t rswap, piv;
|
||||
L = 0;
|
||||
R = arr.size(0) - 1;
|
||||
|
||||
do {
|
||||
if (R <= L) // One element only
|
||||
return;
|
||||
|
||||
if (R == L + 1) { // Two elements only
|
||||
if (arr[L] > arr[R]) {
|
||||
swap_fn(L, R);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use median of three for pivot choice
|
||||
P = (L + R) >> 1;
|
||||
swap_fn(P, L + 1);
|
||||
if (arr[L + 1] > arr[R]) {
|
||||
swap_fn(L + 1, R);
|
||||
}
|
||||
if (arr[L] > arr[R]) {
|
||||
swap_fn(L, R);
|
||||
}
|
||||
if (arr[L + 1] > arr[L]) {
|
||||
swap_fn(L + 1, L);
|
||||
}
|
||||
|
||||
i = L + 1;
|
||||
j = R;
|
||||
piv = arr[L];
|
||||
do {
|
||||
do
|
||||
i++;
|
||||
while (arr[i] < piv);
|
||||
do
|
||||
j--;
|
||||
while (arr[j] > piv);
|
||||
if (j < i)
|
||||
break;
|
||||
swap_fn(i, j);
|
||||
} while (1);
|
||||
swap_fn(L, j);
|
||||
|
||||
// Re-set active partition
|
||||
if (j <= k)
|
||||
L = i;
|
||||
if (j >= k)
|
||||
R = j - 1;
|
||||
} while (1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(
|
||||
Tensor& values,
|
||||
Tensor& indices,
|
||||
const Tensor& self,
|
||||
int64_t k,
|
||||
int64_t dim_,
|
||||
bool keepdim) {
|
||||
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
||||
// FIXME: This seems bogus, I only do this because it was the old behaviour.
|
||||
// The reductions are fine, as long as the axis being reduced along
|
||||
// isn't of 0 elements (and the output has elements).
|
||||
AT_CHECK(
|
||||
self.numel() > 0,
|
||||
"cannot perform reduction function kthvalue",
|
||||
" on tensor with no elements because the operation does not have an identity");
|
||||
AT_CHECK(
|
||||
k > 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
|
||||
"selected index k out of range");
|
||||
|
||||
_reduction_with_indices_allocate_or_resize_output(
|
||||
values, indices, self, dim_, keepdim);
|
||||
if (self.dim() == 0 && self.numel() == 1) {
|
||||
values.copy_(self);
|
||||
indices.zero_();
|
||||
return std::forward_as_tuple(values, indices);
|
||||
}
|
||||
auto tmp_values = self.clone();
|
||||
auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong));
|
||||
AT_DISPATCH_ALL_TYPES(self.type(), "kthvalue", [&] {
|
||||
dim_apply(
|
||||
{tmp_values, tmp_indices, values, indices},
|
||||
dim,
|
||||
[&](int64_t i, TensorList tl) {
|
||||
auto tmp_values = tl[0].accessor<scalar_t, 1>();
|
||||
auto tmp_indices = tl[1].accessor<int64_t, 1>();
|
||||
scalar_t* mode_value = tl[2].data<scalar_t>();
|
||||
int64_t* mode_index = tl[3].data<int64_t>();
|
||||
for (int64_t j = 0; j < tmp_indices.size(0); j++) {
|
||||
tmp_indices[j] = j;
|
||||
}
|
||||
quick_select_template(tmp_values, k - 1, [&](int64_t i, int64_t j) {
|
||||
std::swap(tmp_values[i], tmp_values[j]);
|
||||
std::swap(tmp_indices[i], tmp_indices[j]);
|
||||
});
|
||||
*mode_value = tmp_values[k - 1];
|
||||
*mode_index = tmp_indices[k - 1];
|
||||
});
|
||||
});
|
||||
if (!keepdim) {
|
||||
values.squeeze_(dim);
|
||||
indices.squeeze_(dim);
|
||||
}
|
||||
return std::forward_as_tuple(values, indices);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> kthvalue(
|
||||
const Tensor& self,
|
||||
int64_t k,
|
||||
int64_t dim,
|
||||
bool keepdim) {
|
||||
Tensor values = at::empty({0}, self.options());
|
||||
Tensor indices = at::empty({0}, self.options().dtype(kLong));
|
||||
at::kthvalue_out(values, indices, self, k, dim, keepdim);
|
||||
return std::make_tuple(values, indices);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
48
aten/src/ATen/native/SortingUtils.h
Normal file
48
aten/src/ATen/native/SortingUtils.h
Normal file
@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// ensure we get good values and indices for kthvalue, mode, median
|
||||
// this will always be with the reducing dim as 1-d
|
||||
static void _reduction_with_indices_allocate_or_resize_output(
|
||||
Tensor& values,
|
||||
Tensor& indices,
|
||||
const Tensor& self,
|
||||
int64_t dim_,
|
||||
bool keepdim) {
|
||||
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
||||
auto result_sizes = self.sizes().vec();
|
||||
if (result_sizes.size() > 0) {
|
||||
result_sizes[dim] = 1;
|
||||
}
|
||||
if (values.defined()) {
|
||||
AT_CHECK(
|
||||
self.type() == values.type(),
|
||||
"output values must be of same type as input");
|
||||
if (!keepdim && values.dim() == self.dim() - 1) {
|
||||
// unsqueeze to preserve passed in noncontiguous tensor in resize
|
||||
values.unsqueeze_(dim);
|
||||
}
|
||||
values.resize_(result_sizes);
|
||||
} else {
|
||||
values = at::empty(result_sizes, self.options());
|
||||
}
|
||||
if (indices.defined()) {
|
||||
AT_CHECK(
|
||||
indices.dtype() == kLong, "output indices must be of scalar type Long");
|
||||
AT_CHECK(
|
||||
indices.device() == self.device(),
|
||||
"output indices must be on same device as input");
|
||||
if (!keepdim && indices.dim() == self.dim() - 1) {
|
||||
// unsqueeze to preserve passed in noncontiguous tensor in resize
|
||||
indices.unsqueeze_(dim);
|
||||
}
|
||||
indices.resize_(result_sizes);
|
||||
} else {
|
||||
indices = at::empty(result_sizes, self.options().dtype(kLong));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -97,26 +97,6 @@ Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& o
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> kthvalue(const Tensor& self, int64_t k, int64_t dim, bool keepdim) {
|
||||
Tensor values = at::empty({0}, self.options());
|
||||
Tensor indices = at::empty({0}, self.options().dtype(kLong));
|
||||
return at::native::kthvalue_out(values, indices, self, k, dim, keepdim);
|
||||
}
|
||||
|
||||
std::tuple<Tensor &,Tensor &> kthvalue_out(Tensor& values, Tensor& indices,
|
||||
const Tensor& self, int64_t k, int64_t dim, bool keepdim) {
|
||||
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
|
||||
"kthvalue only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
|
||||
dim = maybe_wrap_dim(dim, self.dim());
|
||||
if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "kthvalue")) {
|
||||
AT_ASSERT(values.dim() == 0);
|
||||
indices.resize_({}).fill_(0);
|
||||
return std::forward_as_tuple(values, indices);
|
||||
} else {
|
||||
return at::legacy::th::_th_kthvalue_out(values, indices, self, k, dim, keepdim);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> median(const Tensor& self, int64_t dim, bool keepdim) {
|
||||
Tensor values = at::empty({0}, self.options());
|
||||
Tensor indices = at::empty({0}, self.options().dtype(kLong));
|
||||
|
226
aten/src/ATen/native/cuda/SortingCommon.cuh
Normal file
226
aten/src/ATen/native/cuda/SortingCommon.cuh
Normal file
@ -0,0 +1,226 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/SortingUtils.h>
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <stdlib.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/detail/TensorInfo.cuh>
|
||||
#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <THC/THCScanUtils.cuh>
|
||||
#include <THC/THCTensorMathReduce.cuh> // AddOp
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__)
|
||||
constexpr int WARP_SIZE = 64;
|
||||
constexpr int MAX_BLOCK_SIZE = 256;
|
||||
|
||||
#else
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr int MAX_BLOCK_SIZE = 1024;
|
||||
#endif
|
||||
|
||||
// Maximum size per grid dimension that we assume (compute capability >= 2.0)
|
||||
constexpr int64_t MAX_GRID_SIZE = 65535LL;
|
||||
|
||||
static bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
|
||||
if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
|
||||
int64_t gridY = 1;
|
||||
int64_t gridZ = 1;
|
||||
|
||||
if (gridTiles > MAX_GRID_SIZE) {
|
||||
gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE);
|
||||
gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
|
||||
|
||||
if (gridTiles > MAX_GRID_SIZE) {
|
||||
gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE);
|
||||
gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
|
||||
}
|
||||
}
|
||||
|
||||
grid = dim3(gridX, gridY, gridZ);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool handleNaN = false>
|
||||
struct ThrustGTOp {
|
||||
__device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
|
||||
return (handleNaN && THCNumerics<scalar_t>::isnan(lhs) &&
|
||||
!THCNumerics<scalar_t>::isnan(rhs)) ||
|
||||
THCNumerics<scalar_t>::gt(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool handleNaN = false>
|
||||
struct ThrustLTOp {
|
||||
__device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
|
||||
return (handleNaN && THCNumerics<scalar_t>::isnan(rhs) &&
|
||||
!THCNumerics<scalar_t>::isnan(lhs)) ||
|
||||
THCNumerics<scalar_t>::lt(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename index_t>
|
||||
__device__ __forceinline__ index_t getLinearBlockId() {
|
||||
return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x +
|
||||
blockIdx.x;
|
||||
}
|
||||
|
||||
// `base` is the base address of a tensor
|
||||
// For each slice (defined as a linear point of `out`, from 0 ->
|
||||
// (sliceSize - 1) * sliceStride, we fill that slice from `0` to
|
||||
// `sliceSize - 1`.
|
||||
template <typename index_t, int Dim>
|
||||
__global__ void fillSliceWithIndex_kernel(
|
||||
cuda::detail::TensorInfo<int64_t, index_t> out,
|
||||
index_t totalSlices,
|
||||
index_t sliceSize,
|
||||
index_t sliceStride) {
|
||||
index_t slice = getLinearBlockId<index_t>();
|
||||
|
||||
if (slice >= totalSlices) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t offset =
|
||||
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, out);
|
||||
int64_t* base = &out.data[offset];
|
||||
|
||||
for (int64_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
|
||||
// Torch indices are 1-based (hence the +1)
|
||||
base[i * sliceStride] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// For slice sorting in Thrust; extracts a slice index from a linear
|
||||
// index and uses that for comparison
|
||||
struct SliceComp {
|
||||
SliceComp(int64_t size) : sliceSize(size) {}
|
||||
|
||||
__device__ bool operator()(const int64_t& a, const int64_t& b) const {
|
||||
// Since the slices are guaranteed to be innermost,
|
||||
// the segment is just via int64_t division
|
||||
int64_t segA = a / sliceSize;
|
||||
int64_t segB = b / sliceSize;
|
||||
return segA < segB;
|
||||
}
|
||||
|
||||
const int64_t sliceSize;
|
||||
};
|
||||
|
||||
// For sorting in Thurst; extracts a within-slice index from a linear index
|
||||
struct GlobalIndexToPerSliceIndex {
|
||||
GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
|
||||
|
||||
__device__ inline void operator()(int64_t& v) const {
|
||||
v = v % sliceSize;
|
||||
}
|
||||
|
||||
const int64_t sliceSize;
|
||||
};
|
||||
|
||||
// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
|
||||
static uint64_t nextHighestPowerOf2(uint64_t n) {
|
||||
n--;
|
||||
n |= n >> 1;
|
||||
n |= n >> 2;
|
||||
n |= n >> 4;
|
||||
n |= n >> 8;
|
||||
n |= n >> 16;
|
||||
#ifndef _MSC_VER
|
||||
n |= n >> 32;
|
||||
#endif
|
||||
n++;
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, typename index_t, typename Launcher>
|
||||
void run_launcher(
|
||||
Tensor& values,
|
||||
Tensor& indices,
|
||||
const Tensor& self,
|
||||
int64_t dim,
|
||||
Launcher l) {
|
||||
auto self_info = cuda::detail::getTensorInfo<scalar_t, index_t>(self);
|
||||
auto values_info = cuda::detail::getTensorInfo<scalar_t, index_t>(values);
|
||||
auto indices_info = cuda::detail::getTensorInfo<int64_t, index_t>(indices);
|
||||
|
||||
int64_t slice_size = self.size(dim);
|
||||
/* We use these structures solely to find the offset to */
|
||||
/* each slice we are operating on */
|
||||
self_info.reduceDim(dim);
|
||||
values_info.reduceDim(dim);
|
||||
indices_info.reduceDim(dim);
|
||||
|
||||
/* Collapse all other dims */
|
||||
int collapse_self_dim = self_info.collapseDims(dim);
|
||||
int collapse_values_dim = values_info.collapseDims(dim);
|
||||
int collapse_indices_dim = indices_info.collapseDims(dim);
|
||||
|
||||
int64_t num_slices = 1;
|
||||
for (int i = 0; i < self_info.dims; ++i) {
|
||||
num_slices *= self_info.sizes[i];
|
||||
}
|
||||
|
||||
/* This is used as a template parameter to calculate indices. */
|
||||
/* We only specialize it if all collapsed dim sizes are the */
|
||||
/* same; otherwise, we use -1 which is the specialization */
|
||||
/* parameter for arbitrary dimensions */
|
||||
int all_dims = self_info.dims;
|
||||
if (values_info.dims != all_dims || indices_info.dims != all_dims) {
|
||||
all_dims = -1;
|
||||
}
|
||||
|
||||
if (all_dims == 1) {
|
||||
l.template launch<scalar_t, index_t, 1>(
|
||||
values_info,
|
||||
collapse_values_dim,
|
||||
indices_info,
|
||||
collapse_indices_dim,
|
||||
self_info,
|
||||
collapse_self_dim,
|
||||
num_slices,
|
||||
slice_size);
|
||||
} else if (all_dims == 2) {
|
||||
l.template launch<scalar_t, index_t, 2>(
|
||||
values_info,
|
||||
collapse_values_dim,
|
||||
indices_info,
|
||||
collapse_indices_dim,
|
||||
self_info,
|
||||
collapse_self_dim,
|
||||
num_slices,
|
||||
slice_size);
|
||||
} else if (all_dims == 3) {
|
||||
l.template launch<scalar_t, index_t, 3>(
|
||||
values_info,
|
||||
collapse_values_dim,
|
||||
indices_info,
|
||||
collapse_indices_dim,
|
||||
self_info,
|
||||
collapse_self_dim,
|
||||
num_slices,
|
||||
slice_size);
|
||||
} else {
|
||||
l.template launch<scalar_t, index_t, -1>(
|
||||
values_info,
|
||||
collapse_values_dim,
|
||||
indices_info,
|
||||
collapse_indices_dim,
|
||||
self_info,
|
||||
collapse_self_dim,
|
||||
num_slices,
|
||||
slice_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
249
aten/src/ATen/native/cuda/SortingKthValue.cu
Normal file
249
aten/src/ATen/native/cuda/SortingKthValue.cu
Normal file
@ -0,0 +1,249 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/SortingUtils.h>
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <stdlib.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/detail/TensorInfo.cuh>
|
||||
#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <THC/THCScanUtils.cuh>
|
||||
#include <THC/THCTensorMathReduce.cuh> // AddOp
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/sort.h>
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/inner_product.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
#include <ATen/native/cuda/SortingCommon.cuh>
|
||||
#include <ATen/native/cuda/SortingRadixSelect.cuh>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
template <typename scalar_t, typename index_t, int Dim>
|
||||
__global__ void gatherKthValue(
|
||||
cuda::detail::TensorInfo<scalar_t, index_t> input,
|
||||
index_t inputSliceSize,
|
||||
index_t k,
|
||||
|
||||
index_t numInputSlices,
|
||||
index_t inputWithinSliceStride,
|
||||
|
||||
cuda::detail::TensorInfo<scalar_t, index_t> kthValue,
|
||||
cuda::detail::TensorInfo<int64_t, index_t> indices) {
|
||||
// Indices are limited to integer fp precision, so counts can fit in
|
||||
// int32, regardless of index_t
|
||||
__shared__ int smem[WARP_SIZE]; // one per each warp, up to warp limit
|
||||
|
||||
index_t slice = getLinearBlockId<index_t>();
|
||||
if (slice >= numInputSlices) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the start offset for our slice
|
||||
index_t sliceStartIndex =
|
||||
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
|
||||
index_t kthValueSliceStartIndex =
|
||||
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, kthValue);
|
||||
index_t indicesSliceStartIndex =
|
||||
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
|
||||
|
||||
scalar_t* inputSliceStart = &input.data[sliceStartIndex];
|
||||
scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex];
|
||||
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
|
||||
|
||||
// Find the k-th highest element in our input
|
||||
scalar_t kValue = static_cast<scalar_t>(0);
|
||||
radixSelect<
|
||||
scalar_t,
|
||||
typename TopKTypeConfig<scalar_t>::RadixType,
|
||||
index_t,
|
||||
false>(
|
||||
inputSliceStart,
|
||||
k,
|
||||
inputSliceSize,
|
||||
inputWithinSliceStride,
|
||||
smem,
|
||||
&kValue);
|
||||
|
||||
// Find the index of the k-th highest element
|
||||
index_t kValueIndex = 0;
|
||||
bool foundKValue = false;
|
||||
|
||||
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
|
||||
bool inRange = (i < inputSliceSize);
|
||||
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
|
||||
: static_cast<scalar_t>(0);
|
||||
bool isKValue = inRange && (THCNumerics<scalar_t>::eq(v, kValue));
|
||||
|
||||
if (isKValue) {
|
||||
kValueIndex = i;
|
||||
foundKValue = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundKValue) {
|
||||
kthValueSliceStart[0] = kValue;
|
||||
indicesSliceStart[0] = kValueIndex;
|
||||
}
|
||||
}
|
||||
|
||||
struct KthValueLauncher {
|
||||
int64_t k;
|
||||
|
||||
KthValueLauncher(int64_t k) : k(k) {}
|
||||
|
||||
template <typename scalar_t, typename index_t, int all_dims>
|
||||
inline void launch(
|
||||
cuda::detail::TensorInfo<scalar_t, index_t> values_info,
|
||||
int collapse_values_dim,
|
||||
cuda::detail::TensorInfo<int64_t, index_t> indices_info,
|
||||
int collapse_indices_dim,
|
||||
cuda::detail::TensorInfo<scalar_t, index_t> self_info,
|
||||
int collapse_self_dim,
|
||||
int64_t num_slices,
|
||||
int64_t slice_size) {
|
||||
dim3 grid;
|
||||
if (!getGridFromTiles(num_slices, grid)) {
|
||||
AT_ERROR("slices are too many");
|
||||
}
|
||||
|
||||
dim3 block(
|
||||
std::min(THCRoundUp(slice_size, (int64_t)WARP_SIZE), (int64_t)1024));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
gatherKthValue<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
|
||||
self_info,
|
||||
slice_size,
|
||||
k,
|
||||
num_slices,
|
||||
/* The actual dimension that the k-selection is running in */
|
||||
/* may have changed from collapseDims() */
|
||||
self_info.strides[collapse_self_dim],
|
||||
values_info,
|
||||
indices_info);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
void kthvalue_cuda_template(
|
||||
Tensor& values,
|
||||
Tensor& indices,
|
||||
const Tensor& self,
|
||||
int64_t k,
|
||||
int64_t dim_,
|
||||
bool keepdim) {
|
||||
int64_t dim = maybe_wrap_dim(dim_, self.dim());
|
||||
int64_t slicesize = self.size(dim);
|
||||
// FIXME: This seems bogus, I only do this because it was the old behaviour.
|
||||
// The reductions are fine, as long as the axis being reduced along
|
||||
// isn't of 0 elements (and the output has elements).
|
||||
AT_CHECK(
|
||||
self.numel() > 0,
|
||||
"cannot perform reduction function kthvalue",
|
||||
" on tensor with no elements because the operation does not have an identity");
|
||||
AT_CHECK(k >= 1 && k <= slicesize, "selected number k out of range");
|
||||
|
||||
_reduction_with_indices_allocate_or_resize_output(
|
||||
values, indices, self, dim, keepdim);
|
||||
if (self.dim() == 0 && self.numel() == 1) {
|
||||
values.copy_(self);
|
||||
indices.zero_();
|
||||
return;
|
||||
}
|
||||
|
||||
AT_CHECK(
|
||||
self.dim() <= MAX_TENSORINFO_DIMS,
|
||||
"cannot operate on more than ",
|
||||
MAX_TENSORINFO_DIMS,
|
||||
" dimensions");
|
||||
|
||||
// Based on required index size, run the algorithm with the
|
||||
// appropriate index type
|
||||
if (cuda::detail::canUse32BitIndexMath(self) &&
|
||||
cuda::detail::canUse32BitIndexMath(values) &&
|
||||
cuda::detail::canUse32BitIndexMath(indices)) {
|
||||
run_launcher<scalar_t, uint32_t>(
|
||||
values, indices, self, dim, KthValueLauncher(k));
|
||||
} else {
|
||||
run_launcher<scalar_t, uint64_t>(
|
||||
values, indices, self, dim, KthValueLauncher(k));
|
||||
}
|
||||
|
||||
if (!keepdim) {
|
||||
values.squeeze_(dim);
|
||||
indices.squeeze_(dim);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
// this does not reduce to median with dim beause we don't want to copy twice
|
||||
template <typename scalar_t>
|
||||
Tensor median_cuda_template(const Tensor& self) {
|
||||
AT_CHECK(self.numel() > 0, "median cannot be called with empty tensor");
|
||||
if (self.dim() == 0 && self.numel() == 1) {
|
||||
return self.clone();
|
||||
}
|
||||
auto self_copy = self.clone().view(-1);
|
||||
auto values = at::empty({1}, self.options());
|
||||
auto indices = at::empty({1}, self.options().dtype(kLong));
|
||||
AT_CHECK(
|
||||
self.dim() <= MAX_TENSORINFO_DIMS,
|
||||
"cannot operate on more than ",
|
||||
MAX_TENSORINFO_DIMS,
|
||||
" dimensions");
|
||||
|
||||
// Based on required index size, run the algorithm with the
|
||||
// appropriate index type
|
||||
if (cuda::detail::canUse32BitIndexMath(self) &&
|
||||
cuda::detail::canUse32BitIndexMath(values) &&
|
||||
cuda::detail::canUse32BitIndexMath(indices)) {
|
||||
run_launcher<scalar_t, uint32_t>(
|
||||
values,
|
||||
indices,
|
||||
self_copy,
|
||||
0,
|
||||
KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
|
||||
} else {
|
||||
run_launcher<scalar_t, uint64_t>(
|
||||
values,
|
||||
indices,
|
||||
self_copy,
|
||||
0,
|
||||
KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
|
||||
}
|
||||
return values.view({});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
|
||||
Tensor& values,
|
||||
Tensor& indices,
|
||||
const Tensor& self,
|
||||
int64_t k,
|
||||
int64_t dim,
|
||||
bool keepdim) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "kthvalue", [&] {
|
||||
kthvalue_cuda_template<scalar_t>(values, indices, self, k, dim, keepdim);
|
||||
});
|
||||
return std::forward_as_tuple(values, indices);
|
||||
}
|
||||
|
||||
Tensor median_cuda(const Tensor& self) {
|
||||
return AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "median", [&] {
|
||||
return median_cuda_template<scalar_t>(self);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
392
aten/src/ATen/native/cuda/SortingRadixSelect.cuh
Normal file
392
aten/src/ATen/native/cuda/SortingRadixSelect.cuh
Normal file
@ -0,0 +1,392 @@
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
template <typename scalar_t>
|
||||
struct TopKTypeConfig {};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<float> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
// Converts a float to an integer representation with the same
|
||||
// sorting; i.e., for floats f1, f2:
|
||||
// if f1 < f2 then convert(f1) < convert(f2)
|
||||
// We use this to enable radix selection of floating-point values.
|
||||
// This also gives a relative order for NaNs, but that's ok, as they
|
||||
// will all be adjacent
|
||||
static inline __device__ RadixType convert(float v) {
|
||||
RadixType x = __float_as_int(v);
|
||||
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
|
||||
|
||||
return (x ^ mask);
|
||||
}
|
||||
|
||||
static inline __device__ float deconvert(RadixType v) {
|
||||
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
|
||||
|
||||
return __int_as_float(v ^ mask);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<uint8_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType convert(uint8_t v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
static inline __device__ uint8_t deconvert(RadixType v) {
|
||||
return v;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<int8_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType convert(int8_t v) {
|
||||
return 128u + v;
|
||||
}
|
||||
|
||||
static inline __device__ int8_t deconvert(RadixType v) {
|
||||
return v - 128;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<int16_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType convert(int16_t v) {
|
||||
assert(sizeof(short) == 2);
|
||||
return 32768u + v;
|
||||
}
|
||||
|
||||
static inline __device__ int16_t deconvert(RadixType v) {
|
||||
return v - 32768;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<int32_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType convert(int32_t v) {
|
||||
assert(sizeof(int) == 4);
|
||||
return 2147483648u + v;
|
||||
}
|
||||
|
||||
static inline __device__ int32_t deconvert(RadixType v) {
|
||||
return v - 2147483648u;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<int64_t> {
|
||||
typedef uint64_t RadixType;
|
||||
|
||||
static inline __device__ RadixType convert(int64_t v) {
|
||||
assert(sizeof(int64_t) == 8);
|
||||
return 9223372036854775808ull + v;
|
||||
}
|
||||
|
||||
static inline __device__ int64_t deconvert(RadixType v) {
|
||||
return v - 9223372036854775808ull;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<double> {
|
||||
typedef uint64_t RadixType;
|
||||
|
||||
static inline __device__ RadixType convert(double v) {
|
||||
RadixType x = __double_as_longlong(v);
|
||||
RadixType mask = -((x >> 63)) | 0x8000000000000000;
|
||||
return (x ^ mask);
|
||||
}
|
||||
|
||||
static inline __device__ double deconvert(RadixType v) {
|
||||
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
|
||||
return __longlong_as_double(v ^ mask);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TopKTypeConfig<at::Half> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType convert(at::Half v) {
|
||||
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
|
||||
RadixType x = __half_as_ushort(v);
|
||||
RadixType mask = -((x >> 15)) | 0x8000;
|
||||
return (x ^ mask);
|
||||
#else
|
||||
assert(false);
|
||||
return 0u;
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline __device__ at::Half deconvert(RadixType v) {
|
||||
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
|
||||
RadixType mask = ((v >> 15) - 1) | 0x8000;
|
||||
return __ushort_as_half(v ^ mask);
|
||||
#else
|
||||
assert(false);
|
||||
return static_cast<at::Half>(0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// This function counts the distribution of all input values in a
|
||||
// slice we are selecting by radix digit at `radixDigitPos`, but only
|
||||
// those that pass the filter `((v & desiredMask) == desired)`.
|
||||
// This produces and broadcasts the seen counts for a single block only.
|
||||
// `smem` must have at least `RadixSize` elements.
|
||||
template <
|
||||
typename scalar_t,
|
||||
typename bitwise_t,
|
||||
typename index_t,
|
||||
typename CountType,
|
||||
int RadixSize,
|
||||
int RadixBits>
|
||||
__device__ void countRadixUsingMask(
|
||||
CountType counts[RadixSize],
|
||||
CountType* smem,
|
||||
bitwise_t desired,
|
||||
bitwise_t desiredMask,
|
||||
int radixDigitPos,
|
||||
index_t sliceSize,
|
||||
index_t withinSliceStride,
|
||||
scalar_t* data) {
|
||||
// Clear out per-thread counts from a previous round
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RadixSize; ++i) {
|
||||
counts[i] = 0;
|
||||
}
|
||||
|
||||
if (threadIdx.x < RadixSize) {
|
||||
smem[threadIdx.x] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Scan over all the data. Upon a read, the warp will accumulate
|
||||
// counts per each digit in the radix using warp voting.
|
||||
for (index_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
|
||||
bitwise_t val =
|
||||
TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));
|
||||
|
||||
bool hasVal = ((val & desiredMask) == desired);
|
||||
bitwise_t digitInRadix =
|
||||
Bitfield<bitwise_t>::getBitfield(val, radixDigitPos, RadixBits);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < RadixSize; ++j) {
|
||||
bool vote = hasVal && (digitInRadix == j);
|
||||
#if defined(__HIP_PLATFORM_HCC__)
|
||||
counts[j] += __popcll(WARP_BALLOT(vote));
|
||||
#else
|
||||
counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// Now, for each warp, sum values
|
||||
if (getLaneId() == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < RadixSize; ++i) {
|
||||
atomicAdd(&smem[i], counts[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For each thread, read in the total counts
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < RadixSize; ++i) {
|
||||
counts[i] = smem[i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Over what radix we are selecting values
|
||||
constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
|
||||
constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
|
||||
constexpr int RADIX_MASK = (RADIX_SIZE - 1);
|
||||
|
||||
// This finds the unique value `v` that matches the pattern
|
||||
// ((v & desired) == desiredMask) in our sorted int format
|
||||
template <typename scalar_t, typename bitwise_t, typename index_t>
|
||||
__device__ scalar_t findPattern(
|
||||
scalar_t* smem,
|
||||
scalar_t* data,
|
||||
index_t sliceSize,
|
||||
index_t withinSliceStride,
|
||||
bitwise_t desired,
|
||||
bitwise_t desiredMask) {
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
smem[threadIdx.x] = static_cast<scalar_t>(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// All threads participate in the loop, in order to sync on the flag
|
||||
index_t numIterations =
|
||||
THCRoundUp(sliceSize, static_cast<index_t>(blockDim.x));
|
||||
for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) {
|
||||
bool inRange = (i < sliceSize);
|
||||
scalar_t v = inRange ? doLdg(&data[i * withinSliceStride])
|
||||
: static_cast<scalar_t>(0);
|
||||
|
||||
if (inRange &&
|
||||
((TopKTypeConfig<scalar_t>::convert(v) & desiredMask) == desired)) {
|
||||
// There should not be conflicts if we are using findPattern,
|
||||
// since the result is unique
|
||||
smem[0] = static_cast<scalar_t>(1);
|
||||
smem[1] = v; // can't use val as the flag, since it could be 0
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
scalar_t found = smem[0];
|
||||
scalar_t val = smem[1];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Check to see if a thread found the value
|
||||
if (THCNumerics<scalar_t>::ne(found, static_cast<scalar_t>(0))) {
|
||||
// all threads return this value
|
||||
return val;
|
||||
}
|
||||
}
|
||||
|
||||
// should not get here
|
||||
assert(false);
|
||||
return static_cast<scalar_t>(0);
|
||||
}
|
||||
|
||||
// Returns the top-Kth element found in the data using radix selection
|
||||
template <typename scalar_t, typename bitwise_t, typename index_t, bool Order>
|
||||
__device__ void radixSelect(
|
||||
scalar_t* data,
|
||||
index_t k,
|
||||
index_t sliceSize,
|
||||
index_t withinSliceStride,
|
||||
int* smem,
|
||||
scalar_t* topK) {
|
||||
// Per-thread buckets into which we accumulate digit counts in our
|
||||
// radix
|
||||
int counts[RADIX_SIZE];
|
||||
|
||||
// We only consider elements x such that (x & desiredMask) == desired
|
||||
// Initially, we consider all elements of the array, so the above
|
||||
// statement is true regardless of input.
|
||||
bitwise_t desired = 0;
|
||||
bitwise_t desiredMask = 0;
|
||||
|
||||
// We are looking for the top kToFind-th element when iterating over
|
||||
// digits; this count gets reduced by elimination when counting
|
||||
// successive digits
|
||||
int kToFind = k;
|
||||
|
||||
// We start at the most significant digit in our radix, scanning
|
||||
// through to the least significant digit
|
||||
#pragma unroll
|
||||
for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0;
|
||||
digitPos -= RADIX_BITS) {
|
||||
// Count radix distribution for the current position and reduce
|
||||
// across all threads
|
||||
countRadixUsingMask<
|
||||
scalar_t,
|
||||
bitwise_t,
|
||||
index_t,
|
||||
int,
|
||||
RADIX_SIZE,
|
||||
RADIX_BITS>(
|
||||
counts,
|
||||
smem,
|
||||
desired,
|
||||
desiredMask,
|
||||
digitPos,
|
||||
sliceSize,
|
||||
withinSliceStride,
|
||||
data);
|
||||
|
||||
auto found_unique = [&](int i, int count) -> bool {
|
||||
/* All threads have the same value in counts here, so all */
|
||||
/* threads will return from the function. */
|
||||
if (count == 1 && kToFind == 1) {
|
||||
/* There is a unique answer. */
|
||||
desired =
|
||||
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
|
||||
desiredMask = Bitfield<bitwise_t>::setBitfield(
|
||||
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
|
||||
|
||||
/* The answer is now the unique element v such that: */
|
||||
/* (v & desiredMask) == desired */
|
||||
/* However, we do not yet know what the actual element is. We */
|
||||
/* need to perform a search through the data to find the */
|
||||
/* element that matches this pattern. */
|
||||
*topK = findPattern<scalar_t, bitwise_t, index_t>(
|
||||
(scalar_t*)smem,
|
||||
data,
|
||||
sliceSize,
|
||||
withinSliceStride,
|
||||
desired,
|
||||
desiredMask);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto found_non_unique = [&](int i, int count) -> bool {
|
||||
if (count >= kToFind) {
|
||||
desired =
|
||||
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
|
||||
desiredMask = Bitfield<bitwise_t>::setBitfield(
|
||||
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
|
||||
|
||||
/* The top-Kth element v must now be one such that: */
|
||||
/* (v & desiredMask == desired) */
|
||||
/* but we haven't narrowed it down; we must check the next */
|
||||
/* least-significant digit */
|
||||
return true;
|
||||
}
|
||||
kToFind -= count;
|
||||
return false; // continue the loop
|
||||
};
|
||||
|
||||
// All threads participate in the comparisons below to know the
|
||||
// final result
|
||||
if (Order) {
|
||||
// Process in descending order
|
||||
#pragma unroll
|
||||
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
|
||||
int count = counts[i];
|
||||
if (found_unique(i, count)) {
|
||||
return;
|
||||
}
|
||||
if (found_non_unique(i, count)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Process in ascending order
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RADIX_SIZE; ++i) {
|
||||
int count = counts[i];
|
||||
if (found_unique(i, count)) {
|
||||
return;
|
||||
}
|
||||
if (found_non_unique(i, count)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end digitPos for
|
||||
|
||||
// There is no unique result, but there is a non-unique result
|
||||
// matching `desired` exactly
|
||||
*topK = TopKTypeConfig<scalar_t>::deconvert(desired);
|
||||
}
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -1196,6 +1196,9 @@
|
||||
variants: function, method
|
||||
|
||||
- func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices)
|
||||
dispatch:
|
||||
CPU: kthvalue_out_cpu
|
||||
CUDA: kthvalue_out_cuda
|
||||
|
||||
- func: layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
|
||||
matches_jit_signature: True
|
||||
|
@ -2,7 +2,6 @@
|
||||
# These are skipped by test_cuda.py
|
||||
torch.ByteTensor.dist
|
||||
torch.ByteTensor.dot
|
||||
torch.ByteTensor.kthvalue
|
||||
torch.ByteTensor.lerp
|
||||
torch.ByteTensor.lerp_
|
||||
torch.ByteTensor.mean
|
||||
@ -13,7 +12,6 @@ torch.ByteTensor.std
|
||||
torch.ByteTensor.var
|
||||
torch.CharTensor.dist
|
||||
torch.CharTensor.dot
|
||||
torch.CharTensor.kthvalue
|
||||
torch.CharTensor.lerp
|
||||
torch.CharTensor.lerp_
|
||||
torch.CharTensor.mean
|
||||
@ -22,8 +20,6 @@ torch.CharTensor.renorm
|
||||
torch.CharTensor.renorm_
|
||||
torch.CharTensor.std
|
||||
torch.CharTensor.var
|
||||
torch.DoubleTensor.kthvalue
|
||||
torch.FloatTensor.kthvalue
|
||||
torch.HalfTensor.chunk_
|
||||
torch.HalfTensor.clone_
|
||||
torch.HalfTensor.contiguous_
|
||||
@ -47,7 +43,6 @@ torch.HalfTensor.inverse_
|
||||
torch.HalfTensor.is_contiguous_
|
||||
torch.HalfTensor.is_same_size_
|
||||
torch.HalfTensor.is_set_to_
|
||||
torch.HalfTensor.kthvalue
|
||||
torch.HalfTensor.kthvalue_
|
||||
torch.HalfTensor.max_
|
||||
torch.HalfTensor.mean_
|
||||
@ -87,7 +82,6 @@ torch.HalfTensor.zeros
|
||||
torch.HalfTensor.zeros_
|
||||
torch.IntTensor.dist
|
||||
torch.IntTensor.dot
|
||||
torch.IntTensor.kthvalue
|
||||
torch.IntTensor.lerp
|
||||
torch.IntTensor.lerp_
|
||||
torch.IntTensor.mean
|
||||
@ -98,7 +92,6 @@ torch.IntTensor.std
|
||||
torch.IntTensor.var
|
||||
torch.LongTensor.dist
|
||||
torch.LongTensor.dot
|
||||
torch.LongTensor.kthvalue
|
||||
torch.LongTensor.lerp
|
||||
torch.LongTensor.lerp_
|
||||
torch.LongTensor.mean
|
||||
@ -109,7 +102,6 @@ torch.LongTensor.std
|
||||
torch.LongTensor.var
|
||||
torch.ShortTensor.dist
|
||||
torch.ShortTensor.dot
|
||||
torch.ShortTensor.kthvalue
|
||||
torch.ShortTensor.lerp
|
||||
torch.ShortTensor.lerp_
|
||||
torch.ShortTensor.mean
|
||||
|
Reference in New Issue
Block a user