Files
pytorch/aten
YyWangCS 3cc8af2d67 torch.topk: refactor global histogram/cumsum into a dedicated kernel to eliminate redundant memory access (#164459)
# TLDR
This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs.

The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs.
| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

# Background
After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**.

Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR #145536](https://github.com/pytorch/pytorch/pull/145536).

# Analysis

`torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages:

1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory.
2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram.

Before [PR #145536](https://github.com/pytorch/pytorch/pull/145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction.

In PR #145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**:

- `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks.
- For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory.

# This PR

To address this inefficiency, we introduce the following optimizations:

1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`.
2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads.

# Performance
We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs.

| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

Besides, I have verified the correctness of this PR with different inputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164459
Approved by: https://github.com/ngimel, https://github.com/Skylion007
2025-10-07 11:04:03 +00:00
..
2023-05-19 00:49:08 +00:00