mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR #145536](https://github.com/pytorch/pytorch/pull/145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR #145536](https://github.com/pytorch/pytorch/pull/145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR #145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007