Fixes part of #163314
In particular bug: **Bug 1: H=None Broadcasting Produces Incorrect Results**
This fixes a shape bug when slicing BlockMask on the Q-tile axis with an int (**mask[:, :, i]**). That form of indexing collapses the Q dimension, so kv_num_blocks/kv_indices lose their expected [B, H, Q_tiles, …] shape. Due to them losing shape, even though the mask_mod remains "interpretable", the kernel’s stride math then reads wrong offsets. Due to this we get silent numerical mismatches compared to regular SDPA, especially when single position decoding/H broadcasting.
The B=None, H=None works case is accidental: with singleton batch/head the kernel maps to index 0 via `sparse_idx_z = off_zq % 1` and `sparse_idx_hq = off_hq % 1` and with a single Q tile `q_start // SPARSE_Q_MULTIPLE = 0`. The missing Q-tiles stride is multiplied by 0, so the bad offset from the collapsed Q axis doesn’t move the pointer and it happens to read the first tile correctly. Once H > 1 or there are multiple Q tiles, those terms become nonzero and the kernel indexes with wrong strides which causes silent error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163426
Approved by: https://github.com/drisspg
# Summary
### Update
API
```Py
class AuxRequest(NamedTuple):
"""Request which auxiliary outputs to compute from flex_attention.
Each field is a boolean indicating whether that auxiliary output should be computed.
"""
lse: bool = False
max_scores: bool = False
class AuxOutput(NamedTuple):
"""Auxiliary outputs from flex_attention operation.
Fields will be None if not requested, or contain the tensor if requested.
"""
lse: Optional[Tensor] = None
max_scores: Optional[Tensor] = None
out_only = flex_attention(query, key, value, score_mod)
out_max, aux_max = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(max_scores=True),
)
out_both, aux_both = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
)
```
Returns the max post mod scores from flex attention.
Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.
Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now
Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.
We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.
### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible but since max is essentially 1 hot on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).
For now no grad, we can re-visit if needed.
## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.
```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ causal ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 161.031318 ┆ 158.597808 ┆ 2.43351 ┆ 1.534391 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘
🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, ┆ 175.546923 ┆ 177.81205 ┆ -2.265127 ┆ -1.273888 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4, ┆ 156.282597 ┆ 158.209134 ┆ -1.926537 ┆ -1.217715 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16, ┆ 232.542929 ┆ 235.140136 ┆ -2.597207 ┆ -1.104536 │
│ ┆ ┆ 2048, 128) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 169.652791 ┆ 171.475986 ┆ -1.823195 ┆ -1.063236 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
# Summary
### Update
API
```Py
class AuxRequest(NamedTuple):
"""Request which auxiliary outputs to compute from flex_attention.
Each field is a boolean indicating whether that auxiliary output should be computed.
"""
lse: bool = False
max_scores: bool = False
class AuxOutput(NamedTuple):
"""Auxiliary outputs from flex_attention operation.
Fields will be None if not requested, or contain the tensor if requested.
"""
lse: Optional[Tensor] = None
max_scores: Optional[Tensor] = None
out_only = flex_attention(query, key, value, score_mod)
out_max, aux_max = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(max_scores=True),
)
out_both, aux_both = flex_attention(
query,
key,
value,
score_mod,
return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True),
)
```
Returns the max post mod scores from flex attention.
Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups.
Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now
Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args.
We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors.
### Req Grad
I currently dont return a max_scores that supports backproping grads. I think this might be feasible but since max is essentially 1 hot on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch).
For now no grad, we can re-visit if needed.
## Perf
I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path.
```Shell
🔝 Top 5 TFlops Deltas (by absolute %):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
🔺 Top 5 Positive TFlops Deltas (highest +%):
shape: (5, 7)
┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡
│ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │
│ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ causal ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 161.031318 ┆ 158.597808 ┆ 2.43351 ┆ 1.534391 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘
🔻 Top 5 Negative TFlops Deltas (lowest -%):
shape: (5, 7)
┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐
│ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡
│ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │
│ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, ┆ 175.546923 ┆ 177.81205 ┆ -2.265127 ┆ -1.273888 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4, ┆ 156.282597 ┆ 158.209134 ┆ -1.926537 ┆ -1.217715 │
│ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │
│ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16, ┆ 232.542929 ┆ 235.140136 ┆ -2.597207 ┆ -1.104536 │
│ ┆ ┆ 2048, 128) ┆ ┆ ┆ ┆ │
│ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 169.652791 ┆ 171.475986 ┆ -1.823195 ┆ -1.063236 │
│ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │
└────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
[#RFC153024](https://github.com/pytorch/pytorch/issues/153024)
**Motivation**
1. The Attention has been the critical performance bottleneck in the current LLM models, and FlexAttention is a good choice to cover the broad variants in the transformers series models. With FlexAttention, it is easy for us to enable the paged attention and fused SDPA in the transformers repo on XPU device. Besides, it also provide a candidate to process attention in LLM ecosystem libraries ., e.g., vLLM, SGLang on XPU device.
2. FlexAttention is good start point to push the intel triton based GEMM kernel to be matured. FlexAttention provide both flexattention kernel and flexdecoding kernel to cover both compute bound and memory bound GEMM computation, and different shapes should also been supported to serve LLM inference., e.g. head_dim=64, 96, 128, 256.
**What does this PR do?**
1. Enable the device type for Flexattention kernel and UTs to ensure all important UTs pass on XPU device.
2. For E2E model inference, ensure the functionality of LLM models inference with FlexAttention to be ready.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143553
Approved by: https://github.com/EikanWang, https://github.com/drisspg
Co-authored-by: Mao Yunfei <yunfei.mao@intel.com>
Co-authored-by: Xingyuan Li <xingyuan.li@intel.com>
Co-authored-by: majing <jing1.ma@intel.com>
Co-authored-by: Xiao, Wang <wang.xiao@intel.com>
Fixes#147336
## Context
NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.
Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.
To summarize:
In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.
This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).
i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.
## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs
## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.
Before fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```
After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
Fixes#147336
## Context
NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.
Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.
To summarize:
In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.
This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).
i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.
## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs
## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.
Before fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```
After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
# Summary
- Adds support for non-power of 2 headdim by launching blocks w/ head_dim rounded to the next valid power.
- Other option I considered was building up the final dot_products with smaller blocks (this would probably work but for sake of code complexity going with this option for now)
### Corollary
We had a bug in our backwards kernel where we were using index_k instead of index_v. This should have shown up for the qk_head_dim != v_head_dim cases..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133495
Approved by: https://github.com/Chillee
This PR brings the FlexAttention inference support for the inductor backend in torch.compile (support precisions: bf16 and fp32) on CPUs.
Based on the existing CPP template, this PR extends and implements a FlexAttention CPP template to support broad attention variants, and meanwhile brings optimized performance on CPUs.
With this, users can transparently extend their Flex Attention usages to CPUs with good and common support from torch.compile, both functionality and performance.
For UT tests, in this PR, we include partial critical tests for CPUs as the following (conduct inference tests):
```
pytest test/inductor/test_flex_attention.py
`TestFlexAttention`
#common functions:
run_test
preprocess_paged_attention
run_paged_attention
run_test_with_paged_attention
run_test_with_call
run_dynamic_test
run_automatic_dynamic_test
#test functions:
test_builtin_score_mods
test_builtin_score_mods_automatic_dynamic
test_builtin_score_mods_different_seqlen
test_builtin_score_mods_different_block_size
test_kv_batch_broadcast
test_GQA
test_cpu_error_message_return_lse
test_validate_cpu_dtype_error_message
`TestPagedAttention`
#test function:
test_paged_builtin_score_mods
```
For the rest UTs in `test/inductor/test_flex_attention.py ` and `test/inductor/test_flex_decoding.py`, due to bigger lines of changes (1500+ LOC) that make this PR hard to review, will submit another PR specific for CPU device UTs enabling and refactor.
Besides, more optimizations are also planned in follow up PRs, including:
- Block sparse computation
- Flash decoding tuning
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141453
Approved by: https://github.com/drisspg, https://github.com/leslie-fang-intel
Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
This PR brings the FlexAttention inference support for the inductor backend in torch.compile (support precisions: bf16 and fp32) on CPUs.
Based on the existing CPP template, this PR extends and implements a FlexAttention CPP template to support broad attention variants, and meanwhile brings optimized performance on CPUs.
With this, users can transparently extend their Flex Attention usages to CPUs with good and common support from torch.compile, both functionality and performance.
For UT tests, in this PR, we include partial critical tests for CPUs as the following (conduct inference tests):
```
pytest test/inductor/test_flex_attention.py
`TestFlexAttention`
#common functions:
run_test
preprocess_paged_attention
run_paged_attention
run_test_with_paged_attention
run_test_with_call
run_dynamic_test
run_automatic_dynamic_test
#test functions:
test_builtin_score_mods
test_builtin_score_mods_automatic_dynamic
test_builtin_score_mods_different_seqlen
test_builtin_score_mods_different_block_size
test_kv_batch_broadcast
test_GQA
test_cpu_error_message_return_lse
test_validate_cpu_dtype_error_message
`TestPagedAttention`
#test function:
test_paged_builtin_score_mods
```
For the rest UTs in `test/inductor/test_flex_attention.py ` and `test/inductor/test_flex_decoding.py`, due to bigger lines of changes (1500+ LOC) that make this PR hard to review, will submit another PR specific for CPU device UTs enabling and refactor.
Besides, more optimizations are also planned in follow up PRs, including:
- Block sparse computation
- Flash decoding tuning
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141453
Approved by: https://github.com/drisspg, https://github.com/leslie-fang-intel
Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
This PR brings the FlexAttention inference support for the inductor backend in torch.compile (support precisions: bf16 and fp32) on CPUs.
Based on the existing CPP template, this PR extends and implements a FlexAttention CPP template to support broad attention variants, and meanwhile brings optimized performance on CPUs.
With this, users can transparently extend their Flex Attention usages to CPUs with good and common support from torch.compile, both functionality and performance.
For UT tests, in this PR, we include partial critical tests for CPUs as the following (conduct inference tests):
```
pytest test/inductor/test_flex_attention.py
`TestFlexAttention`
#common functions:
run_test
preprocess_paged_attention
run_paged_attention
run_test_with_paged_attention
run_test_with_call
run_dynamic_test
run_automatic_dynamic_test
#test functions:
test_builtin_score_mods
test_builtin_score_mods_automatic_dynamic
test_builtin_score_mods_different_seqlen
test_builtin_score_mods_different_block_size
test_kv_batch_broadcast
test_GQA
test_cpu_error_message_return_lse
test_validate_cpu_dtype_error_message
`TestPagedAttention`
#test function:
test_paged_builtin_score_mods
```
For the rest UTs in `test/inductor/test_flex_attention.py ` and `test/inductor/test_flex_decoding.py`, due to bigger lines of changes (1500+ LOC) that make this PR hard to review, will submit another PR specific for CPU device UTs enabling and refactor.
Besides, more optimizations are also planned in follow up PRs, including:
- Block sparse computation
- Flash decoding tuning
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141453
Approved by: https://github.com/jgong5, https://github.com/drisspg, https://github.com/leslie-fang-intel
Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>