Implement kthvalue in ATen (#17544)

Summary:
The CPU version is based on the TH version.
The GPU version is based on #8406 by Pararth Shah (thank you).

CPU quickselect based on that in TH's THTensorMoreMath.cpp, but with C++ (quickselectnoindex will be achieved by a different swap)
CPU kthvalue is based on the THTensor function in the same file.
The dim_apply function is a C++ replacement for TH_TENSOR_DIM_APPLYx macros.
The CUDA kernel uses functions adapted from the THCTensorSortK implementation.
In particular radixSelect is from THCTensorTopK.cuh.
The CUDA launcher code replaces a bunch of macros with C++. It will be re-used in one of the following patches.

Plan for further PRs:
- This
- Sort
- TopK + Mode + Median in any order
- Rip out THC stuff.

There may be utility functions / structs in the SortingCommon.cuh that come into
relevance only with sort.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17544

Differential Revision: D14286934

Pulled By: ezyang

fbshipit-source-id: 35dbea050b097e88777ac5fa5c0f499d5e23c738
This commit is contained in:
Thomas Viehmann
2019-03-01 18:57:02 -08:00
committed by Facebook Github Bot
parent 43f94077d8
commit c6715eda06
9 changed files with 1113 additions and 49 deletions

View File

@ -800,27 +800,6 @@
- arg: bool keepdim
default: "false"
]]
[[
name: _th_kthvalue
backends:
- CPU
variants: function
cname: kthvalue
return: argument 0,1
scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1)
arguments:
- arg: THTensor* values
output: True
- arg: THIndexTensor* indices
output: True
- THTensor* self
- long k
- arg: long dim
wrap_dim: self
default: __last_dim
- arg: bool keepdim
default: "false"
]]
[[
name: _th_mode
variants: function

View File

@ -0,0 +1,195 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/native/SortingUtils.h>
namespace at {
namespace native {
namespace {
// maybe these days, one should define a random access iterator and use
// std::sort...
/* Note from TH:
I cut and pasted (slightly adapted) the quicksort code from
Sedgewick's 1978 "Implementing Quicksort Programs" article
http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf
It is the state of the art existing implementation. The macros
are here to make as close a match as possible to the pseudocode of
Program 2 p.851
Note that other partition schemes exist, and are typically presented
in textbook, but those are less efficient. See e.g.
http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto
Julien, November 12th 2013
*/
constexpr int64_t MAX_LEVELS = 300;
constexpr int64_t M_SMALL = 10; // Limit for small subfiles
template <typename Fn>
void dim_apply(TensorList tensors, int64_t dim, Fn f) {
AT_ASSERT(tensors.size() > 0);
auto t = tensors[0];
auto sizes = t.sizes();
int64_t ndim = t.dim();
int64_t itersize = 1;
for (int64_t i = 0; i < ndim; i++) {
if (i != dim) {
itersize *= t.size(i);
}
}
parallel_for(0, itersize, 1, [&](int64_t i_begin, int64_t i_end) {
std::vector<Tensor> narrowed_tensors;
narrowed_tensors.reserve(tensors.size());
for (int64_t it = i_begin; it < i_end; it++) {
narrowed_tensors.clear();
for (auto ti : tensors) {
int64_t i = it;
Tensor nt = ti;
for (size_t d = 0; d < ndim; d++) {
if (d != dim) {
// this could be avoided for slower-changing dimensions if done
// better
nt = nt.select((d > dim ? 1 : 0), i % sizes[d]);
i = i / sizes[d];
}
}
narrowed_tensors.emplace_back(nt);
}
f(it, narrowed_tensors);
}
});
}
template <typename scalar_t, typename Fn>
void quick_select_template(
TensorAccessor<scalar_t, 1> arr,
int64_t k,
Fn swap_fn) {
int64_t P, L, R, i, j, swap;
scalar_t rswap, piv;
L = 0;
R = arr.size(0) - 1;
do {
if (R <= L) // One element only
return;
if (R == L + 1) { // Two elements only
if (arr[L] > arr[R]) {
swap_fn(L, R);
}
return;
}
// Use median of three for pivot choice
P = (L + R) >> 1;
swap_fn(P, L + 1);
if (arr[L + 1] > arr[R]) {
swap_fn(L + 1, R);
}
if (arr[L] > arr[R]) {
swap_fn(L, R);
}
if (arr[L + 1] > arr[L]) {
swap_fn(L + 1, L);
}
i = L + 1;
j = R;
piv = arr[L];
do {
do
i++;
while (arr[i] < piv);
do
j--;
while (arr[j] > piv);
if (j < i)
break;
swap_fn(i, j);
} while (1);
swap_fn(L, j);
// Re-set active partition
if (j <= k)
L = i;
if (j >= k)
R = j - 1;
} while (1);
}
} // namespace
std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
// FIXME: This seems bogus, I only do this because it was the old behaviour.
// The reductions are fine, as long as the axis being reduced along
// isn't of 0 elements (and the output has elements).
AT_CHECK(
self.numel() > 0,
"cannot perform reduction function kthvalue",
" on tensor with no elements because the operation does not have an identity");
AT_CHECK(
k > 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
_reduction_with_indices_allocate_or_resize_output(
values, indices, self, dim_, keepdim);
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return std::forward_as_tuple(values, indices);
}
auto tmp_values = self.clone();
auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong));
AT_DISPATCH_ALL_TYPES(self.type(), "kthvalue", [&] {
dim_apply(
{tmp_values, tmp_indices, values, indices},
dim,
[&](int64_t i, TensorList tl) {
auto tmp_values = tl[0].accessor<scalar_t, 1>();
auto tmp_indices = tl[1].accessor<int64_t, 1>();
scalar_t* mode_value = tl[2].data<scalar_t>();
int64_t* mode_index = tl[3].data<int64_t>();
for (int64_t j = 0; j < tmp_indices.size(0); j++) {
tmp_indices[j] = j;
}
quick_select_template(tmp_values, k - 1, [&](int64_t i, int64_t j) {
std::swap(tmp_values[i], tmp_values[j]);
std::swap(tmp_indices[i], tmp_indices[j]);
});
*mode_value = tmp_values[k - 1];
*mode_index = tmp_indices[k - 1];
});
});
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor, Tensor> kthvalue(
const Tensor& self,
int64_t k,
int64_t dim,
bool keepdim) {
Tensor values = at::empty({0}, self.options());
Tensor indices = at::empty({0}, self.options().dtype(kLong));
at::kthvalue_out(values, indices, self, k, dim, keepdim);
return std::make_tuple(values, indices);
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,48 @@
#pragma once
namespace at {
namespace native {
// ensure we get good values and indices for kthvalue, mode, median
// this will always be with the reducing dim as 1-d
static void _reduction_with_indices_allocate_or_resize_output(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
auto result_sizes = self.sizes().vec();
if (result_sizes.size() > 0) {
result_sizes[dim] = 1;
}
if (values.defined()) {
AT_CHECK(
self.type() == values.type(),
"output values must be of same type as input");
if (!keepdim && values.dim() == self.dim() - 1) {
// unsqueeze to preserve passed in noncontiguous tensor in resize
values.unsqueeze_(dim);
}
values.resize_(result_sizes);
} else {
values = at::empty(result_sizes, self.options());
}
if (indices.defined()) {
AT_CHECK(
indices.dtype() == kLong, "output indices must be of scalar type Long");
AT_CHECK(
indices.device() == self.device(),
"output indices must be on same device as input");
if (!keepdim && indices.dim() == self.dim() - 1) {
// unsqueeze to preserve passed in noncontiguous tensor in resize
indices.unsqueeze_(dim);
}
indices.resize_(result_sizes);
} else {
indices = at::empty(result_sizes, self.options().dtype(kLong));
}
}
} // namespace native
} // namespace at

View File

@ -97,26 +97,6 @@ Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& o
return ret;
}
std::tuple<Tensor, Tensor> kthvalue(const Tensor& self, int64_t k, int64_t dim, bool keepdim) {
Tensor values = at::empty({0}, self.options());
Tensor indices = at::empty({0}, self.options().dtype(kLong));
return at::native::kthvalue_out(values, indices, self, k, dim, keepdim);
}
std::tuple<Tensor &,Tensor &> kthvalue_out(Tensor& values, Tensor& indices,
const Tensor& self, int64_t k, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"kthvalue only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "kthvalue")) {
AT_ASSERT(values.dim() == 0);
indices.resize_({}).fill_(0);
return std::forward_as_tuple(values, indices);
} else {
return at::legacy::th::_th_kthvalue_out(values, indices, self, k, dim, keepdim);
}
}
std::tuple<Tensor, Tensor> median(const Tensor& self, int64_t dim, bool keepdim) {
Tensor values = at::empty({0}, self.options());
Tensor indices = at::empty({0}, self.options().dtype(kLong));

View File

@ -0,0 +1,226 @@
#include <ATen/ATen.h>
#include <ATen/native/SortingUtils.h>
#include <assert.h>
#include <c10/macros/Macros.h>
#include <stdlib.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
#include <THC/THCNumerics.cuh>
#include <THC/THCScanUtils.cuh>
#include <THC/THCTensorMathReduce.cuh> // AddOp
namespace at {
namespace native {
#if defined(__HIP_PLATFORM_HCC__)
constexpr int WARP_SIZE = 64;
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int WARP_SIZE = 32;
constexpr int MAX_BLOCK_SIZE = 1024;
#endif
// Maximum size per grid dimension that we assume (compute capability >= 2.0)
constexpr int64_t MAX_GRID_SIZE = 65535LL;
static bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
return false;
}
int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
int64_t gridY = 1;
int64_t gridZ = 1;
if (gridTiles > MAX_GRID_SIZE) {
gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE);
gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
if (gridTiles > MAX_GRID_SIZE) {
gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE);
gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
}
}
grid = dim3(gridX, gridY, gridZ);
return true;
}
template <typename scalar_t, bool handleNaN = false>
struct ThrustGTOp {
__device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
return (handleNaN && THCNumerics<scalar_t>::isnan(lhs) &&
!THCNumerics<scalar_t>::isnan(rhs)) ||
THCNumerics<scalar_t>::gt(lhs, rhs);
}
};
template <typename scalar_t, bool handleNaN = false>
struct ThrustLTOp {
__device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
return (handleNaN && THCNumerics<scalar_t>::isnan(rhs) &&
!THCNumerics<scalar_t>::isnan(lhs)) ||
THCNumerics<scalar_t>::lt(lhs, rhs);
}
};
template <typename index_t>
__device__ __forceinline__ index_t getLinearBlockId() {
return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x +
blockIdx.x;
}
// `base` is the base address of a tensor
// For each slice (defined as a linear point of `out`, from 0 ->
// (sliceSize - 1) * sliceStride, we fill that slice from `0` to
// `sliceSize - 1`.
template <typename index_t, int Dim>
__global__ void fillSliceWithIndex_kernel(
cuda::detail::TensorInfo<int64_t, index_t> out,
index_t totalSlices,
index_t sliceSize,
index_t sliceStride) {
index_t slice = getLinearBlockId<index_t>();
if (slice >= totalSlices) {
return;
}
const uint64_t offset =
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, out);
int64_t* base = &out.data[offset];
for (int64_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
// Torch indices are 1-based (hence the +1)
base[i * sliceStride] = i;
}
}
// For slice sorting in Thrust; extracts a slice index from a linear
// index and uses that for comparison
struct SliceComp {
SliceComp(int64_t size) : sliceSize(size) {}
__device__ bool operator()(const int64_t& a, const int64_t& b) const {
// Since the slices are guaranteed to be innermost,
// the segment is just via int64_t division
int64_t segA = a / sliceSize;
int64_t segB = b / sliceSize;
return segA < segB;
}
const int64_t sliceSize;
};
// For sorting in Thurst; extracts a within-slice index from a linear index
struct GlobalIndexToPerSliceIndex {
GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
__device__ inline void operator()(int64_t& v) const {
v = v % sliceSize;
}
const int64_t sliceSize;
};
// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
static uint64_t nextHighestPowerOf2(uint64_t n) {
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
#ifndef _MSC_VER
n |= n >> 32;
#endif
n++;
return n;
}
template <typename scalar_t, typename index_t, typename Launcher>
void run_launcher(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim,
Launcher l) {
auto self_info = cuda::detail::getTensorInfo<scalar_t, index_t>(self);
auto values_info = cuda::detail::getTensorInfo<scalar_t, index_t>(values);
auto indices_info = cuda::detail::getTensorInfo<int64_t, index_t>(indices);
int64_t slice_size = self.size(dim);
/* We use these structures solely to find the offset to */
/* each slice we are operating on */
self_info.reduceDim(dim);
values_info.reduceDim(dim);
indices_info.reduceDim(dim);
/* Collapse all other dims */
int collapse_self_dim = self_info.collapseDims(dim);
int collapse_values_dim = values_info.collapseDims(dim);
int collapse_indices_dim = indices_info.collapseDims(dim);
int64_t num_slices = 1;
for (int i = 0; i < self_info.dims; ++i) {
num_slices *= self_info.sizes[i];
}
/* This is used as a template parameter to calculate indices. */
/* We only specialize it if all collapsed dim sizes are the */
/* same; otherwise, we use -1 which is the specialization */
/* parameter for arbitrary dimensions */
int all_dims = self_info.dims;
if (values_info.dims != all_dims || indices_info.dims != all_dims) {
all_dims = -1;
}
if (all_dims == 1) {
l.template launch<scalar_t, index_t, 1>(
values_info,
collapse_values_dim,
indices_info,
collapse_indices_dim,
self_info,
collapse_self_dim,
num_slices,
slice_size);
} else if (all_dims == 2) {
l.template launch<scalar_t, index_t, 2>(
values_info,
collapse_values_dim,
indices_info,
collapse_indices_dim,
self_info,
collapse_self_dim,
num_slices,
slice_size);
} else if (all_dims == 3) {
l.template launch<scalar_t, index_t, 3>(
values_info,
collapse_values_dim,
indices_info,
collapse_indices_dim,
self_info,
collapse_self_dim,
num_slices,
slice_size);
} else {
l.template launch<scalar_t, index_t, -1>(
values_info,
collapse_values_dim,
indices_info,
collapse_indices_dim,
self_info,
collapse_self_dim,
num_slices,
slice_size);
}
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,249 @@
#include <ATen/ATen.h>
#include <ATen/native/SortingUtils.h>
#include <assert.h>
#include <c10/macros/Macros.h>
#include <stdlib.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
#include <THC/THCNumerics.cuh>
#include <THC/THCScanUtils.cuh>
#include <THC/THCTensorMathReduce.cuh> // AddOp
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/extrema.h>
#include <thrust/inner_product.h>
#include <thrust/sequence.h>
#include <THC/THCThrustAllocator.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/SortingRadixSelect.cuh>
namespace at {
namespace native {
namespace {
template <typename scalar_t, typename index_t, int Dim>
__global__ void gatherKthValue(
cuda::detail::TensorInfo<scalar_t, index_t> input,
index_t inputSliceSize,
index_t k,
index_t numInputSlices,
index_t inputWithinSliceStride,
cuda::detail::TensorInfo<scalar_t, index_t> kthValue,
cuda::detail::TensorInfo<int64_t, index_t> indices) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of index_t
__shared__ int smem[WARP_SIZE]; // one per each warp, up to warp limit
index_t slice = getLinearBlockId<index_t>();
if (slice >= numInputSlices) {
return;
}
// Find the start offset for our slice
index_t sliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
index_t kthValueSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, kthValue);
index_t indicesSliceStartIndex =
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
scalar_t* inputSliceStart = &input.data[sliceStartIndex];
scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
// Find the k-th highest element in our input
scalar_t kValue = static_cast<scalar_t>(0);
radixSelect<
scalar_t,
typename TopKTypeConfig<scalar_t>::RadixType,
index_t,
false>(
inputSliceStart,
k,
inputSliceSize,
inputWithinSliceStride,
smem,
&kValue);
// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
bool isKValue = inRange && (THCNumerics<scalar_t>::eq(v, kValue));
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
}
if (foundKValue) {
kthValueSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
}
}
struct KthValueLauncher {
int64_t k;
KthValueLauncher(int64_t k) : k(k) {}
template <typename scalar_t, typename index_t, int all_dims>
inline void launch(
cuda::detail::TensorInfo<scalar_t, index_t> values_info,
int collapse_values_dim,
cuda::detail::TensorInfo<int64_t, index_t> indices_info,
int collapse_indices_dim,
cuda::detail::TensorInfo<scalar_t, index_t> self_info,
int collapse_self_dim,
int64_t num_slices,
int64_t slice_size) {
dim3 grid;
if (!getGridFromTiles(num_slices, grid)) {
AT_ERROR("slices are too many");
}
dim3 block(
std::min(THCRoundUp(slice_size, (int64_t)WARP_SIZE), (int64_t)1024));
auto stream = at::cuda::getCurrentCUDAStream();
gatherKthValue<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
self_info,
slice_size,
k,
num_slices,
/* The actual dimension that the k-selection is running in */
/* may have changed from collapseDims() */
self_info.strides[collapse_self_dim],
values_info,
indices_info);
}
};
template <typename scalar_t>
void kthvalue_cuda_template(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
int64_t slicesize = self.size(dim);
// FIXME: This seems bogus, I only do this because it was the old behaviour.
// The reductions are fine, as long as the axis being reduced along
// isn't of 0 elements (and the output has elements).
AT_CHECK(
self.numel() > 0,
"cannot perform reduction function kthvalue",
" on tensor with no elements because the operation does not have an identity");
AT_CHECK(k >= 1 && k <= slicesize, "selected number k out of range");
_reduction_with_indices_allocate_or_resize_output(
values, indices, self, dim, keepdim);
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return;
}
AT_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
// Based on required index size, run the algorithm with the
// appropriate index type
if (cuda::detail::canUse32BitIndexMath(self) &&
cuda::detail::canUse32BitIndexMath(values) &&
cuda::detail::canUse32BitIndexMath(indices)) {
run_launcher<scalar_t, uint32_t>(
values, indices, self, dim, KthValueLauncher(k));
} else {
run_launcher<scalar_t, uint64_t>(
values, indices, self, dim, KthValueLauncher(k));
}
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
AT_CUDA_CHECK(cudaGetLastError());
}
// this does not reduce to median with dim beause we don't want to copy twice
template <typename scalar_t>
Tensor median_cuda_template(const Tensor& self) {
AT_CHECK(self.numel() > 0, "median cannot be called with empty tensor");
if (self.dim() == 0 && self.numel() == 1) {
return self.clone();
}
auto self_copy = self.clone().view(-1);
auto values = at::empty({1}, self.options());
auto indices = at::empty({1}, self.options().dtype(kLong));
AT_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
// Based on required index size, run the algorithm with the
// appropriate index type
if (cuda::detail::canUse32BitIndexMath(self) &&
cuda::detail::canUse32BitIndexMath(values) &&
cuda::detail::canUse32BitIndexMath(indices)) {
run_launcher<scalar_t, uint32_t>(
values,
indices,
self_copy,
0,
KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
} else {
run_launcher<scalar_t, uint64_t>(
values,
indices,
self_copy,
0,
KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
}
return values.view({});
}
} // namespace
std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim,
bool keepdim) {
AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "kthvalue", [&] {
kthvalue_cuda_template<scalar_t>(values, indices, self, k, dim, keepdim);
});
return std::forward_as_tuple(values, indices);
}
Tensor median_cuda(const Tensor& self) {
return AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "median", [&] {
return median_cuda_template<scalar_t>(self);
});
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,392 @@
namespace at {
namespace native {
template <typename scalar_t>
struct TopKTypeConfig {};
template <>
struct TopKTypeConfig<float> {
typedef uint32_t RadixType;
// Converts a float to an integer representation with the same
// sorting; i.e., for floats f1, f2:
// if f1 < f2 then convert(f1) < convert(f2)
// We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent
static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
}
static inline __device__ float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
template <>
struct TopKTypeConfig<uint8_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(uint8_t v) {
return v;
}
static inline __device__ uint8_t deconvert(RadixType v) {
return v;
}
};
template <>
struct TopKTypeConfig<int8_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(int8_t v) {
return 128u + v;
}
static inline __device__ int8_t deconvert(RadixType v) {
return v - 128;
}
};
template <>
struct TopKTypeConfig<int16_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(int16_t v) {
assert(sizeof(short) == 2);
return 32768u + v;
}
static inline __device__ int16_t deconvert(RadixType v) {
return v - 32768;
}
};
template <>
struct TopKTypeConfig<int32_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(int32_t v) {
assert(sizeof(int) == 4);
return 2147483648u + v;
}
static inline __device__ int32_t deconvert(RadixType v) {
return v - 2147483648u;
}
};
template <>
struct TopKTypeConfig<int64_t> {
typedef uint64_t RadixType;
static inline __device__ RadixType convert(int64_t v) {
assert(sizeof(int64_t) == 8);
return 9223372036854775808ull + v;
}
static inline __device__ int64_t deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct TopKTypeConfig<double> {
typedef uint64_t RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
}
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <>
struct TopKTypeConfig<at::Half> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(at::Half v) {
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
RadixType x = __half_as_ushort(v);
RadixType mask = -((x >> 15)) | 0x8000;
return (x ^ mask);
#else
assert(false);
return 0u;
#endif
}
static inline __device__ at::Half deconvert(RadixType v) {
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
RadixType mask = ((v >> 15) - 1) | 0x8000;
return __ushort_as_half(v ^ mask);
#else
assert(false);
return static_cast<at::Half>(0);
#endif
}
};
// This function counts the distribution of all input values in a
// slice we are selecting by radix digit at `radixDigitPos`, but only
// those that pass the filter `((v & desiredMask) == desired)`.
// This produces and broadcasts the seen counts for a single block only.
// `smem` must have at least `RadixSize` elements.
template <
typename scalar_t,
typename bitwise_t,
typename index_t,
typename CountType,
int RadixSize,
int RadixBits>
__device__ void countRadixUsingMask(
CountType counts[RadixSize],
CountType* smem,
bitwise_t desired,
bitwise_t desiredMask,
int radixDigitPos,
index_t sliceSize,
index_t withinSliceStride,
scalar_t* data) {
// Clear out per-thread counts from a previous round
#pragma unroll
for (int i = 0; i < RadixSize; ++i) {
counts[i] = 0;
}
if (threadIdx.x < RadixSize) {
smem[threadIdx.x] = 0;
}
__syncthreads();
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
for (index_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
bitwise_t val =
TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));
bool hasVal = ((val & desiredMask) == desired);
bitwise_t digitInRadix =
Bitfield<bitwise_t>::getBitfield(val, radixDigitPos, RadixBits);
#pragma unroll
for (uint32_t j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
#if defined(__HIP_PLATFORM_HCC__)
counts[j] += __popcll(WARP_BALLOT(vote));
#else
counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
#endif
}
}
// Now, for each warp, sum values
if (getLaneId() == 0) {
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
atomicAdd(&smem[i], counts[i]);
}
}
__syncthreads();
// For each thread, read in the total counts
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
counts[i] = smem[i];
}
__syncthreads();
}
// Over what radix we are selecting values
constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
constexpr int RADIX_MASK = (RADIX_SIZE - 1);
// This finds the unique value `v` that matches the pattern
// ((v & desired) == desiredMask) in our sorted int format
template <typename scalar_t, typename bitwise_t, typename index_t>
__device__ scalar_t findPattern(
scalar_t* smem,
scalar_t* data,
index_t sliceSize,
index_t withinSliceStride,
bitwise_t desired,
bitwise_t desiredMask) {
if (threadIdx.x < WARP_SIZE) {
smem[threadIdx.x] = static_cast<scalar_t>(0);
}
__syncthreads();
// All threads participate in the loop, in order to sync on the flag
index_t numIterations =
THCRoundUp(sliceSize, static_cast<index_t>(blockDim.x));
for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < sliceSize);
scalar_t v = inRange ? doLdg(&data[i * withinSliceStride])
: static_cast<scalar_t>(0);
if (inRange &&
((TopKTypeConfig<scalar_t>::convert(v) & desiredMask) == desired)) {
// There should not be conflicts if we are using findPattern,
// since the result is unique
smem[0] = static_cast<scalar_t>(1);
smem[1] = v; // can't use val as the flag, since it could be 0
}
__syncthreads();
scalar_t found = smem[0];
scalar_t val = smem[1];
__syncthreads();
// Check to see if a thread found the value
if (THCNumerics<scalar_t>::ne(found, static_cast<scalar_t>(0))) {
// all threads return this value
return val;
}
}
// should not get here
assert(false);
return static_cast<scalar_t>(0);
}
// Returns the top-Kth element found in the data using radix selection
template <typename scalar_t, typename bitwise_t, typename index_t, bool Order>
__device__ void radixSelect(
scalar_t* data,
index_t k,
index_t sliceSize,
index_t withinSliceStride,
int* smem,
scalar_t* topK) {
// Per-thread buckets into which we accumulate digit counts in our
// radix
int counts[RADIX_SIZE];
// We only consider elements x such that (x & desiredMask) == desired
// Initially, we consider all elements of the array, so the above
// statement is true regardless of input.
bitwise_t desired = 0;
bitwise_t desiredMask = 0;
// We are looking for the top kToFind-th element when iterating over
// digits; this count gets reduced by elimination when counting
// successive digits
int kToFind = k;
// We start at the most significant digit in our radix, scanning
// through to the least significant digit
#pragma unroll
for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0;
digitPos -= RADIX_BITS) {
// Count radix distribution for the current position and reduce
// across all threads
countRadixUsingMask<
scalar_t,
bitwise_t,
index_t,
int,
RADIX_SIZE,
RADIX_BITS>(
counts,
smem,
desired,
desiredMask,
digitPos,
sliceSize,
withinSliceStride,
data);
auto found_unique = [&](int i, int count) -> bool {
/* All threads have the same value in counts here, so all */
/* threads will return from the function. */
if (count == 1 && kToFind == 1) {
/* There is a unique answer. */
desired =
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
desiredMask = Bitfield<bitwise_t>::setBitfield(
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
/* The answer is now the unique element v such that: */
/* (v & desiredMask) == desired */
/* However, we do not yet know what the actual element is. We */
/* need to perform a search through the data to find the */
/* element that matches this pattern. */
*topK = findPattern<scalar_t, bitwise_t, index_t>(
(scalar_t*)smem,
data,
sliceSize,
withinSliceStride,
desired,
desiredMask);
return true;
}
return false;
};
auto found_non_unique = [&](int i, int count) -> bool {
if (count >= kToFind) {
desired =
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
desiredMask = Bitfield<bitwise_t>::setBitfield(
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
/* The top-Kth element v must now be one such that: */
/* (v & desiredMask == desired) */
/* but we haven't narrowed it down; we must check the next */
/* least-significant digit */
return true;
}
kToFind -= count;
return false; // continue the loop
};
// All threads participate in the comparisons below to know the
// final result
if (Order) {
// Process in descending order
#pragma unroll
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
} else {
// Process in ascending order
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
}
} // end digitPos for
// There is no unique result, but there is a non-unique result
// matching `desired` exactly
*topK = TopKTypeConfig<scalar_t>::deconvert(desired);
}
} // namespace native
} // namespace at

View File

@ -1196,6 +1196,9 @@
variants: function, method
- func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices)
dispatch:
CPU: kthvalue_out_cpu
CUDA: kthvalue_out_cuda
- func: layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
matches_jit_signature: True

View File

@ -2,7 +2,6 @@
# These are skipped by test_cuda.py
torch.ByteTensor.dist
torch.ByteTensor.dot
torch.ByteTensor.kthvalue
torch.ByteTensor.lerp
torch.ByteTensor.lerp_
torch.ByteTensor.mean
@ -13,7 +12,6 @@ torch.ByteTensor.std
torch.ByteTensor.var
torch.CharTensor.dist
torch.CharTensor.dot
torch.CharTensor.kthvalue
torch.CharTensor.lerp
torch.CharTensor.lerp_
torch.CharTensor.mean
@ -22,8 +20,6 @@ torch.CharTensor.renorm
torch.CharTensor.renorm_
torch.CharTensor.std
torch.CharTensor.var
torch.DoubleTensor.kthvalue
torch.FloatTensor.kthvalue
torch.HalfTensor.chunk_
torch.HalfTensor.clone_
torch.HalfTensor.contiguous_
@ -47,7 +43,6 @@ torch.HalfTensor.inverse_
torch.HalfTensor.is_contiguous_
torch.HalfTensor.is_same_size_
torch.HalfTensor.is_set_to_
torch.HalfTensor.kthvalue
torch.HalfTensor.kthvalue_
torch.HalfTensor.max_
torch.HalfTensor.mean_
@ -87,7 +82,6 @@ torch.HalfTensor.zeros
torch.HalfTensor.zeros_
torch.IntTensor.dist
torch.IntTensor.dot
torch.IntTensor.kthvalue
torch.IntTensor.lerp
torch.IntTensor.lerp_
torch.IntTensor.mean
@ -98,7 +92,6 @@ torch.IntTensor.std
torch.IntTensor.var
torch.LongTensor.dist
torch.LongTensor.dot
torch.LongTensor.kthvalue
torch.LongTensor.lerp
torch.LongTensor.lerp_
torch.LongTensor.mean
@ -109,7 +102,6 @@ torch.LongTensor.std
torch.LongTensor.var
torch.ShortTensor.dist
torch.ShortTensor.dot
torch.ShortTensor.kthvalue
torch.ShortTensor.lerp
torch.ShortTensor.lerp_
torch.ShortTensor.mean