# TLDR
This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs.
The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs.
| Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B) | 36.6 | 1564.1 | 25.6 |
| (1, 100M) | 3.56 | 17.4 | 2.54 |
| (1, 1000,000) | 0.135 | 0.145 | 0.098 |
| (512, 128000) | 1.33 | 1.33 | 1.32 |
| (8192, 128000) | 19.6 | 19.6 | 19.4 |
# Background
After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**.
Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR #145536](https://github.com/pytorch/pytorch/pull/145536).
# Analysis
`torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages:
1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory.
2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram.
Before [PR #145536](https://github.com/pytorch/pytorch/pull/145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction.
In PR #145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**:
- `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks.
- For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory.
# This PR
To address this inefficiency, we introduce the following optimizations:
1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`.
2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads.
# Performance
We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs.
| Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B) | 36.6 | 1564.1 | 25.6 |
| (1, 100M) | 3.56 | 17.4 | 2.54 |
| (1, 1000,000) | 0.135 | 0.145 | 0.098 |
| (512, 128000) | 1.33 | 1.33 | 1.32 |
| (8192, 128000) | 19.6 | 19.6 | 19.4 |
Besides, I have verified the correctness of this PR with different inputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164459
Approved by: https://github.com/ngimel, https://github.com/Skylion007
This PR applies clang-tidy readability checks to jit sources and all headers in the code base.
`readability-redundant-inline-specifier` is suppressed because it incurs too many changes. `readability-redundant-inline-specifier` is used to detect redundant inline specifiers on function and variable declarations. There are many in-class method definitions that are marked inline.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164652
Approved by: https://github.com/Skylion007
This PR applies clang-tidy readability checks to jit sources and all headers in the code base.
`readability-redundant-inline-specifier` is suppressed because it incurs too many changes. `readability-redundant-inline-specifier` is used to detect redundant inline specifiers on function and variable declarations. There are many in-class method definitions that are marked inline.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164652
Approved by: https://github.com/Skylion007
```bash
DEBUG /data/vllm-community-homes/vllm-user-6/pytorch/aten/src/ATen/cuda/CUDAGraph.h(59): warning #68-D: integer conversion resulted in a change of sign
DEBUG CaptureId_t capture_id_ = -1;
DEBUG ^
DEBUG
DEBUG Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
DEBUG
DEBUG /data/vllm-community-homes/vllm-user-6/pytorch/aten/src/ATen/cuda/CUDAGraph.h(59): warning #68-D: integer conversion resulted in a change of sign
DEBUG CaptureId_t capture_id_ = -1;
DEBUG ^
DEBUG
DEBUG Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
DEBUG
DEBUG /data/vllm-community-homes/vllm-user-6/pytorch/aten/src/ATen/cuda/CUDAGraph.h(59): warning #68-D: integer conversion resulted in a change of sign
DEBUG CaptureId_t capture_id_ = -1;
DEBUG ^
```
Cuda won't use 0 as a capture id, so it is safe to initialize with 0, which also matches the initialization in `pytorch/aten/src/ATen/native/cudnn/RNN.cpp:2362`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163898
Approved by: https://github.com/houseroad
Even though `offset_intragraph_` only tracks RNG consumption within a single graph replay, we have observed that the 32bit storage for these offsets is easy to overshoot, especially for cases with big CUDA graph captures including kernels that are generating a large amount of random numbers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164515
Approved by: https://github.com/eee4017, https://github.com/eqy
Summary:
This diff adds the feature of allocating a large pinned memory segment upfront based on the provided config. This large segment is then used to serve all the small pinned memory requests to avoid expensive device level APIs (slow paths).
Example:
PYTORCH_CUDA_ALLOC_CONF=pinned_reserve_segment_size_mb:2048
This reserves a 2GB pinned memory segment for the process and then all incoming small requests are just served from this segment and no cudaHostAlloc/cudaHostRegister apis are being called.
Differential Revision: D83779074
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164501
Approved by: https://github.com/yangw-dev
`grad_dtype` is a new attribute on Tensor to control gradient dtype:
- Access/setting is leaf-only.
- grad_dtype is respected when (1) when assigning to .grad, and (2) in the engine after the previous node produces incoming gradients for AccumulateGrad. (See table below for details)
- Not setting grad_dtype preserves the current behavior. Accessing it returns `t.dtype`
- `grad_dtype` cannot be set when there is already a `.grad` present and the dtypes conflict.
| `grad_dtype` setting | Setting `.grad` manually | Incoming gradient from autograd engine |
|-----------------------|--------------------------|-----------------------------------------|
| **Default (tensor’s dtype)** | `.grad` must match tensor’s dtype | Engine casts incoming grad to tensor’s dtype |
| **Set to specific dtype** | `.grad` must match that dtype | Engine casts incoming grad to the specified dtype |
| **Set to `None`** | `.grad` may be any dtype | Engine does not cast; accepts incoming grad dtype as-is |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162815
Approved by: https://github.com/albanD
Shared memory is allocated by creating a file in /dev/shm (by default) that can run out of space. Pytorch reserves the file size by calling ftruncate() that creates a sparse file, so it succeeds even if sufficient disk space is not available.
This could lead to a situation when a shared memory region is successfully created but a subsequent access to a shared memory page results in SIGBUS due to the disk being full.
Using posix_fallocate() instead of ftruncate() eliminates this problem because the former syscall always allocates space and it returns an error if the disk is full.
Related to https://github.com/pytorch/pytorch/issues/5040
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161910
Approved by: https://github.com/mikaylagawarecki
We do some fixes in pinned memory allocation stats collection and better differentiate between active vs allocated bytes.
Reviewed By: bbus, sayitmemory
Differential Revision: D83162346
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164412
Approved by: https://github.com/mradmila
In #163455 , the `reshape` was not a pure view op.
The `permute` before it created an non-contiguous tensor, which would trigger a data copy during the reshape.
This PR improved the implementation by remove the `urtensor` intermediate tensor completely.
By simply expanding the `xtensor` would achieve the `repeat` effect.
Before this PR, there were two data copies (in `urtensor.copy_` and `urtensor.reshape`).
Now, there is only one data copy in the `.copy_()`.
Reshape would not copy data because it is on a contiguous tensor.
One more note is that we do want at one copy because we want to duplicate the elements for the repeats.
User can inplace modify single elements without afffecting others.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163842
Approved by: https://github.com/Skylion007
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Summary: This diffs add an API to query expandable segment size for each stream so that we can use this info to warmup the segment in advance, so we dont incur any performance penalty during steady state inference for new CUDA memory allocations.
Differential Revision: D76447308
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163771
Approved by: https://github.com/bbus
Summary:
As titled.
sdpa will select backend based on hardware check, and it fails when exporting with cuda under fake mode on a cuda-less machine.
We guard `at::cuda::is_available()` check before `at::cuda::getCurrentDeviceProperties()` and give warnings.
Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r nn_functional_scaled_dot_product_attention
Differential Revision: D83496154
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164162
Approved by: https://github.com/SherlockNoMad
Before returning a comand buffer, as subsequent calle are very likely to allocate their own encoder, which results in the following runtime error
```
tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1090: failed assertion `A command encoder is already encoding to this command buffer'
```
Added regression test to `test_mps_extension`
Please note, that `torch::mps::get_command_buffer()` should be called with dispatch_queue held, both before and after this change, but many implementations skip that
Fixes https://github.com/pytorch/pytorch/issues/163721
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164093
Approved by: https://github.com/atalman, https://github.com/Skylion007
Summary: Relax stride check for block-wise scaling (1x128, 128x128) when a dimension of the scaling factor is 1. When the scaling tensor has a dimension of size 1, the stride is effectively "meaningless" to PyTorch, i.e. PyTorch decides to replace its stride with a default of `[1, 1]`. However, the old stride check required the stride to match one of the scaling dimensions. Here, we relax the stride check when the effective stride is 1 in order to allow for cases in which `K <= 128` and `N <= 128`.
Test Plan:
```
pytest -s -v test/test_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_block_wise_float32_lhs_block_1_rhs_block_128_cuda 2>&1 | tee ~/personal/stride_check.log
```
Differential Revision: D83023706
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163829
Approved by: https://github.com/lw, https://github.com/eqy