mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 01:54:54 +08:00
# Summary See title ;) ## Design Currently once you install there is no going back in the same python process, this need not be the case, cc @mikaylagawarecki's work on being able to grab original impl. I'll leave for follow up. Okay I added an open reg, but I really want the backends to be found so some weird typing but we get <img width="523" height="197" alt="Screenshot 2025-11-07 at 3 30 32 PM" src="https://github.com/user-attachments/assets/586de943-bbed-40cf-abd1-131f747a4cf1" /> ## Overheads: <img width="799" height="735" alt="Screenshot 2025-11-07 at 2 35 04 PM" src="https://github.com/user-attachments/assets/f9217f31-3e42-4816-8fb3-29ea8b49d735" /> First call to forward -> majority of time is spent in jit for FA First call to backward, 3sec interestingly it doesn't appear that with_stack gets events in the backwards loop @albanD is this expected? <img width="948" height="385" alt="Screenshot 2025-11-07 at 2 35 50 PM" src="https://github.com/user-attachments/assets/a40bacd0-3fb0-4bd8-b33e-bec8fb3f36c0" /> Getting form Pt op to impl is about 43 us which is dwarfed by other cpu overheads <img width="1227" height="649" alt="Screenshot 2025-11-07 at 2 37 41 PM" src="https://github.com/user-attachments/assets/51da0615-facd-41e1-a6e2-fb7778079ab6" /> Just invoking the jit object from cutesl is 100s of us <img width="545" height="414" alt="Screenshot 2025-11-07 at 2 38 19 PM" src="https://github.com/user-attachments/assets/d20345a0-6c47-4dcb-892f-9ef9894a1cf5" /> ### Example usage ```Py #!/usr/bin/env python3 """Minimal FA4 smoke test for scaled dot product attention.""" from __future__ import annotations import sys from jsonargparse import CLI import torch import torch.nn.functional as F from torch.nn.attention import ( install_flash_attention_impl, sdpa_kernel, SDPBackend, ) def _map_dtype(kind: str) -> torch.dtype: return torch.bfloat16 if kind == "bf16" else torch.float16 # To infinity and beyond install_flash_attention_impl("FA4") @sdpa_kernel([SDPBackend.FLASH_ATTENTION]) def main( module_path: str = "flash_attn.cute.interface", batch: int = 4, seq: int = 81292, heads: int = 16, head_dim: int = 128, device: int = 0, dtype: str = "bf16" ) -> None: if not torch.cuda.is_available(): sys.exit("CUDA is required for FA4 smoke testing") torch.cuda.set_device(device) dtype = _map_dtype(dtype) generator = torch.Generator(device="cuda").manual_seed(0) q = torch.randn( batch, heads, seq, head_dim, device="cuda", dtype=dtype, requires_grad=True, generator=generator, ) k = torch.randn( batch, heads, seq, head_dim, device="cuda", dtype=dtype, requires_grad=True, generator=generator, ) v = torch.randn( batch, heads, seq, head_dim, device="cuda", dtype=dtype, requires_grad=True, generator=generator, ) from transformer_nuggets.utils.benchmark import profiler with profiler("sdpa_FA4", with_stack=False): for _ in range(3): out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) loss = out.real.sum() loss.backward() print("Scaled dot product attention output norm:", out.norm().item()) print("dq norm:", q.grad.norm().item()) if __name__ == "__main__": CLI(main) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167348 Approved by: https://github.com/albanD, https://github.com/malfet
Please see the Writing documentation section of CONTRIBUTING.md for details on both writing and building the docs.