Implement topk with sort for some cases (#68632)

Summary:
Benchmark that compares original implementation and the sort implementation (this code should run on a branch without this patch):
```python
import torch
import timeit

def tune_dtype(f):
    def ret(*args, **kwargs):
        for dtype in [torch.int8, torch.half, torch.float, torch.double]:
            f(*args, **kwargs, dtype=dtype)
    return ret

def tune_slice(f):
    def ret(*args, **kwargs):
        slice = 1
        while slice <= 256:
            f(*args, **kwargs, slice=slice)
            slice *= 2
    return ret

def tune_slice_size(f):
    def ret(*args, **kwargs):
        slice_size = 1
        while slice_size <= 1_000_000:
            f(*args, **kwargs, slice_size=slice_size)
            slice_size *= 10
    return ret

def tune_k(f):
    def ret(*args, slice_size, **kwargs):
        k = 1
        while k <= slice_size:
            f(*args, **kwargs, k=k, slice_size=slice_size)
            k *= 10
    return ret

def topk_with_sort(tensor, k, dim=-1, largest=True):
    values, indices = tensor.sort(dim=dim, descending=largest)
    return values.narrow(dim, 0, k), indices.narrow(dim, 0, k)

def run50sync(f):
    for _ in range(50):
        f()
    torch.cuda.synchronize()

def warmup():
    N = 1000000
    for i in range(1, N // 10000):
        torch.randn(i, device='cuda')

def benchmark_one(slice, slice_size, k, dtype):
    input_ = torch.empty((slice, slice_size), dtype=dtype, device="cuda").random_()
    torch.cuda.synchronize()
    time = timeit.timeit(lambda: run50sync(lambda: torch.topk(input_, k, dim=1)), number=1)
    torch.cuda.synchronize()
    time_sort = timeit.timeit(lambda: run50sync(lambda: topk_with_sort(input_, k, dim=1)), number=1)
    method = "orig" if time < time_sort else "sort"
    speedup = time / time_sort
    print(f"(dtype={dtype}, slice={slice}, slice_size={slice_size}, k={k}) -> (method={method}, speedup={speedup})")

if __name__ == "__main__":
    warmup()
    tune_dtype(tune_slice(tune_slice_size(tune_k(benchmark_one))))()

```
Benchmark result see next comment.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68632

Reviewed By: dagitses

Differential Revision: D32566233

Pulled By: ngimel

fbshipit-source-id: f7a508176ef3685b491048c4a6562121c60b8b2a
This commit is contained in:
Xiang Gao
2021-11-19 17:15:10 -08:00
committed by Facebook GitHub Bot
parent e554d8b89c
commit 95f4cd0ba9
2 changed files with 33 additions and 0 deletions

View File

@ -7,6 +7,28 @@
namespace at {
namespace native {
void topk_out_with_sort(
const Tensor& self,
int64_t k, int64_t dim, bool largest,
const Tensor& values,
const Tensor& indices
) {
Tensor sorted_values, sorted_indices;
std::tie(sorted_values, sorted_indices) = at::native::sort_cuda(self, dim, largest);
values.copy_(sorted_values.narrow(dim, 0, k));
indices.copy_(sorted_indices.narrow(dim, 0, k));
}
bool should_use_sort(const Tensor& self, int64_t dim) {
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
if (self.dim() == 0) return false;
if (self.dtype() == kBool) return false; // Bool is not support by topk
int64_t slice_size = self.size(dim);
if (slice_size == 0) return false;
int64_t num_slices = self.numel() / slice_size;
return num_slices <= 16 && slice_size >= 100000;
}
TORCH_IMPL_FUNC(topk_out_cuda)
(const Tensor& self,
int64_t k, int64_t dim, bool largest, bool sorted,
@ -14,8 +36,14 @@ TORCH_IMPL_FUNC(topk_out_cuda)
const Tensor& indices) {
TensorArg topK_arg{values, "topK", 1}, indices_arg{indices, "indices", 2}, input_arg{self, "self", 3};
checkAllSameGPU(__func__, {topK_arg, indices_arg, input_arg});
dim = at::maybe_wrap_dim(dim, self);
if (should_use_sort(self, dim)) {
topk_out_with_sort(self, k, dim, largest, values, indices);
return;
}
// If k is 0 the result is an empty tensor, so we don't need to launch a kernel.
if (k == 0) {
return;

View File

@ -370,6 +370,11 @@ class TestSortAndSelect(TestCase):
k = random.randint(1, testTensor.size(dim))
compare(testTensor, k, dim, dir)
# This tests the code path where on CUDA, topk is implemented with sort.
t = torch.randn((2, 100000), device=device)
compare(t, 2000, 1, True)
compare(t, 2000, 1, False)
def test_topk_arguments(self, device):
q = torch.randn(10, 2, 10, device=device)
# Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)