Files
pytorch/docs
drisspg 05c6a06b2b Add FA4 to sdpa (#167348)
# Summary
See title ;)

## Design

Currently once you install there is no going back in the same python process, this need not be the case, cc @mikaylagawarecki's work on being able to grab original impl. I'll leave for follow up.

Okay I added an open reg, but I really want the backends to be found so some weird typing but we get
<img width="523" height="197" alt="Screenshot 2025-11-07 at 3 30 32 PM" src="https://github.com/user-attachments/assets/586de943-bbed-40cf-abd1-131f747a4cf1" />

## Overheads:
<img width="799" height="735" alt="Screenshot 2025-11-07 at 2 35 04 PM" src="https://github.com/user-attachments/assets/f9217f31-3e42-4816-8fb3-29ea8b49d735" />
First call to forward -> majority of time is spent in jit for FA

First call to backward, 3sec interestingly it doesn't appear that with_stack gets events in the backwards loop @albanD is this expected?
<img width="948" height="385" alt="Screenshot 2025-11-07 at 2 35 50 PM" src="https://github.com/user-attachments/assets/a40bacd0-3fb0-4bd8-b33e-bec8fb3f36c0" />

Getting form Pt op to impl is about 43 us which is dwarfed by other cpu overheads
<img width="1227" height="649" alt="Screenshot 2025-11-07 at 2 37 41 PM" src="https://github.com/user-attachments/assets/51da0615-facd-41e1-a6e2-fb7778079ab6" />

Just invoking the jit object from cutesl is 100s of us
<img width="545" height="414" alt="Screenshot 2025-11-07 at 2 38 19 PM" src="https://github.com/user-attachments/assets/d20345a0-6c47-4dcb-892f-9ef9894a1cf5" />

### Example usage
```Py
#!/usr/bin/env python3

"""Minimal FA4 smoke test for scaled dot product attention."""

from __future__ import annotations

import sys
from jsonargparse import CLI

import torch
import torch.nn.functional as F
from torch.nn.attention import (
    install_flash_attention_impl,
    sdpa_kernel,
    SDPBackend,
)

def _map_dtype(kind: str) -> torch.dtype:
    return torch.bfloat16 if kind == "bf16" else torch.float16

# To infinity and beyond
install_flash_attention_impl("FA4")

@sdpa_kernel([SDPBackend.FLASH_ATTENTION])
def main(
    module_path: str = "flash_attn.cute.interface",
    batch: int = 4,
    seq: int = 81292,
    heads: int = 16,
    head_dim: int = 128,
    device: int = 0,
    dtype: str = "bf16"
    ) -> None:
    if not torch.cuda.is_available():
        sys.exit("CUDA is required for FA4 smoke testing")
    torch.cuda.set_device(device)
    dtype = _map_dtype(dtype)
    generator = torch.Generator(device="cuda").manual_seed(0)
    q = torch.randn(
        batch,
        heads,
        seq,
        head_dim,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
        generator=generator,
    )
    k = torch.randn(
        batch,
        heads,
        seq,
        head_dim,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
        generator=generator,
    )
    v = torch.randn(
        batch,
        heads,
        seq,
        head_dim,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
        generator=generator,
    )
    from transformer_nuggets.utils.benchmark import profiler
    with profiler("sdpa_FA4", with_stack=False):
        for _ in range(3):
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
            loss = out.real.sum()
            loss.backward()
    print("Scaled dot product attention output norm:", out.norm().item())
    print("dq norm:", q.grad.norm().item())

if __name__ == "__main__":
    CLI(main)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167348
Approved by: https://github.com/albanD, https://github.com/malfet
2025-11-12 03:29:07 +00:00
..
2025-11-12 03:29:07 +00:00
2019-10-08 12:18:30 -07:00
2020-05-21 13:03:55 -07:00

Please see the Writing documentation section of CONTRIBUTING.md for details on both writing and building the docs.