Port sort to structured kernels.

Tracking Issue: #55070

This PR relands #67016, with the modifications discussed in https://github.com/pytorch/pytorch/pull/67015#issuecomment-982004500.

In summary, we call `infer_dense_strides` on the input's strides, and pass it to `set_output`. Meaning that if one of the outputs is resized (by a `resize_output` call), we will also restride such an output using the dense strides of the input.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72058
Approved by: https://github.com/bdhirsh
This commit is contained in:
Yukio Siraichi
2022-04-21 01:48:46 +00:00
committed by PyTorch MergeBot
parent 45bbc4c028
commit 3e10fe3231
9 changed files with 110 additions and 167 deletions

View File

@ -14,30 +14,46 @@
namespace at {
namespace meta {
using namespace native;
TORCH_META_FUNC(topk) (
const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");
TORCH_META_FUNC(topk)
(const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");
// Build the output size, which is the dim being selected set to
// size k
DimVector topKSize(self.sizes().vec());
if (topKSize.size() > 0) {
topKSize[dim] = k;
}
set_output(0, topKSize, self.options());
set_output(1, topKSize, self.options().dtype(at::kLong));
// Build the output size, which is the dim being selected set to
// size k
DimVector topKSize(self.sizes().vec());
if (topKSize.size() > 0) {
topKSize[dim] = k;
}
set_output(0, topKSize, self.options());
set_output(1, topKSize, self.options().dtype(at::kLong));
}
TORCH_META_FUNC2(sort, stable)
(const Tensor& self, c10::optional<bool> stable, int64_t dim, bool descending) {
TORCH_INTERNAL_ASSERT(
stable.has_value(),
"sort(): c10::optional<bool> for stable has to have value.");
maybe_wrap_dim(dim, self.dim());
// See issue: https://github.com/pytorch/pytorch/issues/65863
// Strides should be dense, so as not to allocate too much memory.
// We either use 'self' strides, or infer dense strides from them.
std::vector<int64_t> strides = (self.is_non_overlapping_and_dense())
? self.strides().vec()
: at::infer_dense_strides(self.sizes(), self.strides());
set_output(0, self.sizes(), strides, self.options(), {});
set_output(1, self.sizes(), strides, self.options().dtype(kLong), {});
}
} // namespace meta
namespace native {
@ -865,52 +881,37 @@ Tensor nanmedian_cpu(const Tensor& self) {
return median_impl(self, /*ignore_nan=*/true);
}
std::tuple<Tensor&, Tensor&> sort_out_cpu_stable(const Tensor& self,
c10::optional<bool> stable,
int64_t dim,
bool descending,
Tensor& values,
Tensor& indices) {
values.resize_(self.sizes()).copy_(self);
indices.resize_(self.sizes());
TORCH_IMPL_FUNC(sort_stable_out)
(const Tensor& self,
c10::optional<bool> stable,
int64_t dim,
bool descending,
const Tensor& values,
const Tensor& indices) {
values.copy_(self);
// check if self is scalar
if (self.dim() == 0 && self.numel() == 1) {
indices.zero_();
return std::forward_as_tuple(values, indices);
} else {
dim = maybe_wrap_dim(dim, self.dim());
sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value());
}
TORCH_INTERNAL_ASSERT(stable.has_value(), "sort_out(): c10::optional<bool> for stable has to have value.");
sort_stub(kCPU, values, indices, dim, descending, stable.value());
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor&, Tensor&> sort_out_cpu(const Tensor& self,
std::tuple<Tensor&, Tensor&> sort_out(
const Tensor& self,
int64_t dim,
bool descending,
Tensor& values,
Tensor& indices) {
return at::native::sort_out_cpu_stable(
self, /*stable=*/false, dim, descending, values, indices);
return at::sort_out(values, indices, self, false, dim, descending);
}
std::tuple<Tensor, Tensor> sort_cpu_stable(
const Tensor& self,
c10::optional<bool> stable,
int64_t dim,
bool descending) {
TORCH_CHECK(!self.is_complex(), "sort(): input tensor must be of non-complex type");
Tensor values = at::empty({0}, self.options());
Tensor indices = at::empty({0}, self.options().dtype(kLong));
return at::native::sort_out_cpu_stable(self, stable, dim, descending, values, indices);
}
std::tuple<Tensor, Tensor> sort_cpu(
std::tuple<Tensor, Tensor> sort(
const Tensor& self,
int64_t dim,
bool descending) {
return sort_cpu_stable(self, /*stable=*/false, dim, descending);
return at::sort(self, false, dim, descending);
}
Tensor& msort_out(const Tensor& self, Tensor& values) {

View File

@ -18,7 +18,7 @@ enum class QUANTILE_INTERPOLATION_MODE : uint8_t {
NEAREST
};
using sort_fn = void(*)(const TensorBase &values, const TensorBase &indices, int64_t dim, bool descending, bool stable);
using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool);
using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool);
DECLARE_DISPATCH(sort_fn, sort_stub);

View File

@ -22,12 +22,6 @@ void _dim_apply(
int64_t dim,
const std::string& method_name,
const func_t& f) {
dim = maybe_wrap_dim(dim, values.dim());
TORCH_CHECK(
dim >= 0 && dim < values.dim(),
method_name, "(): invalid dimension parameter ", dim
);
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.resize_outputs(false)
@ -90,8 +84,9 @@ struct KeyValueCompDesc {
};
static void sort_kernel(
const TensorBase &values,
const TensorBase &indices,
const TensorBase& self,
const TensorBase& values,
const TensorBase& indices,
int64_t dim,
bool descending,
bool stable) {

View File

@ -5,6 +5,7 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/TensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/native/Sorting.h>
#include <ATen/native/Resize.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -38,7 +39,7 @@ bool should_use_small_sort(const TensorBase &self, int64_t dim) {
std::vector<int64_t> infer_dense_strides_dim_last(const Tensor & self, int64_t dim);
void fillSliceWithIndex(Tensor& t,int dim) {
void fillSliceWithIndex(const Tensor& t, int dim) {
if (t.numel()) {
auto sizes = DimVector(t.dim(), 1);
sizes[dim] = t.sizes()[dim];
@ -51,18 +52,28 @@ void fillSliceWithIndex(Tensor& t,int dim) {
// We perform a segmented sort in cub with inputs that have
// more than 1024/2048 elements along the selected dimension.
// Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace).
std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending, Tensor & values, Tensor & indices) {
void sort_cuda_kernel(
const TensorBase& self_base,
const TensorBase& values_base,
const TensorBase& indices_base,
int64_t dim,
bool descending,
bool stable) {
// this algorithm is always stable
TORCH_INTERNAL_ASSERT(stable.has_value(), "sort_out(): c10::optional<bool> for stable has to have value.");
TensorArg self_arg{self, "self", 1}, values_arg{values, "values", 2}, indices_arg{indices, "indices", 3};
checkAllSameGPU(__func__, {self_arg, values_arg, indices_arg});
bool is_non_overlapping_and_dense = self.is_non_overlapping_and_dense();
int64_t ndim = self.dim();
dim = maybe_wrap_dim(dim, ndim);
int64_t nsort = self.sizes()[dim];
// Macro for converting `TensorBase` -> `Tensor` without
// reference count bumps.
#define TOTENSOR(BASE, VAR) \
OptionalTensorRef opt_##BASE(BASE); \
const Tensor& VAR = *opt_##BASE;
TORCH_CHECK(nsort <= std::numeric_limits<int>::max(),
// Converting TensorBase into Tensor.
// We will need Tensor's methods from this point onwards.
TOTENSOR(self_base, self);
TOTENSOR(values_base, values);
TOTENSOR(indices_base, indices);
TORCH_CHECK(self.sizes()[dim] <= std::numeric_limits<int>::max(),
"The dimension being sorted can not have more than INT_MAX elements.");
const auto self_dtype = self.dtype();
@ -72,37 +83,9 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble,
"Sort currently does not support complex dtypes on CUDA.");
if (ndim == 0) {
if (!values.defined()) {
values = self.clone();
} else {
values.resize_as_(self);
values.copy_(self);
}
if (!indices.defined()) {
indices = at::zeros({}, self.options().dtype(kLong));
} else {
indices.resize_as_(self);
indices.zero_();
}
return std::forward_as_tuple(values, indices);
}
// use inplace algorithm for smaller input sizes without stable=True
if (should_use_small_sort(self, dim) && !stable.value()) {
if (should_use_small_sort(self, dim) && !stable) {
// from thc: sorted->values, indices->indices, input->self
if (!values.defined()) {
values = at::empty_like(self);
}
if (!indices.defined()) {
indices = at::empty_like(self, self.options().dtype(kLong));
}
// Make sure sufficient output space is allocated
auto self_size = self.sizes();
at::native::resize_output(values, self_size);
at::native::resize_output(indices, self_size);
fillSliceWithIndex(indices, dim);
// We sort k/v pairs in-place; copy unsorted input to output
@ -111,12 +94,12 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
// Sort using our in-place k/v kernel that supports arbitrary
// layout
sortKeyValueInplace(values, indices, dim, descending);
return std::forward_as_tuple(values, indices);
return;
}
Tensor self_;
bool newself = false;
if (is_non_overlapping_and_dense && self.stride(dim) == 1) {
if (self.is_non_overlapping_and_dense() && self.stride(dim) == 1) {
self_ = self;
} else {
auto new_strides_unsort = infer_dense_strides_dim_last(self, dim);
@ -126,19 +109,6 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
}
c10::MaybeOwned<Tensor> values_tmp, indices_tmp;
if (!values.defined()) {
if (is_non_overlapping_and_dense) {
values = at::empty_strided(self.sizes(), self.strides(), self.options());
} else {
auto strides = at::infer_dense_strides(self.sizes(), self.strides());
values = at::empty_strided(self.sizes(), strides, self.options());
}
} else {
TORCH_CHECK(self_.scalar_type() == values.scalar_type(),
"Unexpected dtype for values, expect ", self_.scalar_type(), ", got ", values.scalar_type());
values.resize_as_(self);
}
if (values.strides() == self_.strides() && (newself || get_overlap_status(self, values) == MemOverlapStatus::NO)) {
values_tmp = c10::MaybeOwned<Tensor>::borrowed(values);
} else {
@ -146,18 +116,6 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
at::empty_strided(self_.sizes(), self_.strides(), self_.options()));
}
if (!indices.defined()) {
if (is_non_overlapping_and_dense) {
indices = at::empty_strided(self.sizes(), self.strides(), self.options().dtype(kLong));
} else {
auto strides = at::infer_dense_strides(self.sizes(), self.strides());
indices = at::empty_strided(self.sizes(), strides, self.options().dtype(kLong));
}
} else {
TORCH_CHECK(kLong == indices.scalar_type(),
"Unexpected dtype for values, expect torch.long, got ", indices.scalar_type());
indices.resize_as_(self);
}
if (indices.strides() != self_.strides()) {
indices_tmp = c10::MaybeOwned<Tensor>::owned(
at::empty_strided(self_.sizes(), self_.strides(), self_.options().dtype(kLong)));
@ -173,20 +131,11 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
if (!indices_tmp->is_same(indices)) {
indices.copy_(*indices_tmp);
}
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor &,Tensor &> sort_out_cuda(const Tensor & self, int64_t dim, bool descending, Tensor & values, Tensor & indices) {
return sort_out_stable_cuda(self, /*stable=*/false, dim, descending, values, indices);
}
std::tuple<Tensor,Tensor> sort_stable_cuda(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending) {
Tensor values, indices;
return sort_out_stable_cuda(self, stable, dim, descending, values, indices);
}
std::tuple<Tensor,Tensor> sort_cuda(const Tensor & self, int64_t dim, bool descending) {
return sort_stable_cuda(self, /*stable=*/false, dim, descending);
}
// TODO: we should handle this accordingly when we start using REGISTER_HIP_DISPATCH,
// since REGISTER_DISPATCH won't work in this cpp file.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_CUDA_DISPATCH(sort_stub, &sort_cuda_kernel);
}} // namespace at::native

View File

@ -10,9 +10,10 @@
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/CUDAFunctions.h>
#else
#include <ATen/ops/empty_like.h>
#include <ATen/ops/sort_native.h>
#include <ATen/ops/sort_cuda_dispatch.h>
#include <ATen/ops/topk_native.h>
#endif
@ -26,7 +27,7 @@ void topk_out_with_sort(
const Tensor& indices
) {
Tensor sorted_values, sorted_indices;
std::tie(sorted_values, sorted_indices) = at::native::sort_cuda(self, dim, largest);
std::tie(sorted_values, sorted_indices) = at::cuda::sort(self, /* stable= */false, dim, largest);
values.copy_(sorted_values.narrow(dim, 0, k));
indices.copy_(sorted_indices.narrow(dim, 0, k));
}
@ -83,7 +84,7 @@ TORCH_IMPL_FUNC(topk_out_cuda)
Tensor sortedIndices = at::empty_like(indices);
Tensor sortedValues = at::empty_like(values);
sort_out_cuda(values, dim, largest, sortedValues, sortedIndices);
at::cuda::sort_outf(values, /* stable= */ false, dim, largest, sortedValues, sortedIndices);
indices.copy_(indices.gather(dim, sortedIndices));
values.copy_(sortedValues);
}

View File

@ -7755,27 +7755,23 @@
- func: sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
dispatch:
CPU: sort_out_cpu
CUDA: sort_out_cuda
CompositeExplicitAutograd: sort_out
- func: sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
structured: True
dispatch:
CPU: sort_out_cpu_stable
CUDA: sort_out_stable_cuda
CPU, CUDA: sort_stable_out
- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
device_check: NoCheck # TensorIterator
variants: method, function
dispatch:
CPU: sort_cpu
CUDA: sort_cuda
QuantizedCPU: sort_quantized_cpu
CompositeExplicitAutograd: sort
- func: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
structured_delegate: sort.values_stable
variants: method, function
dispatch:
CPU: sort_cpu_stable
CUDA: sort_stable_cuda
QuantizedCPU: sort_quantized_cpu_stable
- func: sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)

View File

@ -35,12 +35,5 @@ std::tuple<Tensor, Tensor> sort_quantized_cpu_stable(
sort_indicies);
}
std::tuple<Tensor, Tensor> sort_quantized_cpu(
const Tensor& self,
int64_t dim,
bool descending) {
return sort_quantized_cpu_stable(self, /*stable=*/false, dim, descending);
}
} // namespace native
} // namespace at

View File

@ -163,6 +163,23 @@ class TestSortAndSelect(TestCase):
self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device))
self.assertEqual(im, t0.sort().indices)
@dtypes(torch.float32)
def test_sort_restride(self, device, dtype):
# Input: non-contiguous (stride: 5) 3-element array
tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
# Outputs: 0-dim tensors
# They will need to be resized, which means they will also be
# restrided with the input tensor's strides as base.
values = torch.tensor(0, dtype=dtype, device=device)
indices = torch.tensor(0, dtype=torch.long, device=device)
torch.sort(tensor, out=(values, indices))
# Check: outputs were restrided to dense strides
self.assertEqual(values.stride(), (1,))
self.assertEqual(indices.stride(), (1,))
# Check: 'tensor' indexed by 'indices' is equal to 'values'
self.assertEqual(tensor[indices], values)
def _test_sort_discontiguous(self, device, dtype):
# on CUDA 2048 vs >2048 have different code path for the dim being sorted
sizes = (5, 7, 2049)

View File

@ -14313,10 +14313,6 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# sort does not correctly warn when resizing out= inputs
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
# Allows unsafe cast
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
)),
OpInfo('unique',
@ -14900,11 +14896,6 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# msort does not correctly warn when resizing out= inputs.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
# Expected RuntimeError when doing an unsafe cast from a result of dtype
# torch.float32 into an out= with dtype torch.long
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
),
sample_inputs_func=sample_inputs_msort),