Commit Graph

21 Commits

Author SHA1 Message Date
78f5a1ec60 varlen api (#164502)
**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of #162326)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
2025-10-15 19:45:55 +00:00
3044e1a460 Revert "varlen api (#164502)"
This reverts commit 3681312ce03e425e280a110df2153db107616a15.

Reverted https://github.com/pytorch/pytorch/pull/164502 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the doctests failure is legit ([comment](https://github.com/pytorch/pytorch/pull/164502#issuecomment-3404419420))
2025-10-15 03:56:42 +00:00
3681312ce0 varlen api (#164502)
**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of #162326)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
2025-10-15 00:45:06 +00:00
db763b1717 [Intel GPU] Support SDPA backend selection and priority setting on XPU (#159464)
Currentlly SPDA XPU use own `priority_order` instead of the one from global context. Hence it does not support `with sdpa_kernel(order, set_priority=True)` with set_priority=True.

This PR enables this feature. To make default `priority_order` from global context works for XPU, I also move MATH backend to lowest priority, otherwise `cudnn attention` and `overrideable attention` will never be selected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159464
Approved by: https://github.com/guangyey, https://github.com/drisspg

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Co-authored-by: mayuyuace <qiming1.zhang@intel.com>
2025-08-14 08:55:31 +00:00
ba0d0de5e6 Enable set SDPA backend by torch.nn.attention.sdpa_kernel on XPU (#156669)
Introduces support for a new `OVERRIDEABLE` backend in the SDPA module, improves backend selection logic, and adds corresponding tests. In addition, a fallback mechanism was added when a specific backend is unavailable, enhancing user configurability.

### Backend Support and Selection Enhancements:
* Added `at::SDPBackend::overrideable` to the list of available SDPA backends in the `Context` class (`aten/src/ATen/Context.h`).
* Updated the backend selection logic in `select_sdp_backend_xpu` to include the `OVERRIDEABLE` backend and added a fallback mechanism for unsupported `FLASH_ATTENTION` on XPU.
* Adjusted error messaging in `_fused_sdp_choice_xpu` to reflect the inclusion of the `OVERRIDEABLE` backend. (`aten/src/ATen/native/mkldnn/xpu/Attention.cpp`)

### Test Additions for Backend Fallback and Selection:
* Added new unit tests to validate fallback behavior for `FLASH_ATTENTION` to `OVERRIDEABLE` and to verify correct backend selection when `MATH` is enabled. (`test/test_transformers.py`,)

### Codebase Updates for Backend Integration:
* Introduced `OVERRIDEABLE` as a new member of the `_SDPBackend` enum. (`torch/_C/__init__.pyi.in`)
* Extended `_backend_names` and updated related methods to handle the `OVERRIDEABLE` backend. (`torch/nn/attention/__init__.py`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156669
Approved by: https://github.com/guangyey, https://github.com/drisspg
2025-07-10 06:52:22 +00:00
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
0dcd482e54 [SDPA] Respect sdpa_kernel's priority_order setting in torch.compile (#147768)
[https://github.com/pytorch/pytorch/pull/140467](https://github.com/pytorch/pytorch/pull/140467) added the option to specify a priority order for SDPA but the `torch.compile` path silently ignored this setting as I wasn't aware of the separate context manager handling on `torch.compile`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147768
Approved by: https://github.com/drisspg
2025-03-13 18:52:34 +00:00
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
cyy
d87aad6877 [5/N] Apply Ruff fixes and pyupgrade to Python 3.9 (#144205)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144205
Approved by: https://github.com/albanD
2025-01-15 04:00:47 +00:00
eqy
8fc6d3a5d8 [SDPA] Allow user-specified priority order with context manager (#140467)
TODO: docs changes?
For better debuggability of issues like https://github.com/pytorch/pytorch/issues/139298

Better testing, current sketch:

``` Python
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel

q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda')
print(torch._C._get_sdp_priority_order())

orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]]
import time
times = list()
for order in orders:
    print(order)
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
print(times)
assert times[0] < times[1]
assert times[0] > times[2]
assert times[1] > times[2]
print(torch._C._get_sdp_priority_order())
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140467
Approved by: https://github.com/drisspg
2024-12-06 07:56:35 +00:00
be4b7e8131 Param fixes in docstring (#136097)
Fixes wrong param names in docstrings. cc: @kit1980

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136097
Approved by: https://github.com/ezyang
2024-09-21 18:56:34 +00:00
63d6cd351a [dynamo] support torch.nn.attention.sdpa_kernel context manager (#135404)
Fixes https://github.com/pytorch/pytorch/issues/134608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135404
Approved by: https://github.com/jansel, https://github.com/drisspg
2024-09-12 22:04:48 +00:00
b4b62d3945 update to 2.5.8 (#131684)
# Summary
This stack brings the current fork of FAv2 near the top of main which is 2.6.2

Notably we need to update cutlass to 3.5.0

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131684
Approved by: https://github.com/jainapurva
2024-07-25 23:15:03 +00:00
dff6342a0b [BE][Easy] enable UFMT for torch/nn/parallel (#128596)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128596
Approved by: https://github.com/mikaylagawarecki
2024-06-17 16:29:22 +00:00
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
04877dc430 Update context manager for cudnn (#126122)
# Summay
Updates the context manager to support cudnn backend

This results were done using cuda toolkit 12-3 and cudnn 9.0.0.

## H100 Numbers
 _power limited_
 ``` Markdown
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|   Batch Size |   Sequence Length |   Heads |   Head Dim |   Flash Time (µs) |   CUDNN Time (µs) |   Speedup (CUDNN/Flash) |
+==============+===================+=========+============+===================+===================+=========================+
|            1 |              4096 |      32 |         64 |           665.053 |           498.59  |                 1.33387 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |              4096 |      16 |        128 |           591.225 |           323.828 |                 1.82574 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |              8192 |      32 |         64 |          2579.77  |          1933.34  |                 1.33436 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |              8192 |      16 |        128 |          2297.4   |          1211.33  |                 1.89659 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |             16384 |      32 |         64 |         10178.2   |          7619.18  |                 1.33587 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |             16384 |      16 |        128 |          9093.51  |          4725.03  |                 1.92454 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |             32768 |      32 |         64 |         39893.1   |         29850.6   |                 1.33643 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |             32768 |      16 |        128 |         36160.9   |         18615.9   |                 1.94247 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |             65536 |      32 |         64 |        157965     |        116794     |                 1.35251 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |             65536 |      16 |        128 |        142039     |         73102.1   |                 1.94303 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |            131072 |      32 |         64 |        621100     |        465143     |                 1.33529 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
|            1 |            131072 |      16 |        128 |        556142     |        289776     |                 1.91922 |
+--------------+-------------------+---------+------------+-------------------+-------------------+-------------------------+
```

## A100 Numbers
```Markdown
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|   Batch Size |   Sequence Length |   Heads |   Head Dim |   Flash Time (µs) |   CUDNN Time (µs) |   Flex Time (µs) |   Speedup (CUDNN/Flash) |   Speedup (Flex/Flash) |
+==============+===================+=========+============+===================+===================+==================+=========================+========================+
|            1 |              4096 |      32 |         64 |           799.391 |           836.327 |          981.234 |                0.955836 |               0.814679 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |              4096 |      16 |        128 |           750.131 |           806.964 |          944.766 |                0.929572 |               0.793986 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |              8192 |      32 |         64 |          3211.84  |          3234.41  |         3803.09  |                0.993022 |               0.844534 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |              8192 |      16 |        128 |          2984.2   |          3164.66  |         3626.79  |                0.942979 |               0.822821 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |             16384 |      32 |         64 |         12630.6   |         12673.1   |        14900.6   |                0.996643 |               0.847653 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |             16384 |      16 |        128 |         11722.7   |         12499.4   |        13763.5   |                0.937862 |               0.851725 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |             32768 |      32 |         64 |         50068.3   |         51061.2   |        60094     |                0.980556 |               0.833167 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |             32768 |      16 |        128 |         46283.6   |         49708.7   |        55336.7   |                0.931096 |               0.836399 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |             65536 |      32 |         64 |        203124     |        203083     |       239618     |                1.0002   |               0.847701 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |             65536 |      16 |        128 |        187326     |        198364     |       221912     |                0.944355 |               0.844145 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |            131072 |      32 |         64 |        816813     |        827419     |       978836     |                0.987182 |               0.834473 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
|            1 |            131072 |      16 |        128 |        749693     |        845463     |       905696     |                0.886725 |               0.827754 |
+--------------+-------------------+---------+------------+-------------------+-------------------+------------------+-------------------------+------------------------+
```

## Script
``` Python
import os
import torch
from typing import Callable
from torch.nn.attention import SDPBackend, sdpa_kernel
from itertools import product
from tqdm import tqdm
from tabulate import tabulate

os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"

causal = False

from triton.testing import do_bench
from torch.nn.functional import scaled_dot_product_attention as sdpa

def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
    # warmup
    for _ in range(5):
        func(*args, **kwargs)
    return do_bench(lambda: func(*args, **kwargs)) * 1e3

def run_attention_test(backend_name, backend_enum):
    results = []
    batch_sizes = [1]
    seq_lengths = [4096, 8192, 16384, 32768, 65536, 131072]

    torch.cuda.empty_cache()
    for b, s in tqdm(product(batch_sizes, seq_lengths), total=len(batch_sizes) * len(seq_lengths), desc=backend_name):
        for h, d in zip((32, 16), (64, 128)):
            q, k, v = torch.randn(
                b, s, h * d * 3, dtype=torch.bfloat16, device="cuda", requires_grad=False
            ).chunk(3, dim=-1)
            q = q.view(b, -1, h, d).transpose(1, 2)
            k = k.view(b, -1, h, d).transpose(1, 2)
            v = v.view(b, -1, h, d).transpose(1, 2)
            with torch.no_grad(), sdpa_kernel(backend_enum):
                time = benchmark_torch_function_in_microseconds(sdpa, q, k, v, is_causal=False)
            results.append((backend_name, b, s, h, d, time))
    return results

flash_results = run_attention_test("Flash Attention", SDPBackend.FLASH_ATTENTION)
cudnn_results = run_attention_test("CUDNN Attention", SDPBackend.CUDNN_ATTENTION)

# Combine results for comparison
combined_results = []
for flash, cudnn in zip(flash_results, cudnn_results):
    speedup = flash[5] / cudnn[5]
    combined_results.append(
        (flash[1], flash[2], flash[3], flash[4], flash[5], cudnn[5], speedup)
    )

# Tabulate the results
headers = [
    "Batch Size",
    "Sequence Length",
    "Heads",
    "Head Dim",
    "Flash Time (s)",
    "CUDNN Time (s)",
    "Speedup (CUDNN/Flash)",
]
table = tabulate(combined_results, headers, tablefmt="grid")
print(table)

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126122
Approved by: https://github.com/cpuhrsch
2024-05-14 03:34:19 +00:00
edad82fc90 Add private helper for determining which version of FA2 closest matches kernel version (#123653)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123653
Approved by: https://github.com/mikaylagawarecki
2024-05-02 21:28:23 +00:00
f5391dad82 Update docs to point to new sdpa_kernel context manager (#121180)
# Summary

Updates the SDPA docs to fix some small inaccuracies and points to the new sdpa_kernel context manger. The Enum like type binded from cpp SDPBackend does not render its fields for some reason. Manually list them instead for now

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121180
Approved by: https://github.com/mikaylagawarecki
2024-03-05 22:19:48 +00:00
8a8e70477e Fix type hints on nn.attention.sdpa_kernel (#119140)
Fixes #119133
Altered type hint and assert to include SDPBackend; disallowed None in assert.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119140
Approved by: https://github.com/mikaylagawarecki, https://github.com/cpuhrsch, https://github.com/drisspg
2024-02-06 07:33:22 +00:00
4e29f01bf2 Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689)
# Summary
Simplification of Backend Selection

This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.

For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.

Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.

Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).

A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
2024-01-24 22:28:04 +00:00
d4c79a3078 Add an attention bias subclass for a lower right causal masking (#114823)
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.

The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.

The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.

Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
|  Type   |      Speedup       | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim |     dtype      | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 |            |           |           |           |           |                |          |
|   Max   | 1.831672915579016  |    128     |    32     |   1024    |   2048    |   2048    | torch.bfloat16 |    64    |
|   Min   | 0.9430534166730135 |     1      |    16     |    256    |    416    |   2048    | torch.bfloat16 |   128    |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
2023-12-06 08:29:26 +00:00